Question: problem with xval implementation
Closed this issue · 5 comments
I have tried xval with a simple mock example and tried to overfit and see what the model generates but its generate weird results
import torch
import json
from x_transformers import Decoder, XValTransformerWrapper, XValAutoregressiveWrapper
model = XValTransformerWrapper(
num_tokens = 4,
numerical_token_id = 3,
max_seq_len = 1024,
attn_layers = Decoder(dim = 512, depth = 12, heads = 8)
)
model = XValAutoregressiveWrapper(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
def train(model, optimizer, epochs=10):
# Constant mock data
ids = torch.tensor([[1, 2, 3, 0, 0, 3, 3, 3, 2, 2, 0, 1, 0, 1, 2, 1]])
nums = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1.4426, 1]])
mask = (nums != 1)
for epoch in range(epochs):
optimizer.zero_grad()
loss = model(ids, nums, mask=mask)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")
# Train the model for more epochs
train(model, optimizer, epochs=20)
Epoch 20/20, Loss: 0.0900447741150856
# then generate
start_ids = torch.randint(0, 4, (1, 1))
start_nums = torch.randn(1, 1)
ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 50)
# (1, 17), (1, 17), (1, 17)
ids_out, num_out, is_number_mask
results
(tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0]]),
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan]]),
tensor([[False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False]]))
can you point out if I am doing anything wrong here.
@HarshaSatyavardhan hey Harsha, thanks for your interest. i somehow had the key padding mask included in the readme when it should not be there. the numerical mask is auto-handled based on the numerical_token_id
could you try running the script below? it should work
the nans
you see are not an error, just to explicitly remind the researcher which values are not a number
import torch
from x_transformers import (
Decoder,
XValTransformerWrapper,
XValAutoregressiveWrapper
)
from einops import repeat
model = XValTransformerWrapper(
num_tokens = 4,
numerical_token_id = 3,
max_seq_len = 1024,
attn_layers = Decoder(dim = 512, depth = 12, heads = 8)
)
model = XValAutoregressiveWrapper(model).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
def train(model, optimizer, epochs=10):
# Constant mock data
ids = torch.tensor([[1, 2, 3, 0, 3, 1]]).cuda()
nums = torch.tensor([[0., 0., 3.14, 0., 2.72, 0.]]).cuda()
batched_ids = repeat(ids, '1 n -> b n', b = 32)
batched_nums = repeat(nums, '1 n -> b n', b = 32)
for epoch in range(epochs):
optimizer.zero_grad()
loss = model(batched_ids, batched_nums)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")
train(model, optimizer, epochs=50)
start_ids = torch.ones((1, 1)).cuda()
start_nums = torch.zeros(1, 1).cuda()
ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 5)
print(ids_out, num_out, is_number_mask)
@HarshaSatyavardhan there was an issue with the numerical loss 🤦♂️ should converge better now, even if it sort of worked before
@lucidrains which is correct for nums? using 1
or 0
where their is no number?
nums = torch.tensor([[0., 0., 3.14, 0., 2.72, 0.]]).cuda()
or this
nums = torch.tensor([[1., 1., 3.14, 1., 2.72, 1.]]).cuda()
I think using 1 leads to better results is this correct or wrong ?
according to the paper we are multiplying these values with the embeddings don't you think multiplying with zero leads to problem.
@HarshaSatyavardhan yes, you are correct! wow you understand the paper well
i protect against that here, so you can actually put any value there (even nans, to be explicit on what is not a number)
@HarshaSatyavardhan do let me know if/once you train anything significant with xval
if it works well for you, i may build out a special tokenizer to do this type of numerical encoding