Reproduces and extends Pretrained Transformers as Universal Computation Engines.
Open our pipeline in Colab.
Check our results in wandb page.
The MNIST database contains 60,000 training images and 10,000 testing images of handwritten digits. We use the standard MNIST benchmark, where the model must classify 32 × 32 black-and-white image. The tokens given to the model are 4 × 4 image patches, so the models are fed 64 tokens of dimension 16.
The CIFAR-10 dataset contains 60,000 32x32 color images in 10 different classes. The 10 different classes represent airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. There are 6,000 images of each class. We use the standard CIFAR-10 benchmark, where the tokens given to the model are 4 × 4 image patches, so the models are fed 64 tokens of dimension 16.
Model is presented with a sequence of n
MNIST digits (28x28 pixels) and should predict the sum of them. Task is
parametrized by a sequence length (for n=1
it is equivalend to standard MNIST task). Task is taken from this paper.
Speech Commands dataset contains short audio clips of a fixed number of command words such as “stop”, “go”, “up”, “down”, etc spoken by many speakers. Google released two versions of the dataset with the first version containing 65k samples over 30 classes, and the second containing 110k samples over 35 classes. In our project, we used the second version which is available in the torchaudio library.
The CYP P450 genes are involved in the formation and breakdown (metabolism) of various molecules and chemicals within cells. Specifically, CYP3A4 is an important enzyme in the body, mainly found in the liver and in the intestine. It oxidizes small foreign organic molecules (xenobiotics), such as toxins or drugs, so that they can be removed from the body.
- Complete table?
- lepsze baseliny: https://paperswithcode.com/sota/image-classification-on-cifar-10
- określanie jaki jakie score powinien osiągać "random"
Dataset | Metric Name | Result |
---|---|---|
MNIST | Accuracy | 99.5% |
CIFAR 10 | Accuracy | 99.5% |
MNIST Digits Addition (n=10) | Mean Absolute Error | 1.42 (dummy: 7.31) |
Cyp3A4 Inhibition | Accuracy | 82.1% |
Speech Command | Accuracy | 98.1% |
The baseline for this dataset is LSTM, taken from original paper.
The baseline for this dataset is VIT-H, taken from this paper.
The baseline for this dataset is NAC model from "Neural Arithmetic Logic Units" paper. It is a model designed for addition / subtraction task, and it perform linear affine tranformation of its input.
The baseline for this dataset is Molecule Attention Transformer, fine-tuned with huggingmolecules package (on the same data split, with a default hps setting).
The baseline for this dataset is Audio Spectrogram Transformer, more details can be found in paper.
Can pretrained language models transfer to different modalities?
- Train Frozen Pretrained Transformer (FPT) on all datasets with default parameters set:
experiments_params = dict(
# ...
model_name='gpt2',
pretrained=True,
freeze_trans=True,
freeze_in=False,
freeze_pos=False,
freeze_ln=False,
freeze_attn=True,
freeze_ff=True,
freeze_out=False,
# ...
)
- Compare the results with the results of baselines. Are they somehow comparable?
Result = average of test accuracy on k steps (for all experiments k=100)
Dataset | Metric Name | Result | #runs | #steps | Parameters | wandb |
---|---|---|---|---|---|---|
MNIST | Accuracy | 98.15% | 1 | 250 | steps_per_iter=200 test_steps_per_iter=100 learning_rate=1e-3 batch_size=16 patch_size=4 |
3boh2kd2 |
CIFAR10 | Accuracy | 63.24% | 1 | 550 | steps_per_iter=200 test_steps_per_iter=100 learning_rate=1e-3 batch_size=16 patch_size=4 |
3qo22alh |
MNIST Digits Addition | MSE | 7.404 | 1 | 200 | steps_per_iter=200 test_steps_per_iter=50 learning_rate=1e-3 batch_size=16 patch_size=28 n=10 |
2mfojag2 |
Cyp3A4 Inhibition | Accuracy | 75.65% | 1 | 280 | steps_per_iter=200 test_steps_per_iter=50 learning_rate=1e-3 batch_size=16 |
3kzurs4w |
Speech Command | Accuracy | 64.97% | 1 | 500 | steps_per_iter=200 test_steps_per_iter=25 learning_rate=1e-4 batch_size=16 patch_size=80 |
3mi762gc |
- Our preliminary results supports author's thesis.
- Comparison with baselines:
Model | MNIST | CIFAR10 | MNIST Digits Addition | Cyp3A4 Inhibition | Speech Command |
---|---|---|---|---|---|
FPT | 98.15% | 63.24% | 7.404* | 75.65% | 64.97% |
Baseline | 99.5% | 99.5% | 1.42 | 82.1% | 98.1% |
What is the importance of the pretraining modality?
- Train Unfrozen Random Transformer (URT) on all datasets without the pretraining and freezing:
experiments_params = dict(
# ...
model_name='gpt2',
pretrained=False,
freeze_trans=False,
freeze_in=False,
freeze_pos=False,
freeze_ln=False,
freeze_attn=False,
freeze_ff=False,
freeze_out=False,
# ...
)
- Compare the results with the results of baselines and the results from [question 1](#Question 1). Are they somehow comparable?
Dataset | Metric Name | # runs | mean | std | weights | Reference |
---|---|---|---|---|---|---|
MNIST | Accuracy | 6 | 60.76% | 0.1154 | 124,460,554 | WandB |
CIFAR 10 | Accuracy | 3 | 21.73% | 0.0156 | 124,485,130 | WandB |
MNIST Digits Addition (Regression, n=10) | Mean Absolute Error | 3 | 7.41 | 0.0316 | 125,043,457 | WandB |
MNIST Digits Addition (Classification, n=10) | Mean Absolute Error | 5 | 6.55 | 0.9497 | 125,112,667 | WandB |
Cyp3A4 Inhibition | Accuracy | 3 | 61.82% | 0.0308 | 125,209,346 | WandB |
Speech Command | Accuracy | 2 | 62.52% | - | 124,528,931 | WandB |
Model | MNIST | CIFAR10 | MNIST Digits Addition | Cyp3A4 Inhibition | Speech Command |
---|---|---|---|---|---|
URT | 60.76% | 21.73% | 7.41* | 61.82% | 63.56% |
FPT | 98.15% | 63.24% | 7.404* | 75.65% | 64.97% |
Baseline | 99.5% | 99.5% | 1.42 | 82.1% | 98.1% |
Does the transformer architecture provide inductive bias that transfers well to various modalities?
- Train Frozen Random Transformer (FRT) on all datasets without the pretraining, but with freezing:
experiments_params = dict(
# ...
model_name='gpt2',
pretrained=False,
freeze_trans=True,
freeze_in=False,
freeze_pos=False,
freeze_ln=False,
freeze_attn=True,
freeze_ff=True,
freeze_out=False,
# ...
)
- Compare the results with the results of baselines and the results from [question 1](#Question 1) and [question 2](#Question 2). Are they somehow comparable?
Result = average of test accuracy on k steps (for all experiments k=100)
Dataset | Metric Name | Result | #runs | #steps | Parameters | wandb |
---|---|---|---|---|---|---|
MNIST | Accuracy | 97.32% 97.08% 96.76% |
3 | 350 250 220 |
steps_per_iter=200 test_steps_per_iter=50 learning_rate=1e-3 batch_size=16 patch_size=4 |
3v1wtr64 1roiff6y rdqaxnlm |
CIFAR 10 | Accuracy | 56.06% 58.54% 58.08% |
3 | 375 550 520 |
steps_per_iter=200 test_steps_per_iter=50 learning_rate=1e-3 batch_size=16 patch_size=4 |
2i50a27d 1t4rqtyu 2h7i4yza |
MNIST Digits Addition | MSE | 7.543 7.501 7.69 |
3 | 150 200 90 |
steps_per_iter=200 test_steps_per_iter=50 learning_rate=1e-3 batch_size=16 patch_size=28 n=10 |
1dj9wo3f 1drqzu9o 1x5p2yog |
Cyp3A4 Inhibition | Accuracy | 73.08% 75.8% 76.52% |
3 | 50 75 280 |
steps_per_iter=200 test_steps_per_iter=50 learning_rate=1e-3 batch_size=16 |
fitarqv4 3cnuxa09 3d155atc |
Speech Command | Accuracy | 32.87% 24.94% 37.44% |
3 | 240 120 350 |
steps_per_iter=200 test_steps_per_iter=25 learning_rate=1e-4 batch_size=16 patch_size=80 |
1abgqjcm 2e3sow6q voonklgk |
- Comparison with baselines:
Model | MNIST | CIFAR10 | MNIST Digits Addition | Cyp3A4 Inhibition | Speech Command |
---|---|---|---|---|---|
FRT (mean) | 97.05% | 57.56% | 7.578 | 75.13% | 31.75% |
FRT (best) | 97.32% | 58.54% | 7.501 | 76.52% | 37.44% |
URT | 60.76% | 21.73% | 7.41* | 61.82% | 63.56% |
FPT | 98.15% | 63.24% | 7.404* | 75.65% | 64.97% |
Baseline | 99.5% | 99.5% | 1.42 | 82.1% | 98.1% |
Can pretrained visual models transfer to different modalities?
- Implement using ViT as the pretrained transformer.
- Train Visual Frozen Pretrained Transformer (V-FPT) on all datasets with default parameters set, but with ViT as pretrained transformer:
experiments_params = dict(
# ...
model_name='vit',
pretrained=True,
freeze_trans=True,
freeze_in=False,
freeze_pos=False,
freeze_ln=False,
freeze_attn=True,
freeze_ff=True,
freeze_out=False,
# ...
)
- Compare the results with the results of baselines and the results from [question 1](#Question 1). Are they somehow comparable?
Dataset | Metric Name | # runs | mean | std | weights | Reference |
---|---|---|---|---|---|---|
MNIST | Accuracy | 3 | 73.59% | 0.0626 | 59,146 | WandB |
CIFAR 10 | Accuracy | 1 | 44.58% | NaN | 83,722 | WandB |
MNIST Digits Addition (Regression, n=10) | Mean Absolute Error | 1 | 7.29 | NaN | 642,049 | WandB |
MNIST Digits Addition (Classification, n=10) | Mean Absolute Error | 2 | 1.79 | 0.3288 | 711,259 | WandB |
Cyp3A4 Inhibition | Accuracy | 3 | 73.87% | 0.0204 | 807,938 | WandB |
Speech Command | Accuracy | 1 | 36.16% | NaN | 127,523 | WandB |
- ViT also perform quite well as Universal Computation Engine
Model | MNIST | CIFAR10 | MNIST Digits Addition | Cyp3A4 Inhibition | Speech Command |
---|---|---|---|---|---|
FRT | 97.08% | 58.54% | 7.501* | 76.52% | 37.44% |
URT | 60.76% | 21.73% | 7.41* | 61.82% | 63.56% |
FPT | 98.15% | 63.24% | 7.404* | 75.65% | 64.97% |
V-FPT | 73.59% | 44.58% | 7.29* | 73.87% | 36.16% |
Baseline | 99.5% | 99.5% | 1.42 | 82.1% | 98.1% |
Does pretraining scenario influence FPT accuracy?
- Multiple pretrained transformers (gpt2 based) has been selected from HuggingFaces with:
- different pretraining languges:
uer/gpt2-chinese-poem
,LorenzoDeMattei/GePpeTto
,rinna/japanese-gpt2-medium
, ... - different model size (embeddings size, number of attention heads):
tiny-gpt2
,gpt2
,gpt2-medium
,gpt2-large
, ... - different kinds of specialities:
magic-the-gathering
,gpt2-chess-uci
,CodeGPT-small-py
, ...
- Selected models has been pretrained and tested on various tasks
For full results take a look at
experiments/Results.ipynb
notebook
pretrained model | mean accuracy | # trainable weights |
---|---|---|
ceostroff/harry-potter-gpt2-fanfiction | 96.28% | 59,146 |
gpt2 | 95.94% | 59,146 |
sberbank-ai/rugpt3small_based_on_gpt2 | 95.37% | 59,146 |
... | ... | ... |
shtoshni/gpt2-chess-uci | 90.11% | 59,146 |
minimaxir/magic-the-gathering | 69.79% | 7,818 |
sshleifer/tiny-gpt2 | 14.74% | 84 |
pretrained model | mean MAE | # trainable weights |
---|---|---|
ceostroff/harry-potter-gpt2-fanfiction | 1.7950 | 14,97,691 |
chrisliu298/arxiv_ai_gpt2 | 1.9036 | 1,308,251 |
distilgpt2 | 1.9068 | 1,479,259 |
gpt2 | 2.0455 | 1,497,691 |
... | ... | ... |
minimaxir/magic-the-gathering | 3.2861 | 149,339 |
shtoshni/gpt2-chess-uci | 3.4264 | 1,104,475 |
sshleifer/tiny-gpt2 | 7.4578 | 3,911 |
pretrained model | mean accuracy | # trainable weights |
---|---|---|
gpt2 | 72.51% | 62,228 |
sberbank-ai/rugpt3small_based_on_gpt2 | 66.70% | 62,228 |
microsoft/CodeGPT-small-py | 66.70% | 62,228 |
... | ... | ... |
gpt2-large | 49.93% | 226,580 |
microsoft/DialoGPT-medium | 49.79% | 132116 |
chrisliu298/arxiv_ai_gpt2 | 49.76% | 226,580 |
- Number of trained weights is correlated with overall model score
- Additional pretraining on special domain may additionaly increase model performance