state-spaces/s4

Dropout2d and residual

AliVard opened this issue · 2 comments

Dear authors and contributors,

There is an observation that I would be happy to get your confirmation on :-)
In all of the model hierarchy: SequenceModel, SequenceResidualBlock and S4 ,you are using Dropout2d which zeros at the batch dimension, i.e. ignores the entire sample. Without a residual link, with multiple layers, the probability that each sample is not ignored through the model becomes negligible. Consequently, the model does not see the inputs and will not train!
In the SequenceResidualBlock, the dropout is applied only if a residual link is present. The residual link of SequenceResidualBlock also takes care of the dropout from S4.
So my issue is two-fold:

  • When using dropout > 0, we never should set residual = None in the parameters of SequenceResidualBlock, right? Is it possible to add a check in the initialization to avoid possible misconfigurations?
  • The dropinp input of SequenceModel should not be used, as there is no residual link there. I've seen in all of the configs we have dropinp: 0.0. So why is it there at all?

Thanks and regards,

There is a bug in PyTorch 1.11 which is causing the behavior of Dropout2d that you've observed: pytorch/pytorch#77081

We will add a warning and a fix for this.

dropinp is a hyperparameter that people sometimes use, and we also used in earlier experiments on WikiText-103.

The READMEs have been updated to mention this issue, and we have implemented a custom dropout function to avoid problems with the PyTorch implementation. Perhaps in the far future when everyone is using torch 1.12 or later we can switch back to using the official functions.