thoglu/jammy_flows

NaN during training

chrhck opened this issue · 5 comments

this gives NaN after a few epochs:

pdf = jammy_flows.pdf("e1+s2", "ggg+v", conditional_input_dim=4, hidden_mlp_dims_sub_pdfs="64-128-64")
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In [20], line 20
     17 w = data[:, 3] *data.shape[0]/ sum(data[:, 3])
     18 labels = labels.to(device)
---> 20 log_pdf, _, _ = pdf(inp, conditional_input=labels) 
     21 neg_log_loss = (-log_pdf * w).mean()
     22 neg_log_loss.backward()

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.10/site-packages/jammy_flows/flows.py:975, in pdf.forward(self, x, conditional_input, amortization_parameters, force_embedding_coordinates, force_intrinsic_coordinates)
    968 tot_log_det = torch.zeros(x.shape[0]).type_as(x)
    970 base_pos, tot_log_det=self.all_layer_inverse(x, tot_log_det, conditional_input, amortization_parameters=amortization_parameters, force_embedding_coordinates=force_embedding_coordinates, force_intrinsic_coordinates=force_intrinsic_coordinates)
    972 log_pdf = torch.distributions.MultivariateNormal(
    973     torch.zeros_like(base_pos).to(x),
    974     covariance_matrix=torch.eye(self.total_base_dim).type_as(x).to(x),
--> 975 ).log_prob(base_pos)
    978 return log_pdf + tot_log_det, log_pdf, base_pos

File ~/.local/lib/python3.10/site-packages/torch/distributions/multivariate_normal.py:210, in MultivariateNormal.log_prob(self, value)
    208 def log_prob(self, value):
    209     if self._validate_args:
--> 210         self._validate_sample(value)
    211     diff = value - self.loc
    212     M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)

File ~/.local/lib/python3.10/site-packages/torch/distributions/distribution.py:293, in Distribution._validate_sample(self, value)
    291 valid = support.check(value)
    292 if not valid.all():
--> 293     raise ValueError(
    294         "Expected value argument "
    295         f"({type(value).__name__} of shape {tuple(value.shape)}) "
    296         f"to be within the support ({repr(support)}) "
    297         f"of the distribution {repr(self)}, "
    298         f"but found invalid values:\n{value}"
    299     )

