carpedm20/ENAS-pytorch

Controller.encoder seems much too large

philtomson opened this issue · 2 comments

From the Controller constructor:

class Controller(torch.nn.Module):
    def __init__(self, args):
        torch.nn.Module.__init__(self)
        self.args = args
        self.forward_evals = 0
        if self.args.network_type == 'rnn':
            # NOTE(brendan): `num_tokens` here is just the activation function
            # for every even step,
            self.num_tokens = [len(args.shared_rnn_activations)]
            for idx in range(self.args.num_blocks):
                self.num_tokens += [idx + 1,
                                    len(args.shared_rnn_activations)]
            self.func_names = args.shared_rnn_activations
        elif self.args.network_type == 'cnn':
            self.num_tokens = [len(args.shared_cnn_types),
                               self.args.num_blocks]
            self.func_names = args.shared_cnn_types

        num_total_tokens = sum(self.num_tokens) #why sum the tokens here?
        #Shouldn't this be: num_total_tokens = len(args.shared_rnn_activations)+self.args.num_blocks
        self.encoder = torch.nn.Embedding(num_total_tokens,
                                          args.controller_hid)

It seems like num_total_tokens doesn't need to be summation of the self.num_tokens - in the case where self.args.num_blocks = 6, that number is 49. Yet from what I can tell, the largest number you can ever get where the embedding is used in Controller.forward() is going to be len(args.shared_rnn_activations)+self.args.num_blocks (in this case that would be 10)

You can assume the same activation in a different place to have the same semantics (embedding) or not. I assumed it's different because activation in different locations may have a separate role.

I've been running through this section of the code in the debugger trying to understand what's going on... when mode is 0 (the case when activation func is being selected) then sum(self.num_tokens[:mode]) is 0. So the line:

inputs = utils.get_variable(
                action[:, 0] + sum(self.num_tokens[:mode]),
                requires_grad=False)

Is always just the action[:,0] component which is a value from 0 to 3 (the size of the activation function list)

And when mode is 1, the sum(self.num_tokens[:mode]) is always 4 - so not sure who you can get anything higher than len(args.shared_rnn_activations)+self.args.num_blocks here. mode can only take on values of 0 or 1. Either I'm missing something or maybe it's a bug?