f90/Wave-U-Net-Pytorch

Support for Apple Metal (MPS) backend

Archie3d opened this issue · 1 comments

Please add support for the MPS backend as you do for cuda:

if torch.backends.mps.is_available():
    mps = torch.device("mps")
    model = model_utils.DataParallel(model)
    model.to(mps)

# ... and so on...
x = x.to(mps)

Has anyone looked into implementing MPS for this yet?