In this repository, we implement the paper: Get To The Point: Summarization with Pointer-Generator Networks.
In this repository, we aim to accomplish the following tasks:
- The input to the model will be a text paragraph from which you need to create summaries.
- Build a Pointer Generator Network architecture from scratch using PyTorch (You may use any other framework of your choice, but do not use an off-the-shelf library implementation!)
- The encoder should be a BiLSTM, whereas the decoder should be a single LSTM layer
- Create attention distribution with the help of the decoder state and encoder hidden states
- The context vector is a weighted sum of encoder hidden states as per the respective attention distribution
- For each decoder timestep calculate generation probability pgen
- Use a weighted sum of vocabulary distribution and attention distribution to obtain a final distribution to make the final prediction
- use ROUGE Metric to evaluate the model
The Pointer network can be considered a simple extension of the attention model. It is a hybrid of an Attention Model and a pointer network. Words are generated from a fixed vocabulary and are copied by pointing.
data_reduction.py
contains the code to choose a suitable subset of the entire dataset we have used.summary_gen.py
We store the decoder outputs and summaries generated in a csv file. This file opens that csv and prints out the summaries in a viewable fashion as well as prints ROUGE scores.extra
folder containing training logs, ipynb format for the codes and other extra files.pgn_summzarization.py
Single commented out and explained code file which handles datasets, constructs data loaders, builds the model from as given and starts training.README
itz what you reading right now ^_^
The code is executed by:
python3 <filename>.py
When the model is run, the code snippet:
torch.save(model.state_dict(), 'model.pt')
stores the best parameters of the model which gives the lowest validation loss. The code by itself calls the model back for testing purposes using:
model.load_state_dict(torch.load('model.pt'))
If the model has been successfully loaded, it returns <All keys matched successfully>
.