Decouple input and output representations
cheind opened this issue · 2 comments
cheind commented
currently we assume that what we get as input is what we will predict as output (just shifted). However, thinking towards other research areas it might make sense that we rework that more generally:
model
input: BxIxT
output: BxQxT
where I might match Q but does not have to. In the training we would then have code like the following
def training_step(batch):
inputs = batch['x']
if 't' in batch:
targets = batch['t'] # allows us to provide alternative targets
elif I == Q:
targets = inputs[..., 1:]
inputs = inputs[..., :-1]
else:
raise ValueError(...)
logits = self.forward(inputs)
loss = ce(logits, targets)
what's more is that we need to think about input transformers. Currently we use one-hot encoding hardwired into the model. We might instead consider a differentiable input_transform that is given to the model upon initialization. This would allow us to use differentiable embedding strategies.
cheind commented
dataset -> model -> loss
model:
input: BxIxT
input_transform: fn(BxKxT) -> BxIxT
condition: BxCxT
output: BxQxT
def loss(inputs, outputs):
if 't' in batch:
targets = batch['t'][..., 1:] # BxQxT or BxT
else:
targets = batch['x'][..., 1:] # 'x' either BxQxT or BxT
targets = inputs[..., 1:] # BxT
logits = outputs[..., :-1]
preds = sample(logits) # BxIxT
ce(preds, targets) # BxQxT, BxQxT
def training_step(batch):
inputs = batch['x'] # BxIxT
condition = batch['c'] # BxCxT
logits = self.forward(inputs)
loss(...)
def forward(self, inputs, cond):
inputs = self.input_transform(inputs)
outputs = self.encode(inputs)