ayushkarnawat/profit

Attempting to train pytorch EmbeddedGCN model

Closed this issue · 2 comments

I am trying to run a forward pass through the EmbeddedGCN model using

from torch.utils.data import DataLoader
from profit.models.pytorch.egcn import EmbeddedGCN
from profit.utils.data_utils.serializers import LMDBSerializer

# Load data
data = LMDBSerializer.load('data/3gb1/processed/egcn_fitness/tertiary3.mdb', as_numpy=False)

# Init model
num_atoms, num_feats = data[0]['arr_0'].shape
model = EmbeddedGCN(num_atoms, num_feats, num_outputs=1, num_layers=1, units_conv=16, units_dense=16)

# Batch dataset
loader = DataLoader(data, batch_size=2)
for batch in loader:
    atoms, adjms, dists, labels = batch.values()
    out = model([atoms, adjms, dists])
    print(out)

Current behavior

Traceback (most recent call last):
  File "examples/3gb1/train.py", line 63, in <module>
    out = model([atoms, adjms, dists])
  File "/Users/ayushkarnawat/miniconda3/envs/chem/lib/python3.7/site-packages/torch/nn/modules/module.py", line 540, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/ayushkarnawat/Documents/dev/python_workspace/profit/profit/models/pytorch/egcn.py", line 1130, in forward
    sc_out = self.relu(self.s_dense(sc))
  File "/Users/ayushkarnawat/miniconda3/envs/chem/lib/python3.7/site-packages/torch/nn/modules/module.py", line 540, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/ayushkarnawat/miniconda3/envs/chem/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 87, in forward
    return F.linear(input, self.weight, self.bias)
  File "/Users/ayushkarnawat/miniconda3/envs/chem/lib/python3.7/site-packages/torch/nn/functional.py", line 1375, in linear
    if input.dim() == 2 and bias is not None:
AttributeError: 'torch.return_types.max' object has no attribute 'dim'

Expected behavior

The model runs a forward pass and outputs a real value.

After a bit of testing to determine where the NaN values are being introduced into the network, it seems that the issue is that there are NaNs in the preprocessed dataset. This means, as the data is getting processed by a forward pass through the network, the output result of the NaN values cannot be properly computed (specifically by the ReLU activation function above). Hence, there is an error output in that portion of the code.

We still have to determine where exactly the NaN values are being introduced in the preprocessed dataset (aka during when in the computation of the mol features are there NaN values).

The issue is related to the torch.return_types.max object. To obtain a Tensor, simply call .values on the object.