/R-BERT

Pytorch implementation of R-BERT: "Enriching Pre-trained Language Model with Entity Information for Relation Classification"

Primary LanguagePythonApache License 2.0Apache-2.0

R-BERT

PWC

(Unofficial) Pytorch implementation of R-BERT: Enriching Pre-trained Language Model with Entity Information for Relation Classification

Model Architecture

Method

  1. Get three vectors from BERT.
    • [CLS] token vector
    • averaged entity_1 vector
    • averaged entity_2 vector
  2. Pass each vector to the fully-connected layers.
    • dropout -> tanh -> fc-layer
  3. Concatenate three vectors.
  4. Pass the concatenated vector to fully-connect layer.
    • dropout -> fc-layer
  • Exactly the SAME conditions as written in paper.
    • Averaging on entity_1 and entity_2 hidden state vectors, respectively. (including $, # tokens)
    • Dropout and Tanh before Fully-connected layer.
    • No [SEP] token at the end of sequence. (If you want add [SEP] token, give --add_sep_token option)

Dependencies

  • perl (For evaluating official f1 score)
  • python>=3.6
  • torch==1.6.0
  • transformers==3.3.1

How to run

$ python3 main.py --do_train --do_eval
  • Prediction will be written on proposed_answers.txt in eval directory.

Official Evaluation

$ python3 official_eval.py
# macro-averaged F1 = 88.29%
  • Evaluate based on the official evaluation perl script.
    • MACRO-averaged f1 score (except Other relation)
  • You can see the detailed result on result.txt in eval directory.

Prediction

$ python3 predict.py --input_file {INPUT_FILE_PATH} --output_file {OUTPUT_FILE_PATH} --model_dir {SAVED_CKPT_PATH}

References