/xlnet

XLNet: fine tuning on RTX 2080 GPU - 8 GB

Primary LanguagePythonApache License 2.0Apache-2.0

Introduction

This fork is an slightly modification to be able to train the large model in the Squad 2.0 dataset using a RTX 2080 (8GB) GPU.

The modifications are:

  • Use FP-16
  • Reduce batch_size to 4
  • Reduce seq_len to 340.
  • Train half of the network, ie, layers 12, 13..., 23. Freeze the others (1, 2, ... 11)
  • Replace the FC layers (1024 -> 1) to a deeper FC layer (512 -> 256 -> 1) for start_logits, end_logits and CLS.

The files changed are:

With those modifications I could achieve 86,23 F1-Score on the Squad-2.0 dev_set, training for 85000 steps (~ 3 epochs of the full dataset). This training took about 5-6 hours.

best_exact 83.4077318285185
best_exact_thresh -1.920951247215271
best_f1 86.23180344890973
best_f1_thresh -1.8610079288482666
has_ans_exact 0.8658906882591093
has_ans_f1 0.9299826812846799

I consider a very good result, since it is trained in a very limited hardware.

For those who has TPU access, could use the original implementation, traing all the layers, replacing the single FC Layer for a deeper FC layer and see how it improves the network.