csarofeen/pytorch

Feature request: iota prim

Closed this issue · 0 comments

Consistent with PyTorch's definition here:

https://github.com/pytorch/pytorch/blob/16387bee4ac6bdaaf419d90425fb82d161e10bb3/torch/_prims/__init__.py#L2364

Note this is a generalization of JAX's iota, which itself is a simplified version of XLA's iota. Implementing an iota primitive that is a superset of both XLA's and PyTorch's behavior would be fine, too.