RF weight dropout and variational noise
albertz opened this issue · 9 comments
Currently we don't have weight dropout in the RF. We should add it.
(I thought there was an issue already about it but I don't find it.)
Related:
- Weight dropout in RETURNN common:
returnn_common.nn.utils.weight_dropout.weight_dropout
. Related issue: rwth-i6/returnn_common#59, rwth-i6/returnn_common#250 - Weight dropout for TF net-dict. See: #735
- Related discussion (on weight norm): #1264 (comment)
In general, there is a whole class of similar features on parameter reparameterizations, which would require a similar mechanism, like weight norm.
Regarding implementation:
I think we could follow the PyTorch implementation of similar logic (e.g. weight norm) by using a forward-pre-hook. We already have support for hooks via rf.Module.register_forward_hook
/rf.hooks.setup_post_hook_on_method
.
Regarding PyTorch:
torch.nn.utils.weight_norm
usesregister_forward_pre_hook
. However, the doc says:This function is deprecated. Use torch.nn.utils.parametrizations.weight_norm() which uses the modern parametrization API.
torch.nn.utils.parametrizations.weight_norm
uses the modern parametrization API, i.e.torch.nn.utils.parametrize.register_parametrization
. Onregister_parametrization
(it's somewhat ugly (although our own hooks mechanism is also a bit complicated), but this here looks even worse):- It uses
_inject_new_class
to replace the original module class by a new dummy one, which is going to be extended by adding a property for the parameter.
(We cannot add the property to the object. A property can only be added to a class. See the descriptor guide doc.) - It deletes the parameter (
delattr
) and stores it in aParametrizationList
. - It injects a property (
_inject_property
) for the parameter. This will call theparametrization
to get it. - Note that this implementation of parametrization is not working with scripting. (I'm not exactly sure why though?)
- Even tracing is not working together with caching? (Again not sure why?)
- There is a caching mechanism which is disabled by default. The user need to explicitly enable it temporarily, like so:
When leaving the
import torch.nn.utils.parametrize as P ... with P.cached(): output = model(inputs)
cached()
context, it will clear the (global) cache.
- It uses
So, comparing the modern parametrization API register_parametrization
to the way it was done in torch.nn.utils.weight_norm
via register_forward_pre_hook
:
register_parametrization
also works for any other module methods, whileregister_forward_pre_hook
only works forforward
.register_forward_pre_hook
will create a temporary buffer to store the calculated parameter. This is buffer is not freed afterforward
. It is just updated in the next pre-hook. Viaregister_parametrization
, the calculated parameter is freed after it's usage (and with caching enabled, after you leave thecached()
context scope). Soregister_parametrization
requires less memory.register_parametrization
will recompute the parameter several times when it is accessed multiple times (whenever the property is accessed), unless the user uses the parameterizationcached
mechanism explicitly. But you usually would want to avoid the redundant computation, i.e. you want to use thecached
context. But that means you also need to modify the way you call your model by installing thecached
context somewhere. This is not automatic. On the other side,register_forward_pre_hook
is always automatic. (Although, if you callforward
multiple times, it would also cause redundant recomputations of the parameter.)register_forward_pre_hook
is conceptually much simpler and less ugly thanregister_parametrization
.- PyTorch parametrization API introduction and discussions: PR pytorch/pytorch#33344, issue pytorch/pytorch#28937, issue pytorch/pytorch#7313.
Concluding from that, I'm a bit unsure what way to go for RF... Using register_forward_pre_hook
looks too error-prone (only covers the hooked function, nothing else)... but the other approach looks too complicated? But maybe still better. The caching mechanism is maybe also not so important for now? For all use cases, I think it would not matter (e.g. rf.Linear
, rf.SelfAttention
, rf.Conv
, etc.).
Further, we should also support this with gradient checkpointing such that the weights are not stored twice in memory. In our existing TF implementation of variational noise, we already use gradient checkpointing, where only the random number generator state is stored and not the dropout mask nor the weight. Thus there is almost no memory overhead. See gradient_checkpoint_scope
and co. For PyTorch, it is currently unclear how to do this. I moved this over to a separate issue: #1552
Btw, regarding gradient checkpointing, see this current code as an example for variational noise in our TF code:
if param_variational_noise and param.dtype.is_floating and isinstance(param, tf.Variable):
with default_control_flow_ctx(): # make independent from loop/cond
with reuse_name_scope_of_tensor(param, postfix="_variational_noise", add_tensor_name=True):
def _apply_var_noise():
rnd_state = tf_util.StatelessRandomSeed.create(shape=tf_util.get_shape(param))
with gradient_checkpoint_scope():
noise = rnd_state.normal(stddev=param_variational_noise, dtype=param.dtype.base_dtype)
return param + noise
param = self.network.cond_on_train(
fn_train=_apply_var_noise,
fn_eval=lambda: param,
)
Specifically, check the code of gradient_checkpoint_scope
and prepare_gradient_checkpointing
.
I know that people also do gradient checkpointing in PyTorch, but I don't know exactly how that works.
There is a gradient checkpointing API in PT: https://pytorch.org/docs/stable/checkpoint.html
It even saves/restores the RNG state so we could do Dropout in there. I'm not sure the RNG state there can be made explicit, but it seems suitable in all the other ways.
I saw you asking in the PT issue about JAX: the RNG there is by definition stateless, and follows a design where the RNG seed has to be threaded through the code and explicitly "split" to make new seeds.
Copying from https://jax.readthedocs.io/en/latest/jax.random.html:
seed = 1701
num_steps = 100
key = jax.random.key(seed)
for i in range(num_steps):
key, subkey = jax.random.split(key)
params = compiled_update(subkey, params, next(batches))
It seems to me the API PT exposes for gradient checkpointing could be used as the RF frontend API and for the associated TF-backed implementation as well?
There is a gradient checkpointing API in PT: https://pytorch.org/docs/stable/checkpoint.html
Yea that is what I referred to when we talked about it. But I need to check it more how it is done there. Specifically, I'm still not exactly sure how I get what I want: that the dropout outputs are not stored but recomputed.
I saw you asking in the PT issue about JAX
No, I did not ask about JAX in there?
Yea that is what I referred to when we talked about it. But I need to check it more how it is done there. Specifically, I'm still not exactly sure how I get what I want: that the dropout outputs are not stored but recomputed.
Yeah it would seem to me like applying only the dropout operation within the gradient checkpointed context might not be enough, but one would have to move more of the layer functionality into the checkpointed/recomputed area? Is this what you're referring to?
Yeah it would seem to me like applying only the dropout operation within the gradient checkpointed context might not be enough, but one would have to move more of the layer functionality into the checkpointed/recomputed area? Is this what you're referring to?
I don't know how this works. I don't want to recompute whatever comes after the dropout. I just don't want that it stores the dropout output in memory for the backprop, i.e. that it recomputes the dropout.
(Note, I made a separate issue just for the gradient checkpointing aspect in PyTorch: #1552. So this issue here can just focus on the RF specific question on how to implement weight dropout (or also weight noise / variational noise). Edit This is implemented now. See #1559, gradient_checkpoint_scope
.)
So, I tend to reimplement something very similar as the PyTorch parametrization API, and also following some of the internal design choices.
- I don't want to extend
rf.Module
. It's also not easily possible anyway, as we don't have such_parameters
dict as intorch.nn.Module
, but our parameters are simply normal attributes. - I want to have it as property, not as buffer. The buffer would take additional memory, which I want to avoid at all cost. (Also all our effort for the gradient checkpointing is to save memory.)
- -> We also must inject a dummy class into the object, following similar logic as in
_inject_new_class
and then_inject_property
.- Note that they don't allow this to be serialized, and have a custom deepcopy function. I think we don't need this. I think serialization of such classes can also work. I need to play around with that later. I don't think we need to care about it for now.
- I'm not sure yet about the caching logic. I think we can skip this for the beginning. In any case, I'm not sure I would follow the same PyTorch API for this. It could be more automatic in RF. But anyway, to save memory, maybe we don't want this.
Some open questions:
- In case the parametrization doesn't change the underlying vars, but is just sth optional, applied in training, like weight dropout, variational noise, I would not want that the module parameter list changes. But is this possible?
- I think yes. Currently
RFModuleAsPTModule
usesnamed_parameters
, and that iterates throughvars(module).items()
, i.e. it should actually see the underlying var, if we did not remove it.
- I think yes. Currently
I also thought about deriving or extending rf.Parameter
. I'm not exactly sure how though. It is currently also a Tensor
, and I don't think we can make this dynamically evaluate on reads, without changing Tensor
itself. So I think this does not work.