lucidrains/tab-transformer-pytorch

Is there any training example about tabtransformer?

pancodex opened this issue · 1 comments

Hi,
I want to use it in a tabular dataset to finish a supervised learning,But I dont really know how to train this model with dataset(it seems that there is no such content in the readme file ). Could you please help me? thank you.

Hello, here is a simple example of training the model, I hope it can help

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tab_transformer_pytorch import TabTransformer
import numpy as np
from sklearn.metrics import accuracy_score, classification_report

# Generate fake data
def generate_data(num_samples):
    x_categ = torch.randint(0, 10, (num_samples, 5))
    
    x_cont = torch.randn(num_samples, 10)
    
    y = torch.zeros(num_samples)
    for i in range(num_samples):
        if x_categ[i, 0] > 5 and x_cont[i, 0] > 0:
            y[i] = 2  
        elif x_categ[i, 1] < 3 or x_cont[i, 1] < -1:
            y[i] = 1 
        else:
            y[i] = 0  
    
    return x_categ, x_cont, y.long()

num_samples = 10000
x_categ, x_cont, y = generate_data(num_samples)

cont_mean = x_cont.mean(dim=0)
cont_std = x_cont.std(dim=0)
x_cont = (x_cont - cont_mean) / cont_std 
cont_mean_std = torch.stack([cont_mean, cont_std], dim=1)

# Model
model = TabTransformer(
    categories = (10, 10, 10, 10, 10), 
    num_continuous = 10,                
    dim = 64,                            
    dim_out = 3,                        
    depth = 6,                           
    heads = 8,                          
    attn_dropout = 0.1,                  
    ff_dropout = 0.1,                    
    mlp_hidden_mults = (4, 2),           
    mlp_act = nn.ReLU(),                 
    continuous_mean_std = cont_mean_std  
)

dataset = TensorDataset(x_categ, x_cont, y)
train_size = int(0.8 * len(dataset))  
test_size = len(dataset) - train_size 
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)  
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)  

criterion = nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 

num_epochs = 20

# Train
for epoch in range(num_epochs):
    model.train()  
    total_loss = 0
    all_preds = []  
    all_labels = []  

    for batch_categ, batch_cont, batch_y in train_loader:
        
        optimizer.zero_grad()  
        
        outputs = model(batch_categ, batch_cont)
        loss = criterion(outputs, batch_y)  
        
        loss.backward() 
        optimizer.step()  
        
        total_loss += loss.item()  
        
        _, predicted = torch.max(outputs.data, 1) 
        all_preds.extend(predicted.numpy())  
        all_labels.extend(batch_y.numpy())  
    
    avg_loss = total_loss / len(train_loader)
    
    accuracy = accuracy_score(all_labels, all_preds)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")