qdrant/quaterion

Implement cross-batch memory for losses

monatis opened this issue · 2 comments

How it works

  • XBM relies on the observation that the drift of embeddings is slow during training, i.e., embeddings for the same object is changing in a very slow pace.
  • This lets us add embeddings and targets in a ring buffer of a certain size.
  • After a certain number of iterations, start using the buffer. Now the final loss is the weighted sum of the actual mini-batch loss and the ring buffer loss.

Suggested implementation

  1. Introduce an XBMConfig class to hold the configuration values such as buffer_size, start_iteration, xbm_weight.
  2. Add a configure_xbm() hook in TrainableModel and return None by default.
  3. İf it returns an XBMConfig instance instead, create a XBMBuffer instance in the TrainableModel constructor.
  4. Implement the XBM logic in _common_step if stage is training.

Notes

  1. We cannot re-use the existing Accumulator classes because they are not ring buffers.
  2. I don't think we need a mixin because addition to TrainableModel will be only a few lines of code, and we need to update _common_step anyway.

Suggested implementation is in the issue. WDYT? @generall and @joein

Completed in #175