Load GPT-2 checkpoint and generate texts in PyTorch
PythonMIT
PyTorch GPT-2
Install
pip install torch-gpt-2
Demo
importosimportsysfromtorch_gpt_2importload_trained_model_from_checkpoint, get_bpe_from_files, generateiflen(sys.argv) !=2:
print('python3 demo.py MODEL_FOLDER')
sys.exit(-1)
model_folder=sys.argv[1]
config_path=os.path.join(model_folder, 'hparams.json')
checkpoint_path=os.path.join(model_folder, 'model.ckpt')
encoder_path=os.path.join(model_folder, 'encoder.json')
vocab_path=os.path.join(model_folder, 'vocab.bpe')
print('Load net from checkpoint...')
net=load_trained_model_from_checkpoint(config_path, checkpoint_path)
print('Load BPE from files...')
bpe=get_bpe_from_files(encoder_path, vocab_path)
print('Generate text...')
output=generate(net, bpe, ['From the day forth, my arm'], length=20, top_k=1)
# If you are using the 117M model and top_k equals to 1, then the result would be:# "From the day forth, my arm was broken, and I was in a state of pain. I was in a state of pain,"print(output[0])