cvignac/DiGress

There occurs nan when computing test/X_logp, test/E_logp, and test/y_logp with "SumExceptBatchMetric" in abstract_metrics.py

vincenttsai2015 opened this issue · 0 comments

When running the testing steps, there occurs a warning as follows:

~/.local/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric SumExceptBatchMetric was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)  # noqa: B028

As this warning occurs, I found that self.total_value and self.total_samples are 0 in the end of the test epoch.

SumExceptBatchMetric
self.total_value:  tensor(0., device='cuda:0')
self.total_samples:  tensor(0., device='cuda:0')
SumExceptBatchMetric
self.total_value:  tensor(0., device='cuda:0')
self.total_samples:  tensor(0., device='cuda:0')
SumExceptBatchMetric
self.total_value:  tensor(0., device='cuda:0')
self.total_samples:  tensor(0., device='cuda:0')

I guess that's the reason why test/X_logp, test/E_logp, and test/y_logp become nan.

I'm wondering if it is a problem with the setting of the data batch size or I need to do some modification in abstract_metrics.py? Any suggestion will be greatly appreciated!