bowang-lab/U-Mamba

Selective Scan Import Issue

MarioPaps opened this issue · 4 comments

Hello,

I made a new environment and installed Pytorch 2.1 with cuda 11.8, alongside the recommended causal-conv1d and mamba-ssm.
However, the model does not train because of 'selective_scan'. Could you help me with this?

This is the full error:
Traceback (most recent call last):
File "/rds/general/user/kp4718/home/code/MedMamba/trainpynew.py", line 129, in
main()
File "/rds/general/user/kp4718/home/code/MedMamba/trainpynew.py", line 88, in main
outputs = net(images)
^^^^^^^^^^^
File "/rds/general/user/kp4718/home/anaconda3/envs/cleanmamba/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/rds/general/user/kp4718/home/anaconda3/envs/cleanmamba/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/rds/general/user/kp4718/home/code/MedMamba/MedMamba.py", line 734, in forward
x = self.forward_backbone(x)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/rds/general/user/kp4718/home/code/MedMamba/MedMamba.py", line 730, in forward_backbone
x = layer(x)
^^^^^^^^
File "/rds/general/user/kp4718/home/anaconda3/envs/cleanmamba/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/rds/general/user/kp4718/home/anaconda3/envs/cleanmamba/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/rds/general/user/kp4718/home/code/MedMamba/MedMamba.py", line 570, in forward
x = blk(x)
^^^^^^
File "/rds/general/user/kp4718/home/anaconda3/envs/cleanmamba/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/rds/general/user/kp4718/home/anaconda3/envs/cleanmamba/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/rds/general/user/kp4718/home/code/MedMamba/MedMamba.py", line 503, in forward
x = input_right + self.drop_path(self.self_attention(self.ln_1(input_right)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/rds/general/user/kp4718/home/anaconda3/envs/cleanmamba/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/rds/general/user/kp4718/home/anaconda3/envs/cleanmamba/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/rds/general/user/kp4718/home/code/MedMamba/MedMamba.py", line 464, in forward
y1, y2, y3, y4 = self.forward_core(x)
^^^^^^^^^^^^^^^^^^^^
File "/rds/general/user/kp4718/home/code/MedMamba/MedMamba.py", line 379, in forward_corev0
self.selective_scan = selective_scan_fn
^^^^^^^^^^^^^^^^^
NameError: name 'selective_scan_fn' is not defined

could you please test mamba first?

import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

Hi, thanks for the prompt reply. This import failed.

I believe the issue is with causal_conv1d and CUDA.
Are we supposed to use pytorch 1.13 instead of 2.x?

I think I have the same problem.
When training the U-Mamba Enc model on the Brats2021 dataset, the issue arises when trying to perform a forward pass through the Mamba layer in the network architecture. It appears that the arguments passed to causal_conv1d_fwd() do not match the expected types or structure. Here is the error:

This is the configuration used by this training:
Configuration name: 3d_fullres
 {'data_identifier': 'nnUNetPlans_3d_fullres', 'preprocessor_name': 'DefaultPreprocessor', 'batch_size': 2, 'patch_size': [128, 128, 128], 'median_image_size_in_voxels': [140.0, 171.0, 137.0], 'spacing': [1.0, 1.0, 1.0], 'normalization_schemes': ['ZScoreNormalization', 'ZScoreNormalization', 'ZScoreNormalization', 'ZScoreNormalization'], 'use_mask_for_norm': [True, True, True, True], 'UNet_class_name': 'PlainConvUNet', 'UNet_base_num_features': 32, 'n_conv_per_stage_encoder': [2, 2, 2, 1, 1, 1], 'n_conv_per_stage_decoder': [2, 2, 2, 1, 1], 'num_pool_per_axis': [5, 5, 5], 'pool_op_kernel_sizes': [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 'conv_kernel_sizes': [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], 'unet_max_num_features': 320, 'resampling_fn_data': 'resample_data_or_seg_to_shape', 'resampling_fn_seg': 'resample_data_or_seg_to_shape', 'resampling_fn_data_kwargs': {'is_seg': False, 'order': 3, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_seg_kwargs': {'is_seg': True, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_probabilities': 'resample_data_or_seg_to_shape', 'resampling_fn_probabilities_kwargs': {'is_seg': False, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'batch_dice': False} 

These are the global plan.json settings:
 {'dataset_name': 'Dataset137_BraTS2021', 'plans_name': 'nnUNetPlans', 'original_median_spacing_after_transp': [1.0, 1.0, 1.0], 'original_median_shape_after_transp': [140, 171, 137], 'image_reader_writer': 'SimpleITKIO', 'transpose_forward': [0, 1, 2], 'transpose_backward': [0, 1, 2], 'experiment_planner_used': 'ExperimentPlanner', 'label_manager': 'LabelManager', 'foreground_intensity_properties_per_channel': {'0': {'max': 95242.25, 'mean': 871.816650390625, 'median': 407.0, 'min': 0.10992202162742615, 'percentile_00_5': 55.0, 'percentile_99_5': 5825.0, 'std': 2023.5313720703125}, '1': {'max': 1905559.25, 'mean': 1698.2144775390625, 'median': 552.0, 'min': 0.0, 'percentile_00_5': 47.0, 'percentile_99_5': 8322.0, 'std': 18787.4140625}, '2': {'max': 4438107.0, 'mean': 2141.349365234375, 'median': 738.0, 'min': 0.0, 'percentile_00_5': 110.0, 'percentile_99_5': 10396.0, 'std': 45159.37890625}, '3': {'max': 580014.3125, 'mean': 995.436279296875, 'median': 512.3143920898438, 'min': 0.0, 'percentile_00_5': 108.0, 'percentile_99_5': 11925.0, 'std': 4629.87939453125}}} 

2024-03-29 22:38:59.103773: unpacking dataset...
2024-03-29 22:38:59.715880: unpacking done...
2024-03-29 22:38:59.716621: do_dummy_2d_data_aug: False
2024-03-29 22:38:59.732308: Unable to plot network architecture:
2024-03-29 22:38:59.732390: No module named 'hiddenlayer'
2024-03-29 22:38:59.737705: 
2024-03-29 22:38:59.737803: Epoch 0
2024-03-29 22:38:59.737955: Current learning rate: 0.01
using pin_memory on device 0
Traceback (most recent call last):
  File "/usr/local/bin/nnUNetv2_train", line 33, in <module>
    sys.exit(load_entry_point('nnunetv2', 'console_scripts', 'nnUNetv2_train')())
  File "/content/U-Mamba/umamba/nnunetv2/run/run_training.py", line 268, in run_training_entry
    run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
  File "/content/U-Mamba/umamba/nnunetv2/run/run_training.py", line 204, in run_training
    nnunet_trainer.run_training()
  File "/content/U-Mamba/umamba/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py", line 1258, in run_training
    train_outputs.append(self.train_step(next(self.dataloader_train)))
  File "/content/U-Mamba/umamba/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py", line 900, in train_step
    output = self.network(data)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/U-Mamba/umamba/nnunetv2/nets/UMambaEnc_3d.py", line 478, in forward
    skips = self.encoder(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/U-Mamba/umamba/nnunetv2/nets/UMambaEnc_3d.py", line 287, in forward
    x = self.mamba_layers[s](x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
  File "/content/U-Mamba/umamba/nnunetv2/nets/UMambaEnc_3d.py", line 89, in forward
    out = self.forward_patch_token(x)
  File "/content/U-Mamba/umamba/nnunetv2/nets/UMambaEnc_3d.py", line 63, in forward_patch_token
    x_mamba = self.mamba(x_norm)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/mamba_simple.py", line 146, in forward
    out = mamba_inner_fn(
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/selective_scan_interface.py", line 317, in mamba_inner_fn
    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/amp/autocast_mode.py", line 98, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/selective_scan_interface.py", line 187, in forward
    conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor

Hi all,

Sorry for the inconvenience. The code only supports torch 2.0 series but the issue can be solved by installing the latest causal-conv1d>=1.2.0