Sunshine-Ye/Beta-DARTS

Bug with Beta darts regularization on Vanilla search

Closed this issue · 3 comments

Hi,

The repository does not work if I search a vanilla network:
cd sota/cnn && python train_search.py

Traceback (most recent call last):
  File "train_search.py", line 279, in <module>
    main() 
  File "train_search.py", line 180, in main
    train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, 
  File "train_search.py", line 213, in train
    architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled, epoch=args.epochs)
  File "/home/prabhant/Beta-DARTS/sota/cnn/../../optimizers/darts/architect.py", line 45, in step
    self._backward_step(input_valid, target_valid, epoch)
  File "/home/prabhant/Beta-DARTS/sota/cnn/../../optimizers/darts/architect.py", line 82, in _backward_step
    ssr_normal = self.mlc_loss(self.model._arch_parameters)
  File "/home/prabhant/Beta-DARTS/sota/cnn/../../optimizers/darts/architect.py", line 57, in mlc_loss
    neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
TypeError: logsumexp() received an invalid combination of arguments - got (list, dim=int), but expected one of:
 * (Tensor input, tuple of ints dim, bool keepdim, *, Tensor out)
 * (Tensor input, tuple of names dim, bool keepdim, *, Tensor out)

Hi,
Thanks for your question
It seems that you are searching on the DARTS search space, and thus self.model._arch_parameters is a "list" nor a ''Tensor''.
The example code we give, "python ./nasbench201/train_search.py", will search on the nasbench201 search space, whose self.model._arch_parameters is a ''Tensor''.
If you want to search on the DARTS search space, just use
ssr_reduce = self.mlc_loss(self.model.alphas_reduce)
ssr_normal = self.mlc_loss(self.model.alphas_normal)
loss = self.model._loss(input_valid, target_valid) + weightsssr_reduce + weightsssr_normal
and choose the optimal weights

I am a little confused, add these lines of code where exactly?
I want to search on DARTS search space not NASBENCH here.

Just based on the original DARTS code, more specifically, replace line 40 of https://github.com/quark0/darts/blob/master/cnn/architect.py as aforementioned lines of code