jbloomAus/DecisionTransformerInterpretability

Explore Improvements to DT Training Procedure

jbloomAus opened this issue · 6 comments

/Just wanted to have meta card to track progress on these things with links:

  • LayerNorm (I'll probably only try layernorm pre) (#52)
  • AdamW Optimizer
  • Adding a warmup stage with LambdaLR scheduler or cosine annealing
  • Implement gated MLP's (https://arxiv.org/pdf/2002.05202.pdf). Might need to be done in TransformerLens.
  • Make it possible to use GeLU not ReLU (try that out as well).
  • Better encode state. #61
  • Look into current init ranges for all the model components and consider proper init ranges
  • Look into where all the parameters are and consider how we can make a sparser model
  • Implement wandb sweeps for DT training (likely already exists a card for this so I should find it)
  • Implement masking rather than just having different tokens during padding. Might be important?

If we've implemented all of those and still no success with the memory env training, possibly try either much longer training runs, more variable sampling methods, or ask for advice (or go bug hunting).

LN and Adam done. No clear benefit on smaller model. I think I'll get everything implemented then set off some sweeps tomorrow/the next day with the memory env environment.

I'm going to add a task here for setting up wandb sweeps. I think given the stuff I've added, it's important to just get a better sense of the right hyperparameters I need.

I just had a lightbulb moment relating to #61 so I'm going to do that really quick before I attempt wandb sweeps.

converting "Implement masking rather than just having different tokens during padding" to it's own card.

Closing this. Got working agents!