"N"-type flows don't work on GPUs
chrhck opened this issue · 3 comments
chrhck commented
pdf = jammy_flows.pdf("e1+s2", "gggg+n", conditional_input_dim=4, hidden_mlp_dims_sub_pdfs="128-128").to("cuda")
inp = inp.to("cuda")
cond = cond.to("cuda")
pdf(inp, conditional_input=cond)
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.local/lib/python3.10/site-packages/jammy_flows/flows.py:970, in pdf.forward(self, x, conditional_input, amortization_parameters, force_embedding_coordinates, force_intrinsic_coordinates)
966 assert(x.shape[0]==conditional_input.shape[0]), "Evaluating input x and condititional input shape must be similar!"
968 tot_log_det = torch.zeros(x.shape[0]).type_as(x)
--> 970 base_pos, tot_log_det=self.all_layer_inverse(x, tot_log_det, conditional_input, amortization_parameters=amortization_parameters, force_embedding_coordinates=force_embedding_coordinates, force_intrinsic_coordinates=force_intrinsic_coordinates)
972 log_pdf = torch.distributions.MultivariateNormal(
973 torch.zeros_like(base_pos).to(x),
974 covariance_matrix=torch.eye(self.total_base_dim).type_as(x).to(x),
...
104 #extra_input_counter+=self.num_householder_params
105 else:
106 mat_pars=mat_pars.repeat(x.shape[0],1,1)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
thoglu commented
Thanks, will fix this asap .. there might be some other errors like this in different flows as I have only worked on the cpu thus far ^^ - more GPU stress testing is very desired. Will update the tests aswell to catch all of those.
chrhck commented
works!