Controller.encoder seems much too large
philtomson opened this issue · 2 comments
From the Controller constructor:
class Controller(torch.nn.Module):
def __init__(self, args):
torch.nn.Module.__init__(self)
self.args = args
self.forward_evals = 0
if self.args.network_type == 'rnn':
# NOTE(brendan): `num_tokens` here is just the activation function
# for every even step,
self.num_tokens = [len(args.shared_rnn_activations)]
for idx in range(self.args.num_blocks):
self.num_tokens += [idx + 1,
len(args.shared_rnn_activations)]
self.func_names = args.shared_rnn_activations
elif self.args.network_type == 'cnn':
self.num_tokens = [len(args.shared_cnn_types),
self.args.num_blocks]
self.func_names = args.shared_cnn_types
num_total_tokens = sum(self.num_tokens) #why sum the tokens here?
#Shouldn't this be: num_total_tokens = len(args.shared_rnn_activations)+self.args.num_blocks
self.encoder = torch.nn.Embedding(num_total_tokens,
args.controller_hid)
It seems like num_total_tokens doesn't need to be summation of the self.num_tokens - in the case where self.args.num_blocks = 6, that number is 49. Yet from what I can tell, the largest number you can ever get where the embedding is used in Controller.forward() is going to be len(args.shared_rnn_activations)+self.args.num_blocks (in this case that would be 10)
You can assume the same activation in a different place to have the same semantics (embedding) or not. I assumed it's different because activation in different locations may have a separate role.
I've been running through this section of the code in the debugger trying to understand what's going on... when mode is 0 (the case when activation func is being selected) then sum(self.num_tokens[:mode]) is 0. So the line:
inputs = utils.get_variable(
action[:, 0] + sum(self.num_tokens[:mode]),
requires_grad=False)
Is always just the action[:,0] component which is a value from 0 to 3 (the size of the activation function list)
And when mode is 1, the sum(self.num_tokens[:mode]) is always 4 - so not sure who you can get anything higher than len(args.shared_rnn_activations)+self.args.num_blocks here. mode can only take on values of 0 or 1. Either I'm missing something or maybe it's a bug?