Bug with Beta darts regularization on Vanilla search
Closed this issue · 3 comments
Hi,
The repository does not work if I search a vanilla network:
cd sota/cnn && python train_search.py
Traceback (most recent call last):
File "train_search.py", line 279, in <module>
main()
File "train_search.py", line 180, in main
train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr,
File "train_search.py", line 213, in train
architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled, epoch=args.epochs)
File "/home/prabhant/Beta-DARTS/sota/cnn/../../optimizers/darts/architect.py", line 45, in step
self._backward_step(input_valid, target_valid, epoch)
File "/home/prabhant/Beta-DARTS/sota/cnn/../../optimizers/darts/architect.py", line 82, in _backward_step
ssr_normal = self.mlc_loss(self.model._arch_parameters)
File "/home/prabhant/Beta-DARTS/sota/cnn/../../optimizers/darts/architect.py", line 57, in mlc_loss
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
TypeError: logsumexp() received an invalid combination of arguments - got (list, dim=int), but expected one of:
* (Tensor input, tuple of ints dim, bool keepdim, *, Tensor out)
* (Tensor input, tuple of names dim, bool keepdim, *, Tensor out)
Hi,
Thanks for your question
It seems that you are searching on the DARTS search space, and thus self.model._arch_parameters is a "list" nor a ''Tensor''.
The example code we give, "python ./nasbench201/train_search.py", will search on the nasbench201 search space, whose self.model._arch_parameters is a ''Tensor''.
If you want to search on the DARTS search space, just use
ssr_reduce = self.mlc_loss(self.model.alphas_reduce)
ssr_normal = self.mlc_loss(self.model.alphas_normal)
loss = self.model._loss(input_valid, target_valid) + weightsssr_reduce + weightsssr_normal
and choose the optimal weights
I am a little confused, add these lines of code where exactly?
I want to search on DARTS search space not NASBENCH here.
Just based on the original DARTS code, more specifically, replace line 40 of https://github.com/quark0/darts/blob/master/cnn/architect.py as aforementioned lines of code