AdityaNG/kan-gpt

CUDA out of memory

Opened this issue · 1 comments

class KanMLP(nn.Module):
    """Some Information about KanLinear"""
    def __init__(self,
              in_features=1152,
              hidden_features = None,
              out_features = None,
               drop=0.
              ):
        super().__init__()
        
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.mlp = nn.ModuleDict(
            dict(
                c_fc=KAN(width=[in_features, hidden_features]),
                c_proj=KAN(width=[hidden_features, out_features]),
                act=NewGELU(),
                dropout=nn.Dropout(0.0),
            )
        )
        m = self.mlp
        self.mlpf = lambda x: m.dropout(
            m.c_proj(m.act(m.c_fc(x)))
        )  # MLP forward



        
    def forward(self, x):
        x = self.mlpf(x)
        return x

net = KanMLP(1152,1152*4).to("cuda")
x = torch.rand(size=(4,4096*4,1152)).to("cuda")
nex(x)

When the number of tokens reaches a certain size, the following situation will occur

 CUDA out of memory.

Hello! Can you answer these questions?

  • What GPU are you using?
  • RAM size of GPU
  • Model (if applicable)

I dropped your code into claude, and hopfully this gives you some indication:

The main reason you're running out of CUDA memory is the large size of your input tensor. Let's break down the memory usage:

Input tensor x:

Shape: (4, 4096*4, 1152) = (4, 16384, 1152)
Elements: 4 * 16384 * 1152 = 75,497,472
Assuming float32, this tensor alone requires about 302 MB of memory.

Network parameters:

Input size: 1152
Hidden size: 1152 * 4 = 4608
This results in two large matrices in the KAN layers, each potentially using significant memory.

Intermediate activations:

The forward pass will create several large intermediate tensors, further increasing memory usage.

To address this issue, you can try the following approaches:

Reduce batch size:
Instead of processing 4 samples at once, try reducing it to 1 or 2:
pythonCopyx = torch.rand(size=(1, 4096*4, 1152)).to("cuda")

Use gradient accumulation:
If you need to process larger batches for training stability, you can use gradient accumulation. This involves processing smaller sub-batches and accumulating gradients before performing an optimization step.
Use mixed precision training:
Utilize float16 (half-precision) computations to reduce memory usage. You can use NVIDIA's Apex library or PyTorch's native AMP (Automatic Mixed Precision):
pythonCopyfrom torch.cuda.amp import autocast

with autocast():
output = net(x)

Optimize your model architecture:
Consider if you can reduce the size of your hidden layers or use more memory-efficient architectures.
Use gradient checkpointing:
This technique trades computation for memory by not storing all intermediate activations.
If possible, process your data in smaller chunks:
Instead of processing the entire 4096*4 sequence length at once, you might be able to process it in smaller segments.