locuslab/deq

I'd like to ask if anderson can't be used normally sometimes

Closed this issue · 11 comments

image
File "/home/Maybe2021/Data/implicit_model/rootfind.py", line 193, in anderson
alpha = torch.solve(y[:, :n + 1], H[:, :n + 1, :n + 1])[0][:, 1:n + 1, 0] # (bsz x n)
RuntimeError: solve_cuda: For batch 31: U(7,7) is zero, singular U.

Hi @Maybe2022,

I have also encountered this, but only very very sparsely. I usually deal with this by one of the following methods: 1) slightly increasing the lam factor in the anderson function; 2) catching this exception (i.e., using try and except) and switch to a stronger solver like Broyden's method; or 3) applying some sort of DEQ stabilization (e.g., Jacobian regularization or auxiliary losses).

This anderson error usually happens as a result of the growing instability of the DEQ models. You probably want to check first whether the model converges properly for most of the time (i.e., measure ||f(z)-z||/||f(z)||; if this is <0.05 at the end of the solver iterations, then it's good convergence). If stability is an issue, fix it; if not, then I suggest you go with approaches (1) and (2) first.

Let me know if this helps!

Thank you very much for your answer,

I'm learning the combination of SNN and DEQ, which may be unstable. I've changed to your Broyden's method now.

At first, I learned from your tutorial on Colab. The Anderson code above is directly aimed at 4 dimensions. I don't understand the code of Broyden's solver, and it's difficult to understand the transformation of Broyden's data shape.

Hi @Maybe2022 ,

You probably want to consider some of the stabilization techniques then. Besides the Jacobian regularization, another method is to apply auxiliary losses on the fixed-point solving process. E.g., typically we do L(z*, y) where z*=z^{30} is the result of the root-solving; but we can also additionally add 0.2 * L(z^{10}, y) as an extra loss term that encourages the model to converge earlier. This extra loss is memory inefficient due to IFT, so you can apply a JFB on it (see this paper).

As for Broyden's method, I'm really just following the Sherman-Morrison formula in Wikipedia, see here. Note that we should never form the J^{-1} matrix, but only keep a low-rank-update version of it (i.e., -I + u1v1^T + u2v2^T + ...).

Let me know if you want me to clarify further!

Thank you very much for your help. I'm trying these technologies you introduced.

Another doubt is that when I try multi GPUs training, the acceleration effect is not obvious; For example, when I use 8 GPUs for training, the training time is only reduced by half. Is this caused by too much CPU participation in fixed point iteration? How to solve it?

Thank you very much.

Hi @Maybe2022 ,

There shouldn't be much CPU ops in the fixed point solving process. That Anderson/Broyden method is entirely on GPUs. Note that in general, when you use 2x more GPUs, you don't get 2x speedup. There could be a few solutions to what you observed:

  1. The speed could be bottlenecked by things like data loading. You may want to check that.
  2. The JFB idea I previously mentioned should make the backward pass almost free. Assuming the model is stable, it should make the training 2x faster.
  3. Check your pytorch and CUDA versions. The latest DEQ implements its backward pass with hook, which is only fully supported with PyTorch v1.10.
  4. There is some cost as you use more GPUs (mostly GPU context and communication cost). Therefore, as you use more GPUs, you probably want to also adjust the batch size accordingly.

Thank you very much;

Another problem is about over fitting; When I only train on cifar10, the convergence of the training set is normal, and the maximum of the test set can only reach about 87.

Is this related to the Eps and threshold of fixed point iteration?

Thank you very much for your help

I stole a lazy

Only the calculation graph is rewritten, and hook is not applied

image

Not sure what exactly went wrong but if you just use cls_mdeq_LARGE_reg.yaml you should expect ~93% accuracy pretty consistently.

If you are using Jacobian-free, then yes, you do want to monitor the stability of the fixed-point iterations. And with such instability, because you are using inexact gradients, overfitting might happen indeed. Could you check that the convergence is proper?

Hi @Maybe2022 ,

You probably want to consider some of the stabilization techniques then. Besides the Jacobian regularization, another method is to apply auxiliary losses on the fixed-point solving process. E.g., typically we do L(z*, y) where z*=z^{30} is the result of the root-solving; but we can also additionally add 0.2 * L(z^{10}, y) as an extra loss term that encourages the model to converge earlier. This extra loss is memory inefficient due to IFT, so you can apply a JFB on it (see this paper).

As for Broyden's method, I'm really just following the Sherman-Morrison formula in Wikipedia, see here. Note that we should never form the J^{-1} matrix, but only keep a low-rank-update version of it (i.e., -I + u1v1^T + u2v2^T + ...).

Let me know if you want me to clarify further!

Hi @jerrybai1995,

Thanks for your excellent work! I really enjoy this series of papers. Regarding Broyden's method, I notice that seems you're using the "bad Broyden's method". May I ask whether you intentionally choose this version? Why don't use the "good" one? (However, I look up some literatures. These two versions seems have similar performance.) Do you mind to share some insights in terms of this choice? Thanks in advance.

Hi @liu-jc ,

Thanks for the question. There is no specific reason why Broyden good was not picked (it was more like an empirical choice...? when I started to work on this project I tried both Broyden good and bad, and found the bad to be a bit better, so I went on writing a PyTorch version of that one). But as you can see, eventually Anderson acceleration also works very well; and if you were to apply Jacobian regularization, the naive fixed-point iteration also works quite well. So I definitely don't think Broyden good wouldn't work 😆

Hope this clarifies things!

Hi @jerrybai1995 ,

Thanks for your prompt answer! Now I understand. Thanks for your help again :-)