google-research/text-to-text-transfer-transformer

Pretrain T5 using PyTorch

LiweiPeng opened this issue · 2 comments

Hi, I'd like to pretrain T5 model using pytorch. Can someone share the link to a T5 pretrain implementation using pytorch? Thanks.

Hi, you can just use the example provided here:
https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/models/hf_model.py#L27
except that instead of providing the name of a pre-trained model in the constructor:
model = t5.models.HfPyTorchModel("t5-base", "/tmp/hft5/", device)
you should provide a T5Config directly, and the model will be initialized from scratch, e.g.:
model = t5.models.HfPyTorchModel(transformers.T5Config(), "/tmp/hft5/", device)
Here is the T5Config class with its default hparam values:
https://github.com/huggingface/transformers/blob/master/src/transformers/configuration_t5.py#L34
I think you can use the hyperparmeter values corresponding to one of the models like so:
config = transformers.T5Config.from_json_file(json_file="https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json")
A list of available json files is here:
https://github.com/huggingface/transformers/blob/master/src/transformers/configuration_t5.py#L25

We've released nanoT5 that reproduces T5-model (similar to BART) pre-training in PyTorch (not Flax).

You can take a look!

Any suggestions are more than welcome.