Is there any training example about tabtransformer?
pancodex opened this issue · 1 comments
pancodex commented
Hi,
I want to use it in a tabular dataset to finish a supervised learning,But I dont really know how to train this model with dataset(it seems that there is no such content in the readme file ). Could you please help me? thank you.
Alexx776 commented
Hello, here is a simple example of training the model, I hope it can help
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tab_transformer_pytorch import TabTransformer
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
# Generate fake data
def generate_data(num_samples):
x_categ = torch.randint(0, 10, (num_samples, 5))
x_cont = torch.randn(num_samples, 10)
y = torch.zeros(num_samples)
for i in range(num_samples):
if x_categ[i, 0] > 5 and x_cont[i, 0] > 0:
y[i] = 2
elif x_categ[i, 1] < 3 or x_cont[i, 1] < -1:
y[i] = 1
else:
y[i] = 0
return x_categ, x_cont, y.long()
num_samples = 10000
x_categ, x_cont, y = generate_data(num_samples)
cont_mean = x_cont.mean(dim=0)
cont_std = x_cont.std(dim=0)
x_cont = (x_cont - cont_mean) / cont_std
cont_mean_std = torch.stack([cont_mean, cont_std], dim=1)
# Model
model = TabTransformer(
categories = (10, 10, 10, 10, 10),
num_continuous = 10,
dim = 64,
dim_out = 3,
depth = 6,
heads = 8,
attn_dropout = 0.1,
ff_dropout = 0.1,
mlp_hidden_mults = (4, 2),
mlp_act = nn.ReLU(),
continuous_mean_std = cont_mean_std
)
dataset = TensorDataset(x_categ, x_cont, y)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 20
# Train
for epoch in range(num_epochs):
model.train()
total_loss = 0
all_preds = []
all_labels = []
for batch_categ, batch_cont, batch_y in train_loader:
optimizer.zero_grad()
outputs = model(batch_categ, batch_cont)
loss = criterion(outputs, batch_y)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
all_preds.extend(predicted.numpy())
all_labels.extend(batch_y.numpy())
avg_loss = total_loss / len(train_loader)
accuracy = accuracy_score(all_labels, all_preds)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")