kjunelee/MetaOptNet

TypeError: btrisolve() takes 3 positional arguments but 4 were given

robotzheng opened this issue · 1 comments

Loading mini ImageNet dataset - phase train
Loading mini ImageNet dataset - phase val
using gpu: 0,1,2,3
{'num_epoch': 60, 'save_epoch': 10, 'train_shot': 5, 'val_shot': 5, 'train_query': 6, 'val_episode': 2000, 'val_query': 15, 'train_way': 5, 'test_way': 5, 'save_path': './experiments/miniImageNet_MetaOptNet_SVM', 'gpu': '0,1,2,3', 'network': 'ResNet', 'head': 'SVM', 'dataset': 'miniImageNet', 'episodes_per_batch': 8, 'eps': 0.1}
Train Epoch: 1 Learning Rate: 0.1000
0%| | 0/1000 [00:00<?, ?it/s]/usr/local/python3/lib/python3.6/site-packages/qpth/solvers/pdipm/batch.py:14: UserWarning: torch.btrifact is deprecated in favour of torch.lu and will be removed in the next release. Please use torch.lu instead.
return x.btrifact(pivot=not x.is_cuda)
0%| | 0/1000 [00:08<?, ?it/s]
Traceback (most recent call last):
File "train.py", line 207, in
logit_query = cls_head(emb_query, emb_support, labels_support, opt.train_way, opt.train_shot)
File "/usr/local/python3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call
result = self.forward(*input, **kwargs)
File "/home/zzt/MetaOptNet/models/classification_heads.py", line 550, in forward
return self.scale * self.head(query, support, support_labels, n_way, n_shot, **kwargs)
File "/home/zzt/MetaOptNet/models/classification_heads.py", line 396, in MetaOptNetHead_SVM_CS
qp_sol = QPFunction(verbose=False, maxIter=maxIter)(G, e.detach(), C.detach(), h.detach(), A.detach(), b.detach())
File "/usr/local/python3/lib/python3.6/site-packages/qpth/qp.py", line 91, in forward
self.Q_LU, self.S_LU, self.R = pdipm_b.pre_factor_kkt(Q, G, A)
File "/usr/local/python3/lib/python3.6/site-packages/qpth/solvers/pdipm/batch.py", line 401, in pre_factor_kkt
G_invQ_GT = torch.bmm(G, G.transpose(1, 2).btrisolve(*Q_LU))
TypeError: btrisolve() takes 3 positional arguments but 4 were given

This is the same issue as #10. Upgrading to latest PyTorch and qpth should help. Let me know if it helps.