thoglu/jammy_flows

"N"-type flows don't work on GPUs

chrhck opened this issue · 3 comments

pdf = jammy_flows.pdf("e1+s2", "gggg+n", conditional_input_dim=4, hidden_mlp_dims_sub_pdfs="128-128").to("cuda")
inp = inp.to("cuda")
cond = cond.to("cuda")
pdf(inp, conditional_input=cond)


File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.10/site-packages/jammy_flows/flows.py:970, in pdf.forward(self, x, conditional_input, amortization_parameters, force_embedding_coordinates, force_intrinsic_coordinates)
    966     assert(x.shape[0]==conditional_input.shape[0]), "Evaluating input x and condititional input shape must be similar!"
    968 tot_log_det = torch.zeros(x.shape[0]).type_as(x)
--> 970 base_pos, tot_log_det=self.all_layer_inverse(x, tot_log_det, conditional_input, amortization_parameters=amortization_parameters, force_embedding_coordinates=force_embedding_coordinates, force_intrinsic_coordinates=force_intrinsic_coordinates)
    972 log_pdf = torch.distributions.MultivariateNormal(
    973     torch.zeros_like(base_pos).to(x),
    974     covariance_matrix=torch.eye(self.total_base_dim).type_as(x).to(x),
...
    104     #extra_input_counter+=self.num_householder_params
    105 else:
    106     mat_pars=mat_pars.repeat(x.shape[0],1,1)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Thanks, will fix this asap .. there might be some other errors like this in different flows as I have only worked on the cpu thus far ^^ - more GPU stress testing is very desired. Will update the tests aswell to catch all of those.

This should now be fixed in 725c73a. Can you check if it works now?

works!