/GPT2

Slight variations of GPT2 implementations

Primary LanguagePython

Notes

Language models are unsupervised multitask learners

model

  • We use a Transformer (Vaswani et al., 2017) based archi tecture for our LMs. The model largely follows the details of the OpenAI GPT model (Radford et al., 2018) with a Parameters few modifications. Layer normalization (Ba et al., 2016) was moved to the input of each sub-block, similar to a pre-activation residual network (He et al., 2016) and an additional layer normalization was added after the final self attention block. A modified initialization which accounts for the accumulation on the residual path with model depth is used. We scale the weights of residual layers at initial ization by a factor of 1 N where N is the number of residual layers. The vocabulary is expanded to 50,257. We also increase the context size from 512 to 1024 tokens and a larger batchsize of 512 is used.

Architecture details

transformer.wte.weight torch.Size([50257, 768]) transformer.wpe.weight torch.Size([1024, 768]) transformer.h.0.ln_1.weight torch.Size([768]) transformer.h.0.ln_1.bias torch.Size([768]) transformer.h.0.attn.c_attn.weight torch.Size([768, 2304]) transformer.h.0.attn.c_attn.bias torch.Size([2304]) transformer.h.0.attn.c_proj.weight torch.Size([768, 768]) transformer.h.0.attn.c_proj.bias torch.Size([768]) transformer.h.0.ln_2.weight torch.Size([768]) transformer.h.0.ln_2.bias torch.Size([768]) transformer.h.0.mlp.c_fc.weight torch.Size([768, 3072]) transformer.h.0.mlp.c_fc.bias torch.Size([3072]) transformer.h.0.mlp.c_proj.weight torch.Size([3072, 768]) transformer.h.0.mlp.c_proj.bias torch.Size([768]) transformer.h.1.ln_1.weight torch.Size([768]) transformer.h.1.ln_1.bias torch.Size([768]) transformer.h.1.attn.c_attn.weight torch.Size([768, 2304]) transformer.h.1.attn.c_attn.bias torch.Size([2304]) transformer.h.1.attn.c_proj.weight torch.Size([768, 768]) transformer.h.1.attn.c_proj.bias torch.Size([768]) transformer.h.1.ln_2.weight torch.Size([768]) transformer.h.1.ln_2.bias torch.Size([768]) transformer.h.1.mlp.c_fc.weight torch.Size([768, 3072]) transformer.h.1.mlp.c_fc.bias torch.Size([3072]) transformer.h.1.mlp.c_proj.weight torch.Size([3072, 768]) transformer.h.1.mlp.c_proj.bias torch.Size([768]) transformer.h.2.ln_1.weight torch.Size([768]) transformer.h.2.ln_1.bias torch.Size([768]) transformer.h.2.attn.c_attn.weight torch.Size([768, 2304]) transformer.h.2.attn.c_attn.bias torch.Size([2304]) transformer.h.2.attn.c_proj.weight torch.Size([768, 768]) transformer.h.2.attn.c_proj.bias torch.Size([768]) transformer.h.2.ln_2.weight torch.Size([768]) transformer.h.2.ln_2.bias torch.Size([768]) transformer.h.2.mlp.c_fc.weight torch.Size([768, 3072]) transformer.h.2.mlp.c_fc.bias torch.Size([3072]) transformer.h.2.mlp.c_proj.weight torch.Size([3072, 768]) transformer.h.2.mlp.c_proj.bias torch.Size([768]) transformer.h.3.ln_1.weight torch.Size([768]) transformer.h.3.ln_1.bias torch.Size([768]) transformer.h.3.attn.c_attn.weight torch.Size([768, 2304]) transformer.h.3.attn.c_attn.bias torch.Size([2304]) transformer.h.3.attn.c_proj.weight torch.Size([768, 768]) transformer.h.3.attn.c_proj.bias torch.Size([768]) transformer.h.3.ln_2.weight torch.Size([768]) transformer.h.3.ln_2.bias torch.Size([768]) transformer.h.3.mlp.c_fc.weight torch.Size([768, 3072]) transformer.h.3.mlp.c_fc.bias torch.Size([3072]) transformer.h.3.mlp.c_proj.weight torch.Size([3072, 768]) transformer.h.3.mlp.c_proj.bias torch.Size([768]) transformer.h.4.ln_1.weight torch.Size([768]) transformer.h.4.ln_1.bias torch.Size([768]) transformer.h.4.attn.c_attn.weight torch.Size([768, 2304]) transformer.h.4.attn.c_attn.bias torch.Size([2304]) transformer.h.4.attn.c_proj.weight torch.Size([768, 768]) transformer.h.4.attn.c_proj.bias torch.Size([768]) transformer.h.4.ln_2.weight torch.Size([768]) transformer.h.4.ln_2.bias torch.Size([768]) transformer.h.4.mlp.c_fc.weight torch.Size([768, 3072]) transformer.h.4.mlp.c_fc.bias torch.Size([3072]) transformer.h.4.mlp.c_proj.weight torch.Size([3072, 768]) transformer.h.4.mlp.c_proj.bias torch.Size([768]) transformer.h.5.ln_1.weight torch.Size([768]) transformer.h.5.ln_1.bias torch.Size([768]) transformer.h.5.attn.c_attn.weight torch.Size([768, 2304]) transformer.h.5.attn.c_attn.bias torch.Size([2304]) transformer.h.5.attn.c_proj.weight torch.Size([768, 768]) transformer.h.5.attn.c_proj.bias torch.Size([768]) transformer.h.5.ln_2.weight torch.Size([768]) transformer.h.5.ln_2.bias torch.Size([768]) transformer.h.5.mlp.c_fc.weight torch.Size([768, 3072]) transformer.h.5.mlp.c_fc.bias torch.Size([3072]) transformer.h.5.mlp.c_proj.weight torch.Size([3072, 768]) transformer.h.5.mlp.c_proj.bias torch.Size([768]) transformer.h.6.ln_1.weight torch.Size([768]) transformer.h.6.ln_1.bias torch.Size([768]) transformer.h.6.attn.c_attn.weight torch.Size([768, 2304]) transformer.h.6.attn.c_attn.bias torch.Size([2304]) transformer.h.6.attn.c_proj.weight torch.Size([768, 768]) transformer.h.6.attn.c_proj.bias torch.Size([768]) transformer.h.6.ln_2.weight torch.Size([768]) transformer.h.6.ln_2.bias torch.Size([768]) transformer.h.6.mlp.c_fc.weight torch.Size([768, 3072]) transformer.h.6.mlp.c_fc.bias torch.Size([3072]) transformer.h.6.mlp.c_proj.weight torch.Size([3072, 768]) transformer.h.6.mlp.c_proj.bias torch.Size([768]) transformer.h.7.ln_1.weight torch.Size([768]) transformer.h.7.ln_1.bias torch.Size([768]) transformer.h.7.attn.c_attn.weight torch.Size([768, 2304]) transformer.h.7.attn.c_attn.bias torch.Size([2304]) transformer.h.7.attn.c_proj.weight torch.Size([768, 768]) transformer.h.7.attn.c_proj.bias torch.Size([768]) transformer.h.7.ln_2.weight torch.Size([768]) transformer.h.7.ln_2.bias torch.Size([768]) transformer.h.7.mlp.c_fc.weight torch.Size([768, 3072]) transformer.h.7.mlp.c_fc.bias torch.Size([3072]) transformer.h.7.mlp.c_proj.weight torch.Size([3072, 768]) transformer.h.7.mlp.c_proj.bias torch.Size([768]) transformer.h.8.ln_1.weight torch.Size([768]) transformer.h.8.ln_1.bias torch.Size([768]) transformer.h.8.attn.c_attn.weight torch.Size([768, 2304]) transformer.h.8.attn.c_attn.bias torch.Size([2304]) transformer.h.8.attn.c_proj.weight torch.Size([768, 768]) transformer.h.8.attn.c_proj.bias torch.Size([768]) transformer.h.8.ln_2.weight torch.Size([768]) transformer.h.8.ln_2.bias torch.Size([768]) transformer.h.8.mlp.c_fc.weight torch.Size([768, 3072]) transformer.h.8.mlp.c_fc.bias torch.Size([3072]) transformer.h.8.mlp.c_proj.weight torch.Size([3072, 768]) transformer.h.8.mlp.c_proj.bias torch.Size([768]) transformer.h.9.ln_1.weight torch.Size([768]) transformer.h.9.ln_1.bias torch.Size([768]) transformer.h.9.attn.c_attn.weight torch.Size([768, 2304]) transformer.h.9.attn.c_attn.bias torch.Size([2304]) transformer.h.9.attn.c_proj.weight torch.Size([768, 768]) transformer.h.9.attn.c_proj.bias torch.Size([768]) transformer.h.9.ln_2.weight torch.Size([768]) transformer.h.9.ln_2.bias torch.Size([768]) transformer.h.9.mlp.c_fc.weight torch.Size([768, 3072]) transformer.h.9.mlp.c_fc.bias torch.Size([3072]) transformer.h.9.mlp.c_proj.weight torch.Size([3072, 768]) transformer.h.9.mlp.c_proj.bias torch.Size([768]) transformer.h.10.ln_1.weight torch.Size([768]) transformer.h.10.ln_1.bias torch.Size([768]) transformer.h.10.attn.c_attn.weight torch.Size([768, 2304]) transformer.h.10.attn.c_attn.bias torch.Size([2304]) transformer.h.10.attn.c_proj.weight torch.Size([768, 768]) transformer.h.10.attn.c_proj.bias torch.Size([768]) transformer.h.10.ln_2.weight torch.Size([768]) transformer.h.10.ln_2.bias torch.Size([768]) transformer.h.10.mlp.c_fc.weight torch.Size([768, 3072]) transformer.h.10.mlp.c_fc.bias torch.Size([3072]) transformer.h.10.mlp.c_proj.weight torch.Size([3072, 768]) transformer.h.10.mlp.c_proj.bias torch.Size([768]) transformer.h.11.ln_1.weight torch.Size([768]) transformer.h.11.ln_1.bias torch.Size([768]) transformer.h.11.attn.c_attn.weight torch.Size([768, 2304]) transformer.h.11.attn.c_attn.bias torch.Size([2304]) transformer.h.11.attn.c_proj.weight torch.Size([768, 768]) transformer.h.11.attn.c_proj.bias torch.Size([768]) transformer.h.11.ln_2.weight torch.Size([768]) transformer.h.11.ln_2.bias torch.Size([768]) transformer.h.11.mlp.c_fc.weight torch.Size([768, 3072]) transformer.h.11.mlp.c_fc.bias torch.Size([3072]) transformer.h.11.mlp.c_proj.weight torch.Size([3072, 768]) transformer.h.11.mlp.c_proj.bias torch.Size([768]) transformer.ln_f.weight torch.Size([768]) transformer.ln_f.bias torch.Size([768]) lm_head.weight torch.Size([50257, 768])