apache/mxnet

gather_nd vs. take

mureva opened this issue · 0 comments

I have an observation and I'm hoping someone can advise.

I have a scenario where I maintain a large table of vectors, so a basic (n,m) array of n vectors of size m. I then have some system that generates indices into this table, and I want to pull rows out of the table at those indices. (more background? A hash table version of NeRF is what we're making).

So, I have a set of indices, and I want to gather the rows out of the table to make use of elsewhere. There's two operators in mxnet that will do the job - gather_nd and take

I could have more than 100k, even 1000k indices:

  • at 100k take will do a forward pass in less than 1 ms, but a backward pass will take about 45 ms. Meanwhile, gather_nd will do a forward pass in about 16 ms, and a backward pass in under 1 ms.
  • At 1000k indices, take is 4 ms fwd, 400 ms bwd, gather is 170 ms fwd, 1ms bwd. .

So... obvious question... is there a way to get the best of both worlds here? The fast forward pass of take, the fast backward pass of gather_nd?

Is there a better operator for gathering rows from the table? I also tried Embedding - on my test it looked like the best of both worlds, but on the real app, was slow on the backward pass.