gather_nd vs. take
mureva opened this issue · 0 comments
I have an observation and I'm hoping someone can advise.
I have a scenario where I maintain a large table of vectors, so a basic (n,m) array of n vectors of size m. I then have some system that generates indices into this table, and I want to pull rows out of the table at those indices. (more background? A hash table version of NeRF is what we're making).
So, I have a set of indices, and I want to gather the rows out of the table to make use of elsewhere. There's two operators in mxnet that will do the job - gather_nd
and take
I could have more than 100k, even 1000k indices:
- at 100k
take
will do a forward pass in less than 1 ms, but a backward pass will take about 45 ms. Meanwhile,gather_nd
will do a forward pass in about 16 ms, and a backward pass in under 1 ms. - At 1000k indices,
take
is 4 ms fwd, 400 ms bwd, gather is 170 ms fwd, 1ms bwd. .
So... obvious question... is there a way to get the best of both worlds here? The fast forward pass of take
, the fast backward pass of gather_nd
?
Is there a better operator for gathering rows from the table? I also tried Embedding
- on my test it looked like the best of both worlds, but on the real app, was slow on the backward pass.