kuprel/min-dalle

Add explicit device

neverix opened this issue · 3 comments

Right now, in the model, the individual loading functions control where the encoder and decoder are stored (CPU or GPU) by checking torch.cuda.is_available(). This doesn't let the user to make the model run on CPU when a GPU is available.

I propose to add a device= parameter that defaults to None to decide automatically.

makes sense, could support TPU and MPS too

Ok, now MinDalle can be initialized with device="cuda" or device="cpu". I got MPS to work but it was actually slower than CPU on my M1, and I had to change a lot of the operations to work for MPS. I reverted them because the original ops were faster

Ok, now MinDalle can be initialized with device="cuda" or device="cpu". I got MPS to work but it was actually slower than CPU on my M1, and I had to change a lot of the operations to work for MPS. I reverted them because the original ops were faster

could you share the diff? I'd like to try it out locally.