DeepGraphLearning/torchdrug

TypeError during validation

Joey-Xue opened this issue · 0 comments

I was using torchprotein on my self-created dataset with protein structure and residue-level features for binary classification task.

train_set, valid_set, test_set = toydataset.split()

from torchdrug import tasks
task = tasks.MultipleBinaryClassification(gearnet, graph_construction_model = graph_construction_model, num_mlp_layer = 3,
                                          task = [0], criterion = 'bce', metric = ['auprc@macro', 'auprc@micro', 'f1_max'])

import torch
from torchdrug import core
optimizer = torch.optim.Adam(task.parameters(), lr = 1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                     gpus=[0], batch_size= 16)
solver.train(10)
solver.evaluate('valid')

While the training went well, I got error in evaluate step:

03:57:22   average binary cross entropy: 3.63139
03:57:22   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
03:57:22   Evaluate on valid
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 7
      4 solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
      5                      gpus=[0], batch_size= 16)
      6 solver.train(10)
----> 7 solver.evaluate('valid')

File [~/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py:115](https://vscode-remote+ssh-002dremote-002bhpc-002dbuild01-002escidom-002ede.vscode-resource.vscode-cdn.net/home/itg/z.xue/LIME_VEP/GraphST/~/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py:115), in context_decorator..decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File [~/miniconda3/lib/python3.10/site-packages/torchdrug/core/engine.py:222](https://vscode-remote+ssh-002dremote-002bhpc-002dbuild01-002escidom-002ede.vscode-resource.vscode-cdn.net/home/itg/z.xue/LIME_VEP/GraphST/~/miniconda3/lib/python3.10/site-packages/torchdrug/core/engine.py:222), in Engine.evaluate(self, split, log)
    220     pred = comm.cat(pred)
    221     target = comm.cat(target)
--> 222 metric = model.evaluate(pred, target)
    223 if log:
    224     self.meter.log(metric, category="%s[/epoch](https://vscode-remote+ssh-002dremote-002bhpc-002dbuild01-002escidom-002ede.vscode-resource.vscode-cdn.net/epoch)" % split)

File [~/miniconda3/lib/python3.10/site-packages/torchdrug/tasks/property_prediction.py:324](https://vscode-remote+ssh-002dremote-002bhpc-002dbuild01-002escidom-002ede.vscode-resource.vscode-cdn.net/home/itg/z.xue/LIME_VEP/GraphST/~/miniconda3/lib/python3.10/site-packages/torchdrug/tasks/property_prediction.py:324), in MultipleBinaryClassification.evaluate(self, pred, target)
    322     score = metrics.area_under_prc(pred.flatten(), target.long().flatten())
    323 elif _metric == "auprc@macro":
--> 324     score = metrics.variadic_area_under_prc(pred, target.long(), dim=0).mean()
    325 elif _metric == "f1_max":
    326     score = metrics.f1_max(pred, target)

TypeError: variadic_area_under_prc() got an unexpected keyword argument 'dim'

It seems like something wrong with the metric definition?
Thanks in advance for any help!