Implement cross-batch memory for losses
monatis opened this issue · 2 comments
monatis commented
- Paper: https://arxiv.org/pdf/1912.06798.pdf
- Reference for implementation: https://github.com/msight-tech/research-xbm/
How it works
- XBM relies on the observation that the drift of embeddings is slow during training, i.e., embeddings for the same object is changing in a very slow pace.
- This lets us add embeddings and targets in a ring buffer of a certain size.
- After a certain number of iterations, start using the buffer. Now the final loss is the weighted sum of the actual mini-batch loss and the ring buffer loss.
Suggested implementation
- Introduce an
XBMConfig
class to hold the configuration values such asbuffer_size
,start_iteration
,xbm_weight
. - Add a
configure_xbm()
hook inTrainableModel
and returnNone
by default. - İf it returns an
XBMConfig
instance instead, create aXBMBuffer
instance in theTrainableModel
constructor. - Implement the XBM logic in
_common_step
ifstage
is training.
Notes
- We cannot re-use the existing
Accumulator
classes because they are not ring buffers. - I don't think we need a mixin because addition to
TrainableModel
will be only a few lines of code, and we need to update_common_step
anyway.