/genderBERT

Deep learning model for gender classification on texts using pretrained BERT models

Primary LanguagePython

genderBERT

Deep learning model for gender classification on texts using pretrained BERT models. Given a text written by a human as input, our models predict the gender of the author of the text. Our trained models with the highest accuracy for each data set (see section 'Results') can be downloaded via figshare.com.

Used pretrained models

BERT (Bidirectional Encoder Representations from Transformers) (source, paper): Transformers model pretrained on a large corpus of English data with self-supervised learning. We use the base version, which provides 12 layers with 110M parameters and uncased version, which does not distinguish between lower and upper case letters (to reduce complexity and runtime, see this overview for all pretrained versions).

alBERT (A Lite BERT) (source, paper): Lite version of BERT using two different parameter reduction techniques resulting in 18x fewer parameters and about 1.7x faster training speed

roBERTa (Robustly optimized BERT approach) (source, paper): Further trained BERT model on additional data, with longer sequences and bigger batches, adding dynamical masking patterns and removing BERT's next sentence prediction objective.

distilBERT (Distilled Bert) (source, paper): Variation of BERT utilizing knowledge distillation during the pre-training phase for size reduction and increase in speed

VCDNN (Variable-Component Deep Neural Network) (paper): for robust speech recognition

HAN (Hierarchical Attention Networks) (paper): for document classification

Used data sets for training & testing

Amazon (source): contains product reviews from Amazon

StackExchange (source): contains content posted on StackExchange

Reddit (source): contains comments on Reddit

Code

▶️ For more detailed information, use the corresponding link to the docsring at the end of each descirption.

  • main.py - Main file of the project. Uses tokenization and model functionalities to create new models in accordance to the configuration set in the config.json (docstring)
  • tokenizer.py - Prepares the data for the main file. The given data set gets tokenized with applied padding and oversized texts get truncated. It stores the resulting tokenized texts and the corresponding attention mask. (docstring)
  • model.py - Implements function to load embeddings, model creation, training, validating and testing. Uses a given pretrained model and a tokenized data set and does training/validation/testing as specified in the given mode of the config.json file.
  • majority_voting.py - Computes the majority voting for a given prediction of a BERT model and displays accuracy and F1 after applying majority voting. Whenever a user with multiple texts but different predictions on gender has a majority for one gender, all minority predictions get changed to the majority prediction. The function does not change predictions for users with no predicted majority for one gender (50/50 case). (docstring)
  • customBERT.py - Additional (failed) approach, where BERT gets extended by 3 adjustable layers (e.g. linear). All attempts resulted in an accuracy below 0.75.
  • run_cluster.py - Script for training the same model with different learning rates, maximal tokencounts and truncating methods consecutively.
  • config.json - Collection (type dict) of possible setups for model.py. ID of the setup is used in main.py for extracting options.
  • bert-base-uncased-vocab.txt - Vocab map for the pretrained BERT model bert-base-uncased to convert words to token IDs (source) (the line of the file represents the ID starting at 0 for the [PAD] flag, see BERT Tokenizer)
  • data_subset.csv - Sample of Amazon training data (10k raw elements) for testing general code functionality on local machines.

How to use

  • 1. Add your config to the corresponding JSON file:

    • EPOCHS - Number of epochs for training
    • LEARNING_RATE - Used learning rate for the model
    • BATCH_SIZE - Used batch size for the model
    • MAX_TOKENCOUNT - Maximal token count (used for embedding)
    • TRUNCATING_METHOD - Method for cutting oversized data (Head, Tail, Headtail)
    • TOGGLE_PHASES - Bool array of the form: [Do_Train_Phase, Do_Val_Phase, Do_Test_Phase]
    • SAVE_MODEL - Save to given path, do not save if none
    • PRELOAD_MODEL - Load model from given path, do not load if none
    • LOAD_EMBEDDINGS - Load embeddings from given path of a size 3 array [train, val, test], do not load if none
    • ROWS_COUNTS - Number of rows to consider, given by a size 3 array (see above)
    • MODEL_TYPE - Types: bert, albert, roberta, distilbert, custombert
    • DATASET_TYPE - Unique name for the dataset (used for storing data)
    • PATH_TRAIN - Path to the train data
    • COLUMNS_TRAIN - Format of the train data (e.g. ["UserId", "Gender", "ReviewText"])
    • PATH_VALIDATION - Path to the validation data
    • COLUMNS_VALIDATION - Format of the validation data (e.g. ["UserId", "Gender", "ReviewText"])
    • PATH_TEST - Path to the test data
    • COLUMNS_TEST - Format of the test data (e.g. ["Gender", "ReviewText"])
    • BASE_FREEZE - Freeze base layers if True
  • 2. Run main.py with the config number as a command line argument (e.g. run python main.py 8 to use config entry with ID 8).

Results on data

Amazon data

Without majority voting With majority voting
Model Male F1 Female F1 Accuracy Male F1 Female F1 Accuracy
VCDNN 0.751 0.770 0.761 0.823 0.834 0.831
BERT 0.735 0.764 0.750 0.808 0.830 0.820
roBERTa 0.712 0.758 0.737 0.783 0.819 0.783
distilBERT 0.731 0.754 0.743 0.802 0.819 0.802

StackOverflow data

Without majority voting With majority voting
Model Male F1 Female F1 Accuracy Male F1 Female F1 Accuracy
HAN 0.640 0.642 0.641 0.735 0.719 0.727
BERT (2e-5) 0.652 0.643 0.648 0.738 0.722 0.730
roBERTa (2e-5) 0.658 0.653 0.655 0.724 0.710 0.717
distilBERT (2e-5) 0.640 0.649 0.644 0.711 0.715 0.713

Reddit data

Without majority voting With majority voting
Model Male F1 Female F1 Accuracy Male F1 Female F1 Accuracy
HAN 0.644 0.660 0.652 0.909 0.907 0.908
VDCNN 0.718 0.659 0.692 0.879 0.848 0.865
BERT (10%) 0.702 0.686 0.695 0.914 0.905 0.914
roBERTa (10%) 0.685 0.681 0.683 0.916 0.909 0.913
distilBERT (10%) 0.695 0.665 0.681 0.901 0.887 0.895