astramind-ai/BitMat

Corrupted data?

Closed this issue · 6 comments

I'm getting corrupted data (I think) when testing the library. This only happens when I print the results of a forward() pass in a loop, but I think it is causing problems elsewhere as well (training collapses at some point).

Here's a minimal reproduction of the problem:

import torch

from transformers import LlamaConfig, LlamaForCausalLM
from bitmat import convert_hf_model

torch.set_default_device('cuda')

seed = 1713219988 

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

student_config = LlamaConfig(
  hidden_size=2048,
  intermediate_size=8192,
  num_hidden_layers=24,
  num_attention_heads=32,
  vocab_size=51200,
)

model = LlamaForCausalLM(student_config)

model = convert_hf_model(model)

with torch.no_grad():
  model.eval()

  for i in range(5):
    # torch.cuda.synchronize("cuda") # should fix if it's a sync issue, but does not

    x = model.forward(torch.tensor([[1,2,3]]).to('cuda')).logits # .to("cpu")
    print(i, x) # shows inconsistent results
    # print(i, x[:,:,0]) # no problem

When this runs, it prints different values for the output tensor in each pass of the loop. That should not happen. When I only print a small slice, the values are consistent. Also if I move the tensor to the cpu it does not show the problem.

This looks like a synchronization error, but calling cuda.synchronize() does not help. Am I missing something?


I got the same results on a 4080 (local) and an A100 (on datacrunch).

CUDA version 12.3.2
Bitmat installed from github version 0.3.3

I've tested the kernel and while it doesn't work with your example, it is a matrix size issue.
Triton often encounters difficulties when handling small matrices because coordinating thousands of threads for such tasks can be challenging. We are currently in the process of rewriting the backend in CUDA to enhance its reliability and performance significantly.
You can try this code to test it yourself:

import torch

from transformers import LlamaConfig, LlamaForCausalLM
from bitmat import convert_hf_model

torch.set_default_device('cuda')

seed = 1713219988

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

student_config = LlamaConfig(
  hidden_size=2048,
  intermediate_size=8192,
  num_hidden_layers=24,
  num_attention_heads=32,
  vocab_size=51200,
)

model = LlamaForCausalLM(student_config)

model = convert_hf_model(model)
with torch.no_grad():
  model.eval()
  rand_int = torch.randint(0, 51200, (1, 1024))

  for i in range(10):
    x = model.forward(rand_int.to('cuda')).logits # .to("cpu")
    if i > 1: # if it's the first iter assing x to old_x
      assert torch.allclose(x, x_old)
      print(str(i)+" | passed")
      x_old = x
    else:
      x_old = x

Thanks for looking at it, I will try it out

I ran into some more stability issues while testing and looked into it a bit more. The problem is not threads. The problem is the RMS norm layer is initialized with the wrong size -- output size instead of input size.

This will work if you test it with a square matrix, but if the input size is larger than the output size, it will return invalid results. Specifically, your _rms_layernorm_forward method will access memory beyond the length of the array. This data might be uninitialized, or it might be the contents of some other tensor.

To see what I mean, consider a layer with input 8192 and output 2048. The _rms_layernorm_forward uses the same mask for the input (strided) and weights. So some mask indices will be > 2048, and the tl.load function will be accessing undefined data.

X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)

And anyways intuitively, you're normalizing the inputs so it makes sense that weights are input size. The fix is just

self.norm = RMSLayerNorm(in_features, eps) 

in bitlineary.py.

Seems pretty stable now, so I'm going to try training again.

BTW my last comment sounds pretty critical, I don't mean to be -- I really want this library to work out, I'm just trying to be helpful.

@duncanwerner Thank you! I will take a look at it right away. In any case if you want you can open a pull request

@duncanwerner don't worry we appreciate all the suggestions!
Thank you for pointing it out!