How does DEQ use less memory compared to explicit network?
Opened this issue · 0 comments
Hi, I have a question about memory footprint of DEQ.
As far as I understand, DEQ does not need to store intermediate activations, and thus able to approximate infinite-layer model at the cost of only one layer. (So that training with NFE=30 iteration will cost just as a single iteration)
However, in the first DEQ paper, Table.3, explicit Transformer-XL with 16 layers consume much more VRAM compared to DEQ-Transformer(medium).
It seems they both have nearly same architecture with nearly same number of parameters. In this setting, as far as I understand, DEQ should perform better because it is effectively modeling much deeper model than its explicit counterpart, while consuming same memory. Why DEQ consumes less VRAM? Shouldn't it be same?
(I also found that the forward function of DEQ transformer contains one regular explicit forward that tracks gradient:
deq/DEQ-Sequence/models/deq_transformer.py
Line 367 in 1fb7059