The attention mechanism is a powerful tool that dynamically highlights relevant features of input data. It assigns more weight to the most relevant words in a sentence, improving the quality of predictions. This project delves into the application of the attention mechanism for text classification, evaluating its effectiveness in solving our problem.
To perform multiclass text classification using the attention mechanism on a dataset of customer complaints about consumer financial products.
The dataset consists of more than two million customer complaints about consumer financial products. It includes columns for the actual text of the complaint and the product category associated with each complaint. Pre-trained word vectors from the GloVe dataset (glove.6B) are used to enhance text representation.
- Language:
Python
- Libraries:
pandas
,torch
,nltk
,numpy
,pickle
,re
,tqdm
,sklearn
- Installing the necessary packages via
pip
- Importing the required libraries
- Defining configuration file paths
- Processing GloVe embeddings:
- Reading the text file
- Converting embeddings to float arrays
- Adding embeddings for padding and unknown items
- Saving embeddings and vocabulary
- Processing Text data:
- Reading the CSV file and dropping null values
- Replacing duplicate labels
- Encoding the label column and saving the encoder and encoded labels
- Data Preprocessing:
- Conversion to lowercase
- Punctuation removal
- Digits removal
- Removing consecutive instances of 'x'
- Removing additional spaces
- Tokenizing the text
- Saving the tokens
- Model:
- Creating the attention model
- Defining a function for the PyTorch dataset
- Functions for training and testing the model
- Training:
- Loading the necessary files
- Splitting data into train, test, and validation sets
- Creating PyTorch datasets
- Creating data loaders
- Creating the model object
- Moving the model to GPU if available
- Defining the loss function and optimizer
- Training the model
- Testing the model
- Making predictions on new text data
-
Input: Contains the data required for analysis, including:
complaints.csv
glove.6B.50d.txt
(download from here)
-
Source: Contains modularized code for various project steps, including:
model.py
data.py
utils.py
These Python files contain helpful functions used in
Engine.py
. -
Output: Contains files required for model training, including:
attention.pth
embeddings.pkl
label_encoder.pkl
labels.pkl
vocabulary.pkl
tokens.pkl
-
config.py: Contains project configurations.
-
Engine.py: The main file to run the entire project, which trains the model and saves it in the output folder.