sbrunk/storch

torch.randint parameter high is an Int and size is a Seq

Closed this issue · 1 comments

hmf commented

Minor quibble. On inspection, the code I find is:

  def randint(low: Long, high: Int, size: Seq[Int]) =
    // TODO Handle Optional Generators properly
    val generator = new org.bytedeco.pytorch.GeneratorOptional()
    Tensor(
      torchNative.torch_randint(low, high, size.toArray.map(_.toLong), generator)
    )

Shouldn't the high be a Long?

Question: would it not be better to have size be an array and avoid a copy?

sbrunk commented

Yes high should be a long to, good catch.

As for size, I'd argue it's more important have a convenient API that accepts any kind of Seq. It might even make sense to extend this to also accept tuples of ints. The critical path where performance really counts is always on tensor operations.