
Pretrain T5 using PyTorch

LiweiPeng opened this issue · 2 comments

Hi, I'd like to pretrain T5 model using pytorch. Can someone share the link to a T5 pretrain implementation using pytorch? Thanks.

Hi, you can just use the example provided here:
except that instead of providing the name of a pre-trained model in the constructor:
model = t5.models.HfPyTorchModel("t5-base", "/tmp/hft5/", device)
you should provide a T5Config directly, and the model will be initialized from scratch, e.g.:
model = t5.models.HfPyTorchModel(transformers.T5Config(), "/tmp/hft5/", device)
Here is the T5Config class with its default hparam values:
I think you can use the hyperparmeter values corresponding to one of the models like so:
config = transformers.T5Config.from_json_file(json_file="")
A list of available json files is here:

We've released nanoT5 that reproduces T5-model (similar to BART) pre-training in PyTorch (not Flax).

You can take a look!

Any suggestions are more than welcome.