theislab/scarches

Error: don't know how to restore data location of torch.storage.UntypedStorage while trying to load HLCA reference

Closed this issue · 1 comments

Hi I'm trying to run your tutorial of mapping data to HLCA and I'm getting an error while loading the reference model

surgery_model = sca.models.SCANVI.load_query_data(
    adata_query,
    ref_model_dir,
    freeze_dropout=True,
)

INFO     File /Users/bapoorva/Desktop/HLCA/HLCA_reference_model/model.pt already downloaded                                           
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/bapoorva/miniconda3/lib/python3.11/site-packages/scvi/model/base/_archesmixin.py", line 89, in load_query_data
    attr_dict, var_names, load_state_dict = _get_loaded_data(
                                            ^^^^^^^^^^^^^^^^^
  File "/Users/bapoorva/miniconda3/lib/python3.11/site-packages/scvi/model/base/_archesmixin.py", line 330, in _get_loaded_data
    attr_dict, var_names, load_state_dict, _ = _load_saved_files(
                                               ^^^^^^^^^^^^^^^^^^
  File "/Users/bapoorva/miniconda3/lib/python3.11/site-packages/scvi/model/base/_utils.py", line 69, in _load_saved_files
    model = torch.load(model_path, map_location=map_location)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bapoorva/miniconda3/lib/python3.11/site-packages/torch/serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bapoorva/miniconda3/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/Users/bapoorva/miniconda3/lib/python3.11/pickle.py", line 1213, in load
    dispatch[key[0]](self)
  File "/Users/bapoorva/miniconda3/lib/python3.11/pickle.py", line 1254, in load_binpersid
    self.append(self.persistent_load(pid))
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bapoorva/miniconda3/lib/python3.11/site-packages/torch/serialization.py", line 1142, in persistent_load
    typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bapoorva/miniconda3/lib/python3.11/site-packages/torch/serialization.py", line 1116, in load_tensor
    wrap_storage=restore_location(storage, location),
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bapoorva/miniconda3/lib/python3.11/site-packages/torch/serialization.py", line 1086, in restore_location
    return default_restore_location(storage, str(map_location))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bapoorva/miniconda3/lib/python3.11/site-packages/torch/serialization.py", line 220, in default_restore_location
    raise RuntimeError("don't know how to restore data location of "
RuntimeError: don't know how to restore data location of torch.storage.UntypedStorage (tagged with mps:0)

Any ideas on what might be causing the error and how to fix it ?

Turns out this error occurs if you are trying to run it in a Mac M1. I tried it in a different linux machine and it ran without errors