This repository contains the code to train a summarization model using the BART architecture. The training and evaluation is performed on a subset of the CNN_Dailymail dataset. The model is fine-tuned for the task of summarizing news articles.
- Environment Setup: Initial setup including the installation of necessary libraries such as Pytorch Lightning, Transformers, and Torchmetrics.
- Data Preparation: Loading and preparation of the CNN_Dailymail dataset from parquet files, and data tokenization using the BART tokenizer.
- Custom Dataset: Definition of a custom dataset class for PyTorch's DataLoader for efficient data loading.
- Model Configuration: Configuration of the BART model for summarization, including setting hyperparameters like learning rate, batch sizes, and number of epochs.
- Training: Training loop for the summarization model, with loss tracking and logging.
- Validation: Validation step to monitor the model performance during training.
- Testing: Testing step to evaluate the model's summarization performance, along with ROUGE score calculation for evaluation.
- Model Saving: Saving the trained model and tokenizer for future use.
- Example Usage: Code to create a summary for a new article and evaluate its performance using ROUGE scores.
- Google Drive Access: Code to mount Google Drive in Colab for accessing/saving models.