Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch
-
-
Author: Yi Tay, Dara Bahri, Donald Metzler, Da-Cheng Juan, Zhe Zhao, Che Zheng
-
2. Fixed Random Synthesizer
4. Factorized Dense Synthesizer
5. Factorized Random Synthesizer
6. Mixture of Synthesizers
import torch
from synthesizer import Transformer, SynthesizerDense, SynthesizerRandom, FactorizedSynthesizerDense, FactorizedSynthesizerRandom, MixtureSynthesizers, get_n_params, calculate_flops
def main():
batch_size, channel_dim, sentence_length = 2, 1024, 32
x = torch.randn([batch_size, sentence_length, channel_dim])
vanilla = Transformer(channel_dim)
out, attention_map = vanilla(x)
print(out.size(), attention_map.size())
n_params, flops = get_n_params(vanilla), calculate_flops(vanilla.children())
print('vanilla, n_params: {}, flops: {}'.format(n_params, flops))
dense_synthesizer = SynthesizerDense(channel_dim, sentence_length)
out, attention_map = dense_synthesizer(x)
print(out.size(), attention_map.size())
n_params, flops = get_n_params(dense_synthesizer), calculate_flops(dense_synthesizer.children())
print('dense_synthesizer, n_params: {}, flops: {}'.format(n_params, flops))
random_synthesizer = SynthesizerRandom(channel_dim, sentence_length)
out, attention_map = random_synthesizer(x)
print(out.size(), attention_map.size())
n_params, flops = get_n_params(random_synthesizer), calculate_flops(random_synthesizer.children())
print('random_synthesizer, n_params: {}, flops: {}'.format(n_params, flops))
random_synthesizer_fix = SynthesizerRandom(channel_dim, sentence_length, fixed=True)
out, attention_map = random_synthesizer_fix(x)
print(out.size(), attention_map.size())
n_params, flops = get_n_params(random_synthesizer_fix), calculate_flops(random_synthesizer_fix.children())
print('random_synthesizer_fix, n_params: {}, flops: {}'.format(n_params, flops))
factorized_synthesizer_random = FactorizedSynthesizerRandom(channel_dim)
out, attention_map = factorized_synthesizer_random(x)
print(out.size(), attention_map.size())
n_params, flops = get_n_params(factorized_synthesizer_random), calculate_flops(
factorized_synthesizer_random.children())
print('factorized_synthesizer_random, n_params: {}, flops: {}'.format(n_params, flops))
factorized_synthesizer_dense = FactorizedSynthesizerDense(channel_dim, sentence_length)
out, attention_map = factorized_synthesizer_dense(x)
print(out.size(), attention_map.size())
n_params, flops = get_n_params(factorized_synthesizer_dense), calculate_flops(
factorized_synthesizer_dense.children())
print('factorized_synthesizer_dense, n_params: {}, flops: {}'.format(n_params, flops))
mixture_synthesizer = MixtureSynthesizers(channel_dim, sentence_length)
out, attention_map = mixture_synthesizer(x)
print(out.size(), attention_map.size())
n_params, flops = get_n_params(mixture_synthesizer), calculate_flops(mixture_synthesizer.children())
print('mixture_synthesizer, n_params: {}, flops: {}'.format(n_params, flops))
if __name__ == '__main__':
main()
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
vanilla, n_params: 3148800, flops: 3145729
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
dense_synthesizer, n_params: 1083456, flops: 1082370
torch.Size([2, 32, 1024]) torch.Size([1, 32, 32])
random_synthesizer, n_params: 1050624, flops: 1048577
torch.Size([2, 32, 1024]) torch.Size([1, 32, 32])
random_synthesizer_fix, n_params: 1050624, flops: 1048577
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
factorized_synthesizer_random, n_params: 1066000, flops: 1064961
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
factorized_synthesizer_dense, n_params: 1061900, flops: 1060865
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
mixture_synthesizer, n_params: 3149824, flops: 3145729