Tools for adapting universal language models to specifc tasks
Please note these are tools for rapid prototyping - not brute force hyperparameter tuning.
Adapted from: Google's BERT
pip install critical_path
- Download a pretrained BERT model - start with BERT-Base Uncased if you're not sure where to begin
- Unzip the model and make note of the path
- Full implementation examples can be found here:
- Train and evaluate the SQuAD dataset
- Train and evaluate custom datasets for multi-label classification tasks (multiple labels possible)
- Train and evaluate custom datasets for single-label classification tasks (one label possible)
base_model_folder_path = "../models/uncased_L-12_H-768_A-12/" # Folder containing downloaded Base Model
name_of_config_json_file = "bert_config.json" # Located inside the Base Model folder
name_of_vocab_file = "vocab.txt" # Located inside the Base Model folder
output_directory = "../models/trained_BERT/" # Trained model and results landing folder
# Multi-Label and Single-Label Specific
data_dir = None # Directory .tsv data is stored in - typically for CoLA/MPRC or other datasets with known structure
"""Settable parameters and their default values
Note: Most default values are perfectly fine
"""
# Administrative
init_checkpoint = None
save_checkpoints_steps = 1000
iterations_per_loop = 1000
do_lower_case = True
# Technical
batch_size_train = 32
batch_size_eval = 8
batch_size_predict = 8
num_train_epochs = 3.0
max_seq_length = 128
warmup_proportion = 0.1
learning_rate = 3e-5
# SQuAD Specific
doc_stride = 128
max_query_length = 64
n_best_size = 20
max_answer_length = 30
is_squad_v2 = False # SQuAD 2.0 has examples with no answer, aka "impossible", SQuAD 1.0 does not
verbose_logging = False
null_score_diff_threshold = 0.0
from critical_path.BERT.configs import ConfigClassifier
Flags = ConfigClassifier()
Flags.set_model_paths(
bert_config_file=base_model_folder_path + name_of_config_json_file,
bert_vocab_file=base_model_folder_path + name_of_vocab_file,
bert_output_dir=output_folder_path,
data_dir=data_dir)
Flags.set_model_params(
batch_size_train=8,
max_seq_length=256,
num_train_epochs=3)
# Retrieve a handle for the configs
FLAGS = Flags.get_handle()
A single 1070GTX using BERT-Base Uncased can handle
Model | max_seq_len | batch_size |
---|---|---|
BERT-Base Uncased | 256 | 6 |
... | 384 | 4 |
For full batch size and sequence length guidelines see Google's recommendations
"""For Multi-Label Classification"""
from critical_path.BERT.model_multilabel_class import MultiLabelClassifier
model = MultiLabelClassifier(FLAGS)
- SQuAD has dedicated dataloaders
- read_squad_examples(), write_squad_predictions() in /BERT/model_squad
- Multi-Label Classification has a generic dataloader
- DataProcessor in /BERT/model_multilabel_class
- Note: This requires data labels to be in string format
-
labels = [ ["label_1", "label_2", "label_3"], ["label_2"] ]
- DataProcessor in /BERT/model_multilabel_class
- Single-Label Classification dataloaders
- ColaProcessor is implemented in /BERT/model_classifier
- More dataloader formats have been done by pytorch-pretrained-BERT
"""For Multi-Label Classification with a custom .csv reading function"""
from critical_path.BERT.model_multilabel_class import DataProcessor
# read_data is dataset specifc - see /bert_multilabel_example.py
input_ids, input_text, input_labels, label_list = read_toxic_data(randomize=True)
processor = DataProcessor(label_list=label_list)
train_examples = processor.get_samples(
input_ids=input_ids,
input_text=input_text,
input_labels=input_labels,
set_type='train')
"""Train and predict a Multi-Label Classifier"""
if do_train:
model.train(train_examples, label_list)
if do_predict:
model.predict(predict_examples, label_list)
- Full implementations: