redotvideo/mamba-chat

Error during training

Closed this issue · 6 comments

I can get the model to perform inference just fine, but in my colab env this is what I run into

`2023-12-21 14:14:22.093031: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-21 14:14:22.093134: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-21 14:14:22.162482: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-21 14:14:24.060295: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Tokenizing dataset...
100% 1000/1000 [00:02<00:00, 383.23it/s]
0% 0/750 [00:00<?, ?it/s]Traceback (most recent call last):
File "/content/mamba-chat/train_mamba.py", line 60, in
run(args)
File "/content/mamba-chat/train_mamba.py", line 45, in run
trainer.train()
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1555, in train
return inner_training_loop(
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1860, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2725, in training_step
loss = self.compute_loss(model, inputs)
File "/content/mamba-chat/trainer/mamba_trainer.py", line 9, in compute_loss
lm_logits = model(input_ids).logits
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 221, in forward
hidden_states = self.backbone(input_ids, inference_params=inference_params)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 152, in forward
hidden_states, residual = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/mamba_simple.py", line 341, in forward
hidden_states, residual = fused_add_norm_fn(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
y, mean, rstd, residual_out = _layer_norm_fwd(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 100, in run
timings = {config: self._bench(*args, config=config, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 100, in
timings = {config: self._bench(*args, config=config, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 83, in _bench
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 104, in do_bench
fn()
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 81, in kernel_call
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
File "", line 63, in _layer_norm_fwd_1pass_kernel
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 476, in compile
next_module = compile_kernel(module)
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 351, in
lambda src: ptx_to_cubin(src, arch))
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 150, in ptx_to_cubin
return compile_ptx_to_cubin(ptx, ptxas, arch)
RuntimeError: Internal Triton PTX codegen error:
ptxas /tmp/compile-ptx-src-eed3b0, line 984; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 984; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 986; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 986; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 988; error : Feature '.bf16' requires .target sm_80 or higher
....
ptxas /tmp/compile-ptx-src-eed3b0, line 2801; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2803; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2803; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2805; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2805; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2807; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-eed3b0, line 2807; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas fatal : Ptx assembly aborted due to errors

0% 0/750 [00:02<?, ?it/s]`

I don't know if it's too late, but changing the dtype from bfloat16 to just float16 here does seem to fix the problem on colab. bfloats are not supported on T4, I think.

Edit: Setting the dtype to float16 seems to have broken the loss, but after setting it to float32 everything worked fine.

I'll give this a try later. Thanks

It works. I'm still having to, after installing the requirements, to add the below in order to run it on colab, but I'm pretty sure that's a separate problem
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia

Also under training args adding fp16=True, results in faster training in colab

It works. I'm still having to, after installing the requirements, to add the below in order to run it on colab, but I'm pretty sure that's a separate problem !export LC_ALL="en_US.UTF-8" !export LD_LIBRARY_PATH="/usr/lib64-nvidia" !export LIBRARY_PATH="/usr/local/cuda/lib64/stubs" !ldconfig /usr/lib64-nvidia

ya that looks like a colab issue?

It is. Honestly I got that solution from here. I haven't looked up why it works yet.

It is. Honestly I got that solution from here. I haven't looked up why it works yet.

Found an issue open over at pytorch, looks like a version issue with triton and pytorch(?). Don't think the maintainers will be happy with us discussing that here.