Error in DPOptimizer: Inconsistency between batch_first argument of PrivacyEngine and DPMultiheadAttention
tklausen opened this issue ยท 2 comments
๐ Bug
Context
Both PrivacyEngine and DPMultiheadAttention accept the bool argument batch_first
, which indicates whether the batch dimension is the first or second dimension. In the case of the PrivacyEngine, this argument is passed down to the GradSampleModule which ensures that the batch dimension is always the first dimension in .grad_samples
(=per-sample gradients) (see rearrange_grad_samples), so that the grad_samples
can be used by DPOptimizer.
Problem
Using PrivacyEngine and DPMultiheadAttention both with batch_first=True
mixes up the batch dimension and can throw an error.
DPMultiheadAttention reorders its inputs to the forward method (query, key, value) so that the batch dimension is the second dimension (and the sequence dimension is the first dimension) if batch_first=True
. Therefore, the internal linear layers of DPMultiheadAttention are called with an input whose second dimension is the batch dimension. However, the GradSampleModule expects the batch dimension to be the first dimension (because batch_first
was set to True
in the PrivacyEngine). Thus, the computed gradients are not per-sample gradients. This even throws an error if the model uses an additional layer other than DPMultiheadAttention whose input is batch dimension first. This error is thrown during a torch.stack operation in the DPOptimizer's clip_and_accumulate method.
To Reproduce
See Colab.
- Initialize PrivacyEngine with
batch_first=True
- Create a model that has at least:
a. one DPMultiheadAttention layer withbatch_first=True
b. one other layer such as nn.Linear - Ensure that batch size != sequence length of input to DPMultiheadAttention layer
Stack trace:
[<ipython-input-2-e63f0910143c>](https://localhost:8080/#) in train(model, criterion, optimizer, train_loader, device)
12 loss.backward()
13
---> 14 optimizer.step()
15 optimizer.zero_grad()
16
[/content/opacus/opacus/optimizers/optimizer.py](https://localhost:8080/#) in step(self, closure)
518 closure()
519
--> 520 if self.pre_step():
521 return self.original_optimizer.step()
522 else:
[/content/opacus/opacus/optimizers/optimizer.py](https://localhost:8080/#) in pre_step(self, closure)
499 return True
500
--> 501 self.clip_and_accumulate()
502 if self._check_skip_next_step():
503 self._is_last_step_skipped = True
[/content/opacus/opacus/optimizers/optimizer.py](https://localhost:8080/#) in clip_and_accumulate(self)
404 g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
405 ]
--> 406 per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
407 per_sample_clip_factor = (
408 self.max_grad_norm / (per_sample_norms + 1e-6)
RuntimeError: stack expects each tensor to be equal size, but got [16] at entry 0 and [8] at entry 2
Expected Behavior
The per-sample gradients are computed correctly and no error is thrown if batch_first
has the same value in both PrivacyEngine and DPMultiheadAttention.
For batch_first=False
, no changes are required.
For batch_first=True
, the DPMultiheadAttention layer should call its internal linear layers with an input whose first dimension is the batch dimension.
Environment
opacus: 1.4.1
pytorch: 2.2.1
Other packages should not be relevant as this is a pure Opacus bug.
Additional context
This issue may be related to #505, but I can't confirm this as the source code for this issue seems to have been deleted.
Thanks for contributing to Opacus! Great catch!
Let me launch some fix. Need to guarantee that the input of all the linear layers inside DPMultiheadAttention has batch_size as the first dimension of input when batch_first = True.
Closed the issue since we launched the fix in #651