Attempting to train pytorch EmbeddedGCN model
Closed this issue · 2 comments
I am trying to run a forward pass through the EmbeddedGCN model using
from torch.utils.data import DataLoader
from profit.models.pytorch.egcn import EmbeddedGCN
from profit.utils.data_utils.serializers import LMDBSerializer
# Load data
data = LMDBSerializer.load('data/3gb1/processed/egcn_fitness/tertiary3.mdb', as_numpy=False)
# Init model
num_atoms, num_feats = data[0]['arr_0'].shape
model = EmbeddedGCN(num_atoms, num_feats, num_outputs=1, num_layers=1, units_conv=16, units_dense=16)
# Batch dataset
loader = DataLoader(data, batch_size=2)
for batch in loader:
atoms, adjms, dists, labels = batch.values()
out = model([atoms, adjms, dists])
print(out)
Current behavior
Traceback (most recent call last):
File "examples/3gb1/train.py", line 63, in <module>
out = model([atoms, adjms, dists])
File "/Users/ayushkarnawat/miniconda3/envs/chem/lib/python3.7/site-packages/torch/nn/modules/module.py", line 540, in __call__
result = self.forward(*input, **kwargs)
File "/Users/ayushkarnawat/Documents/dev/python_workspace/profit/profit/models/pytorch/egcn.py", line 1130, in forward
sc_out = self.relu(self.s_dense(sc))
File "/Users/ayushkarnawat/miniconda3/envs/chem/lib/python3.7/site-packages/torch/nn/modules/module.py", line 540, in __call__
result = self.forward(*input, **kwargs)
File "/Users/ayushkarnawat/miniconda3/envs/chem/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "/Users/ayushkarnawat/miniconda3/envs/chem/lib/python3.7/site-packages/torch/nn/functional.py", line 1375, in linear
if input.dim() == 2 and bias is not None:
AttributeError: 'torch.return_types.max' object has no attribute 'dim'
Expected behavior
The model runs a forward pass and outputs a real value.
After a bit of testing to determine where the NaN values are being introduced into the network, it seems that the issue is that there are NaN
s in the preprocessed dataset. This means, as the data is getting processed by a forward pass through the network, the output result of the NaN
values cannot be properly computed (specifically by the ReLU activation function above). Hence, there is an error output in that portion of the code.
We still have to determine where exactly the NaN
values are being introduced in the preprocessed dataset (aka during when in the computation of the mol features are there NaN
values).
The issue is related to the torch.return_types.max
object. To obtain a Tensor
, simply call .values
on the object.