
Question: problem with xval implementation

I have tried xval with a simple mock example and tried to overfit and see what the model generates but its generate weird results

import torch
import json
from x_transformers import Decoder, XValTransformerWrapper, XValAutoregressiveWrapper

model = XValTransformerWrapper(
    num_tokens = 4,
    numerical_token_id = 3,
    max_seq_len = 1024,
    attn_layers = Decoder(dim = 512, depth = 12, heads = 8)
model = XValAutoregressiveWrapper(model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train(model, optimizer, epochs=10):
    # Constant mock data
    ids = torch.tensor([[1, 2, 3, 0, 0, 3, 3, 3, 2, 2, 0, 1, 0, 1, 2, 1]])
    nums = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1.4426, 1]])
    mask = (nums != 1)
    for epoch in range(epochs):
        loss = model(ids, nums, mask=mask)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

# Train the model for more epochs
train(model, optimizer, epochs=20)

Epoch 20/20, Loss: 0.0900447741150856

# then generate
start_ids = torch.randint(0, 4, (1, 1))
start_nums = torch.randn(1, 1)

ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 50)

# (1, 17), (1, 17), (1, 17)
ids_out, num_out, is_number_mask


(tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0]]),
 tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan]]),
 tensor([[False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False]]))

can you point out if I am doing anything wrong here.

@HarshaSatyavardhan hey Harsha, thanks for your interest. i somehow had the key padding mask included in the readme when it should not be there. the numerical mask is auto-handled based on the numerical_token_id

could you try running the script below? it should work

the nans you see are not an error, just to explicitly remind the researcher which values are not a number

import torch
from x_transformers import (
from einops import repeat

model = XValTransformerWrapper(
    num_tokens = 4,
    numerical_token_id = 3,
    max_seq_len = 1024,
    attn_layers = Decoder(dim = 512, depth = 12, heads = 8)

model = XValAutoregressiveWrapper(model).cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

def train(model, optimizer, epochs=10):
    # Constant mock data
    ids = torch.tensor([[1, 2, 3, 0, 3, 1]]).cuda()
    nums = torch.tensor([[0., 0., 3.14, 0., 2.72, 0.]]).cuda()

    batched_ids = repeat(ids, '1 n -> b n', b = 32)
    batched_nums = repeat(nums, '1 n -> b n', b = 32)

    for epoch in range(epochs):
        loss = model(batched_ids, batched_nums)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

train(model, optimizer, epochs=50)
start_ids = torch.ones((1, 1)).cuda()
start_nums = torch.zeros(1, 1).cuda()

ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 5)

print(ids_out, num_out, is_number_mask)

@HarshaSatyavardhan there was an issue with the numerical loss 🤦‍♂️ should converge better now, even if it sort of worked before

@lucidrains which is correct for nums? using 1 or 0 where their is no number?
nums = torch.tensor([[0., 0., 3.14, 0., 2.72, 0.]]).cuda()
or this
nums = torch.tensor([[1., 1., 3.14, 1., 2.72, 1.]]).cuda()

I think using 1 leads to better results is this correct or wrong ?
according to the paper we are multiplying these values with the embeddings don't you think multiplying with zero leads to problem.

@HarshaSatyavardhan yes, you are correct! wow you understand the paper well

i protect against that here, so you can actually put any value there (even nans, to be explicit on what is not a number)

@HarshaSatyavardhan do let me know if/once you train anything significant with xval

if it works well for you, i may build out a special tokenizer to do this type of numerical encoding