fakufaku/fast_bss_eval

Compatibility problem with torch >= 1.8.0 when torch_complex package is not installed

adriengossesonos opened this issue · 2 comments

Hello,
I noticed that when trying to use the package (version 0.1.3), I get some compatibility issues when using torch.Tensor inputs for the method bss_eval_sources because I did not have the torch_complex package installed. However, the torch_complex package shouldn't be required in this case since I use torch 1.10.2.

This happens because in the __init__.py file, the variable has_torch is not set to True

try:
    import torch as pt
    has_torch = True

    from . import torch as torch     # --> this line fails
    from .torch import sdr_pit_loss, si_sdr_pit_loss   
except ImportError:
    has_torch = False

    # dummy pytorch module
    class pt:
        class Tensor:
            def __init__(self):
                pass

    # dummy torch submodule
    class torch:
        bss_eval_sources = None
        sdr = None
        sdr_loss = None

from . import numpy as numpy

Apparently this happens because the line that fails tries to import the file torch/compatibility.py :

try:
    from packaging.version import Version
except [ImportError, ModuleNotFoundError]:
    from distutils.version import LooseVersion as Version

from torch_complex import ComplexTensor # --> this line causes the problem when torch_complex is not installed 

import torch

is_torch_1_8_plus = Version(torch.__version__) >= Version("1.8.0")

if not is_torch_1_8_plus:
    try:
        import torch_complex
    except ImportError:
        raise ImportError(
            "When using torch<=1.7, the package torch_complex is required."
            " Install it as `pip install torch_complex`"
        )

If I understand correctly, the fix would simply be to do the following :

try:
    from packaging.version import Version
except [ImportError, ModuleNotFoundError]:
    from distutils.version import LooseVersion as Version

import torch

is_torch_1_8_plus = Version(torch.__version__) >= Version("1.8.0")

if not is_torch_1_8_plus:
    try:
        from torch_complex import ComplexTensor 
    except ImportError:
        raise ImportError(
            "When using torch<=1.7, the package torch_complex is required."
            " Install it as `pip install torch_complex`"
        )

Thank you very much for the detailed bug report! I also had some issues with torch_complex import failing, but had not looked into it! The bugfix is simple enough: I'll try it right now!

@adriengossesonos , thanks to you I fixed the bug and bumped to 1.4! Please let me know if there are more problems!