ValueError: Expected value argument (Tensor of shape (200, 3)) to be within the support (IndependentConstraint(Real(), 1)) of the distribution MultivariateNormal(loc: torch.Size([200, 3]), covariance_matrix: torch.Size([200, 3, 3])), but found invalid values:
tensor([[    nan,  0.1067, -2.2454],
        [    nan, -0.4479, -1.3993],
        [    nan,  1.1414, -0.2839],
        [    nan,  0.2720, -0.9769],
        [    nan,  0.4975,  0.5888],
        [    nan,  0.3729,  0.7307],
        [    nan, -0.5783, -0.6921],
        [    nan, -0.0498,  1.1616],
        [    nan,  1.1821, -1.6822],
        [    nan,  1.7657,  1.9744],
        [    nan, -1.0785,  1.1321],
....

This typically comes about in Gaussianization flows when gradients are too large and one ends up in parameter space in regions where certain parametrizations do not work properly anymore.

A typical case where this could happen is when the conditional_input that is used is not normalized to be within the sigmoidal region, and then uses a too small batch size or too large learning rate.

I guess three things could potentially help here:

  1. hidden structure complexity (e.g. from 64-128-64 to 256)
  2. Make sure conditional_input is normalized to be within -1 to 1 mostly.
  3. Also call pdf_obj.init_params(data=label_batch) to initialize the gaussianization flow parameters to be such, that the PDF follows the label distribution roughly.

Let me know if some of those helped.

Step 3 isn't clear to me. When should this be called? When I try I get that the pdf has no init attribute

Sorry, I meant init_params, instead of init. Is now corrected in the other comment.

@thoglu I think, I also have the same problem with:
pdf = jammy_flows.pdf("e1", "gg", conditional_input_dim=5, amortization_mlp_dims='64')

` 1%|▌ | 70/12585 [00:00<01:56, 107.81it/s]

ValueError Traceback (most recent call last)
Cell In[30], line 8
5 optimizer.zero_grad()
7 # Calculate log PDF
----> 8 log_pdf, _, _ = pdf(target.unsqueeze(1), conditional_input=conditional_input)
10 # Negative log-likelihood loss
11 neg_log_loss = -log_pdf.mean()

File /scratch/users/baranh/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)

File /scratch/users/baranh/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None

File /scratch/users/baranh/anaconda3/lib/python3.11/site-packages/jammy_flows/main/default.py:1084, in pdf.forward(self, x, conditional_input, amortization_parameters, force_embedding_coordinates, force_intrinsic_coordinates)
1078 base_pos, tot_log_det=self.all_layer_inverse(x, tot_log_det, conditional_input, amortization_parameters=amortization_parameters, force_embedding_coordinates=force_embedding_coordinates, force_intrinsic_coordinates=force_intrinsic_coordinates)
1080 ## must faster calculation based on std normal
1081 other=torch.distributions.Normal(
1082 0.0,
1083 1.0,
-> 1084 ).log_prob(base_pos)
1086 log_pdf=other.sum(dim=-1)
1088 return log_pdf + tot_log_det, log_pdf, base_pos

File /scratch/users/baranh/anaconda3/lib/python3.11/site-packages/torch/distributions/normal.py:79, in Normal.log_prob(self, value)
77 def log_prob(self, value):
78 if self._validate_args:
---> 79 self._validate_sample(value)
80 # compute the variance
81 var = self.scale**2

File /scratch/users/baranh/anaconda3/lib/python3.11/site-packages/torch/distributions/distribution.py:312, in Distribution._validate_sample(self, value)
310 valid = support.check(value)
311 if not valid.all():
--> 312 raise ValueError(
313 "Expected value argument "
314 f"({type(value).name} of shape {tuple(value.shape)}) "
315 f"to be within the support ({repr(support)}) "
316 f"of the distribution {repr(self)}, "
317 f"but found invalid values:\n{value}"
318 )

ValueError: Expected value argument (Tensor of shape (64, 1)) to be within the support (Real()) of the distribution Normal(loc: 0.0, scale: 1.0), but found invalid values:
tensor([[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan],
[nan]], grad_fn=)`

When I have pdf = jammy_flows.pdf("e1", "g", conditional_input_dim=5, amortization_mlp_dims='64'), I don't face this problem, however when I add layers, I get this error. I tried both minmax and standardised conditional_inputs, but still I get this error.

@Baran-phys Maybe try the following:

  1. IMPORTANT Make sure you use double precision (setting both the model to double via pdf.double(), and then feed the input to the model as double). Especially for Gaussianization flows with multiple layers this is important.

  2. IMPORTANT use init_params: pdf_obj.init_params(data=label_batch) before the optimization, to initialize the pdf in the region of the labels.. that should solve the problem of scaling the target dimension ( I have used it to describe PDFs over a target space that spans from -100 s to +100s without normalizing the labels).

  3. Make sure the target distribution is not degenerate (fitting a delta peak in target space for example).

  4. Make sure you have no large outliers in your labels / conditional_input to create very large gradients (or solve that issue with gradient clipping if necessary).

  5. It can be stabilizing to switch off "fit_normalization" in the gaussianization flow and restrict the width of the sigmoids (see config at end). I am using this in all settings, although it makes the flow overall a little less flexible in principle.. but not fitting for normalization is actually the default in the original gaussianization flow paper.

  6. It can be very helpful to add an affine layer ("t") at the end. If you do that, it is also wise for > 1d to change cov_type parameter. This was used in e.g. in https://arxiv.org/pdf/2309.16380 to fit the 1-d energy PDF model.

for 1-d:

opt_dict=dict()

opt_dict["g"]=dict()
opt_dict["g"]["fit_normalization"]=0
opt_dict["g"]["upper_bound_for_widths"]=1.0
opt_dict["g"]["lower_bound_for_widths"]=0.01

pdf=jammy_flows.pdf("e1", "gggt", options_overwrite=opt_dict)

for n-d:

opt_dict=dict()
opt_dict["t"]=dict()
opt_dict["t"]["cov_type"]="full"
opt_dict["g"]=dict()
opt_dict["g"]["fit_normalization"]=0
opt_dict["g"]["upper_bound_for_widths"]=1.0
opt_dict["g"]["lower_bound_for_widths"]=0.01

pdf=jammy_flows.pdf("e3", "gggt", options_overwrite=opt_dict)
  • the double precision and init_params should hopefully solve you NAN problem or make it such that it occurs only once per epoch or so. If that is the case, you can also catch NaN's manually in the training loop and set the gradient for that iteration to 0.

I should maybe add a best practices section in the docs, these issues arise always for anyone using the tool.