multi-gpu support
eschmidbauer opened this issue · 4 comments
Thank you for sharing this project. I am wondering if there is, or will be support for multi-GPU inference. Currently I am unable to run the inference on a 3090
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 5.50 GiB (GPU 0; 23.69 GiB total capacity; 15.94 GiB already allocated; 2.85 GiB free; 19.24 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Alternatively, you can enable use_chunk to use chunk-based inference, which can save memory when processing long audio.
e.g., inference(model, audio_file, text, output_file, device, use_chunk=True)
thanks!
but now i have a different error...
>>> inference(model, audio_file, text, output_file, device, use_chunk=True)
Separate audio from [input.wav] with textual query [caller]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/myuser/AudioSep/pipeline.py", line 45, in inference
sep_segment = model.ss_model.chunk_inference(input_dict)
File "/home/myuser/miniconda3/envs/AudioSep/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/myuser/AudioSep/models/resunet.py", line 661, in chunk_inference
'RATE': self.sampling_rate
File "/home/myuser/miniconda3/envs/AudioSep/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1269, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'ResUNet30' object has no attribute 'sampling_rate'
I have fixed the bug, it should work now.
https://github.com/Audio-AGI/AudioSep/blob/main/models/resunet.py#L661
now a different error...
>>> inference(model, audio_file, text, output_file, device, use_chunk=True)
Separate audio from [input.wav] with textual query [caller]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/myser/AudioSep/pipeline.py", line 50, in inference
write(output_file, 32000, np.round(sep_segment * 32767).astype(np.int16))
File "/home/myser/miniconda3/envs/AudioSep/lib/python3.10/site-packages/scipy/io/wavfile.py", line 797, in write
fmt_chunk_data = struct.pack('<HHIIHH', format_tag, channels, fs,
struct.error: ushort format requires 0 <= number <= (0x7fff * 2 + 1)