Compatibility problem with torch >= 1.8.0 when torch_complex package is not installed
adriengossesonos opened this issue · 2 comments
Hello,
I noticed that when trying to use the package (version 0.1.3
), I get some compatibility issues when using torch.Tensor inputs for the method bss_eval_sources
because I did not have the torch_complex
package installed. However, the torch_complex package
shouldn't be required in this case since I use torch 1.10.2
.
This happens because in the __init__.py
file, the variable has_torch
is not set to True
try:
import torch as pt
has_torch = True
from . import torch as torch # --> this line fails
from .torch import sdr_pit_loss, si_sdr_pit_loss
except ImportError:
has_torch = False
# dummy pytorch module
class pt:
class Tensor:
def __init__(self):
pass
# dummy torch submodule
class torch:
bss_eval_sources = None
sdr = None
sdr_loss = None
from . import numpy as numpy
Apparently this happens because the line that fails tries to import the file torch/compatibility.py
:
try:
from packaging.version import Version
except [ImportError, ModuleNotFoundError]:
from distutils.version import LooseVersion as Version
from torch_complex import ComplexTensor # --> this line causes the problem when torch_complex is not installed
import torch
is_torch_1_8_plus = Version(torch.__version__) >= Version("1.8.0")
if not is_torch_1_8_plus:
try:
import torch_complex
except ImportError:
raise ImportError(
"When using torch<=1.7, the package torch_complex is required."
" Install it as `pip install torch_complex`"
)
If I understand correctly, the fix would simply be to do the following :
try:
from packaging.version import Version
except [ImportError, ModuleNotFoundError]:
from distutils.version import LooseVersion as Version
import torch
is_torch_1_8_plus = Version(torch.__version__) >= Version("1.8.0")
if not is_torch_1_8_plus:
try:
from torch_complex import ComplexTensor
except ImportError:
raise ImportError(
"When using torch<=1.7, the package torch_complex is required."
" Install it as `pip install torch_complex`"
)
Thank you very much for the detailed bug report! I also had some issues with torch_complex import failing, but had not looked into it! The bugfix is simple enough: I'll try it right now!
@adriengossesonos , thanks to you I fixed the bug and bumped to 1.4! Please let me know if there are more problems!