torch.randint parameter high is an Int and size is a Seq
Closed this issue · 1 comments
hmf commented
Minor quibble. On inspection, the code I find is:
def randint(low: Long, high: Int, size: Seq[Int]) =
// TODO Handle Optional Generators properly
val generator = new org.bytedeco.pytorch.GeneratorOptional()
Tensor(
torchNative.torch_randint(low, high, size.toArray.map(_.toLong), generator)
)
Shouldn't the high
be a Long
?
Question: would it not be better to have size
be an array and avoid a copy?
sbrunk commented
Yes high
should be a long to, good catch.
As for size, I'd argue it's more important have a convenient API that accepts any kind of Seq. It might even make sense to extend this to also accept tuples of ints. The critical path where performance really counts is always on tensor operations.