TypeError during validation
Joey-Xue opened this issue · 0 comments
Joey-Xue commented
I was using torchprotein on my self-created dataset with protein structure and residue-level features for binary classification task.
train_set, valid_set, test_set = toydataset.split()
from torchdrug import tasks
task = tasks.MultipleBinaryClassification(gearnet, graph_construction_model = graph_construction_model, num_mlp_layer = 3,
task = [0], criterion = 'bce', metric = ['auprc@macro', 'auprc@micro', 'f1_max'])
import torch
from torchdrug import core
optimizer = torch.optim.Adam(task.parameters(), lr = 1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
gpus=[0], batch_size= 16)
solver.train(10)
solver.evaluate('valid')
While the training went well, I got error in evaluate step:
03:57:22 average binary cross entropy: 3.63139
03:57:22 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
03:57:22 Evaluate on valid
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[9], line 7
4 solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
5 gpus=[0], batch_size= 16)
6 solver.train(10)
----> 7 solver.evaluate('valid')
File [~/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py:115](https://vscode-remote+ssh-002dremote-002bhpc-002dbuild01-002escidom-002ede.vscode-resource.vscode-cdn.net/home/itg/z.xue/LIME_VEP/GraphST/~/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py:115), in context_decorator..decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File [~/miniconda3/lib/python3.10/site-packages/torchdrug/core/engine.py:222](https://vscode-remote+ssh-002dremote-002bhpc-002dbuild01-002escidom-002ede.vscode-resource.vscode-cdn.net/home/itg/z.xue/LIME_VEP/GraphST/~/miniconda3/lib/python3.10/site-packages/torchdrug/core/engine.py:222), in Engine.evaluate(self, split, log)
220 pred = comm.cat(pred)
221 target = comm.cat(target)
--> 222 metric = model.evaluate(pred, target)
223 if log:
224 self.meter.log(metric, category="%s[/epoch](https://vscode-remote+ssh-002dremote-002bhpc-002dbuild01-002escidom-002ede.vscode-resource.vscode-cdn.net/epoch)" % split)
File [~/miniconda3/lib/python3.10/site-packages/torchdrug/tasks/property_prediction.py:324](https://vscode-remote+ssh-002dremote-002bhpc-002dbuild01-002escidom-002ede.vscode-resource.vscode-cdn.net/home/itg/z.xue/LIME_VEP/GraphST/~/miniconda3/lib/python3.10/site-packages/torchdrug/tasks/property_prediction.py:324), in MultipleBinaryClassification.evaluate(self, pred, target)
322 score = metrics.area_under_prc(pred.flatten(), target.long().flatten())
323 elif _metric == "auprc@macro":
--> 324 score = metrics.variadic_area_under_prc(pred, target.long(), dim=0).mean()
325 elif _metric == "f1_max":
326 score = metrics.f1_max(pred, target)
TypeError: variadic_area_under_prc() got an unexpected keyword argument 'dim'
It seems like something wrong with the metric definition?
Thanks in advance for any help!