Tensorflow 2 implementations of the relation classification architectures from Google's Matching the Blanks paper.

Primary LanguagePythonMIT LicenseMIT

BERT for Relation Extraction

bert_re contains Tensorflow 2 implementations of BERT-based models for relation extraction, as given in Figure 3 of Matching the Blanks: Distributional Similarity for Relation Learning

These models are:

Model Input Classification Head
3a Standard CLS Token
3b Standard Mention Pooling
3c Positional Emb. Mention Pooling
3d Entity Markers CLS Token
3e Entity Markers Mention Pooling
3f Entity Markers Entity Start Token

Model 3c is not yet implemented here.

Installation and Usage

git clone https://github.com/jvasilakes/BERT-RE.git
pip install -r requirements.txt
python setup.py develop

See examples.py for usage.

SemEval2010 Task 8

python run_semeval.py --model_id ${MODEL_ID} \
		      --bert_model_dir /path/to/uncased_L-12_H-768_A-12/ \
		      --train_file /path/to/SemEval2010_task8_all_data/SemEval2010_task8_training/TRAIN_FILE.TXT \
		      --test_file data/SemEval2010_task8_all_data/SemEval2010_task8_testing_keys/TEST_FILE_FULL.TXT \
		      --outdir /path/to/desired/output/directory/ \
		      --learning_rate 3e-5 --batch_size 16 --epochs 10

Where ${MODEL_ID} is one of 3a, 3b, 3d, 3e, 3f.

Model 3c is not yet implemented.

The training dataset of 8000 examples was randomly split into 80% training (6400 examples) and 20% (1600 examples) using sklearn.train_test_split with random_state=0.

Each model was trained with the following hyper-parameters:

  • BERT model: BERT-base
  • Learning rate: Linear warmup to 3e-5 at step 400 followed by polynomial decay.
  • Batch size: 16
  • Optimizer: Adam (epsilon = 1e-8)
  • Loss: Categorical cross entropy from softmax activations
  • Classification layer dropout rate: 0.1
  • Epochs: Early stopping monitoring validation loss.

Note that unlike the paper, we use a dense layer with softmax activations for computing output probabilities, rather than layer normalization with linear activations.


We report weighted averages of precision, recall, and F1. The number of training epochs reported refers to the number of epochs completed before validation loss stopped improving.

Model Precision Recall F1 Support # Train epochs
3a 0.73 0.74 0.73 2717 3
3b 0.83 0.82 0.82 2717 4
3d 0.81 0.82 0.81 2717 3
3e 0.82 0.82 0.82 2717 4
3f 0.82 0.83 0.82 2717 3

View the training logs on Tensorboard dev


SemEval2010 Task 8 data obtained from sahitya0000

Pretrained BERT-base weights obtained from the official Google release

BERT layer implemented using BERT for Tensorflow v2

Tokenization borrowed from the Hugging Face Transformers Library