cheind/autoregressive

Decouple input and output representations

cheind opened this issue · 2 comments

currently we assume that what we get as input is what we will predict as output (just shifted). However, thinking towards other research areas it might make sense that we rework that more generally:

model
  input: BxIxT
  output: BxQxT

where I might match Q but does not have to. In the training we would then have code like the following

def training_step(batch): 
  inputs = batch['x']
  if 't' in batch:
    targets = batch['t'] # allows us to provide alternative targets
  elif I == Q:
    targets = inputs[..., 1:]
    inputs = inputs[..., :-1]
  else:
    raise ValueError(...)

  logits = self.forward(inputs)
  loss = ce(logits, targets)

what's more is that we need to think about input transformers. Currently we use one-hot encoding hardwired into the model. We might instead consider a differentiable input_transform that is given to the model upon initialization. This would allow us to use differentiable embedding strategies.



dataset -> model -> loss

model:
    input: BxIxT
    input_transform: fn(BxKxT) -> BxIxT
    condition: BxCxT
    output: BxQxT


def loss(inputs, outputs):
    if 't' in batch:
        targets = batch['t'][..., 1:] # BxQxT or BxT
    else:
        targets = batch['x'][..., 1:] # 'x' either BxQxT or BxT
    targets = inputs[..., 1:] # BxT
    logits = outputs[..., :-1]
    preds = sample(logits) # BxIxT
    ce(preds, targets) # BxQxT, BxQxT


def training_step(batch):
    inputs = batch['x'] # BxIxT
    condition = batch['c'] # BxCxT
    logits = self.forward(inputs)
    loss(...)

def forward(self, inputs, cond):
    inputs = self.input_transform(inputs)
    outputs = self.encode(inputs)

would that also work for different model output interpretation such as #24