oxpig/CaLM

GPU device management

adrienchaton opened this issue · 3 comments

Hi and thanks again for the great work,

I will fix that in my local copy of the repo but as far as I see, the CaLM class in pretrained.py which is used for inference doesn't support a device argument and setting e.g. model.model.cuda() results in a conflict between the model device and the device of the tensors put in the forward method.

quick fix is to add

device = next(self.parameters()).device
tokens = tokens.to(device)

in the forward method e.g. here https://github.com/oxpig/CaLM/blob/main/calm/model.py#L107

Hi @adrienchaton,

The quick fix you mention is likely correct — however, a better practice would be to instantiate the model in the CPU and then transfer it to the GPU (as well as the tokens). Here is some example code:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ProteinBertModel(args, alphabet)
model = model.to(device)

# ...
tokens = tokens.to(device)
logits = model(tokens)

Best wishes,
Carlos

Thanks for your reply and sharing the code snippet.
This is indeed what I would do in my own codes (+ possibly already instantiating the tensor on GPU device for speed-up and best practice), but the usage you recommend in your readme for inference with pre-trained model isn't this.

I find unfortunate that the proposed usage doesn't allow using GPU, but the great work and right to ignore this are yours.