HuwCampbell/grenade

Better Wengert tape storage

HuwCampbell opened this issue · 1 comments

At the moment during training we give the input for reverse mode at the layer level.

This is pretty good and efficient for most layers, but for some layers like LSTM it's not as granular as it could be.

Maybe Layer should be something like

class UpdateLayer x => Layer x (i :: Shape) (o :: Shape) where
  -- | The Wengert tape for this layer. Includes all that is required
  --   to generate the back propagated gradients efficiently. As a
  --   default, `S i` is fine.
  type Tape x i o :: *
  -- | Used in training and scoring. Take the input from the previous
  --   layer, and give the output from this layer.
  runForwards    :: x -> S i -> (Tape x i o, S o)
  -- | Back propagate a step. Takes the current layer, the input that the
  --   layer gave from the input and the back propagated derivatives from
  --   the layer above.
  --   Returns the gradient layer and the derivatives to push back further.
  runBackwards   :: x -> Tape x i o -> S o -> (Gradient x, S i)

My Fuse layer, which I haven't actually used, should be much nicer with this formulation.