Add explicit device
neverix opened this issue · 3 comments
Right now, in the model, the individual loading functions control where the encoder and decoder are stored (CPU or GPU) by checking torch.cuda.is_available()
. This doesn't let the user to make the model run on CPU when a GPU is available.
I propose to add a device=
parameter that defaults to None
to decide automatically.
makes sense, could support TPU and MPS too
Ok, now MinDalle
can be initialized with device="cuda"
or device="cpu"
. I got MPS to work but it was actually slower than CPU on my M1, and I had to change a lot of the operations to work for MPS. I reverted them because the original ops were faster
Ok, now
MinDalle
can be initialized withdevice="cuda"
ordevice="cpu"
. I got MPS to work but it was actually slower than CPU on my M1, and I had to change a lot of the operations to work for MPS. I reverted them because the original ops were faster
could you share the diff? I'd like to try it out locally.