Audio-AGI/AudioSep

multi-gpu support

eschmidbauer opened this issue · 4 comments

Thank you for sharing this project. I am wondering if there is, or will be support for multi-GPU inference. Currently I am unable to run the inference on a 3090
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 5.50 GiB (GPU 0; 23.69 GiB total capacity; 15.94 GiB already allocated; 2.85 GiB free; 19.24 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Alternatively, you can enable use_chunk to use chunk-based inference, which can save memory when processing long audio.

e.g., inference(model, audio_file, text, output_file, device, use_chunk=True)

thanks!
but now i have a different error...

>>> inference(model, audio_file, text, output_file, device, use_chunk=True)
Separate audio from [input.wav] with textual query [caller]

Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
 File "/home/myuser/AudioSep/pipeline.py", line 45, in inference
   sep_segment = model.ss_model.chunk_inference(input_dict)
 File "/home/myuser/miniconda3/envs/AudioSep/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
   return func(*args, **kwargs)
 File "/home/myuser/AudioSep/models/resunet.py", line 661, in chunk_inference
   'RATE': self.sampling_rate
 File "/home/myuser/miniconda3/envs/AudioSep/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1269, in __getattr__
   raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'ResUNet30' object has no attribute 'sampling_rate'

now a different error...

>>> inference(model, audio_file, text, output_file, device, use_chunk=True)
Separate audio from [input.wav] with textual query [caller]

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/myser/AudioSep/pipeline.py", line 50, in inference
    write(output_file, 32000, np.round(sep_segment * 32767).astype(np.int16))
  File "/home/myser/miniconda3/envs/AudioSep/lib/python3.10/site-packages/scipy/io/wavfile.py", line 797, in write
    fmt_chunk_data = struct.pack('<HHIIHH', format_tag, channels, fs,
struct.error: ushort format requires 0 <= number <= (0x7fff * 2 + 1)