$ pip install gtrxl-torch
from gtrxl_torch.gtrxl_torch import GTrXL
import torch
model = GTrXL(
d_model=512,
nheads=4,
transformer_layers=1
)
input = torch.randn(32,16,512)
output = model(input)
Dimension ➯ [Sequence, Batch, Memory Size]
model.save()
model.load()
d_model
: int.
The number of expected features in the encoder/decoder inputsnheads
: int.
The number of heads in the multiheadattention modelstransformer_layers
: int.
Number of Transformer blocks.hidden_dims
: int.
Number of hidden neurons for the postion wise MLP.n_layers
: int.
RNN (GRU) layers.layer_norm_eps
: float, default1e-5
.
The eps value in layer normalization components.batch_first
: bool, defaultFalse
.
(N, S, E) if batch first.chkpt_dir
: str defaultmodels
.
Directory name where model is saved.activation
: str, defaultrelu
.
Activation function for MLP.network_name
: str, defaultnetwork.pt
.
Name of the model (file) you're saving.
- Alterations to the transformer model (GTrXL) ➱ Click Here