google-research/scenic

Unknown term in `sample_permutation` for sinkhorn

MasterSkepticista opened this issue · 1 comments

Hi, I am trying to use sinkhorn matcher for batch sizes >8 (per device) as in the code. It (understandably) fails at #L62 since bs * dim exceeds sampling range available 10 * dim.

ValueError: Cannot take a larger sample than population when 'replace=False'

v = jax.random.choice(key, 10 * dim, shape=(bs, dim), replace=False)
v = jnp.sort(v, axis=-1) * 10.

I could replace 10 with bs or higher to get it to work. I have few questions though:

  • Is 10 arbitrary?
  • Why is 10 multiplied to the sorted vector in #L63? Could you point me to the relevant literature (incl for previous) if this is a specific choice?
  • If 10 is arbitrary, can I replace 10 -> bs in both #L62 and #L63 while maintaining algorithmic correctness?

Could you please cite the source? This otherwise feels like an AI-generated answer.