Incorrect grid shape in `utils.gridapply` for one-dimensional space
francois-rozet opened this issue · 0 comments
francois-rozet commented
Description
When the domain is one-dimensional, lampe.utils.gridapply
builds a grid of shape (bins,)
instead of the expected (bins, 1)
.
Reproduce
In the following error, mat1
is x
and should be of shape (128, 1)
.
>>> import torch
>>> import lampe
>>> A = torch.randn(1, 3)
>>> f = lambda x: x @ A
>>> domain = torch.zeros(1), torch.ones(1)
>>> lampe.utils.gridapply(f, domain)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/francois/Documents/Git/lampe/lampe/utils.py", line 104, in gridapply
y = [f(x) for x in grid.split(batch_size)]
File "/home/francois/Documents/Git/lampe/lampe/utils.py", line 104, in <listcomp>
y = [f(x) for x in grid.split(batch_size)]
File "<stdin>", line 1, in <lambda>
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x128 and 1x3)
Expected behavior
The grid should be of shape (bins, 1)
.
Causes and solution
The lampe.utils.gridapply
function uses torch.cartesian_prod
, which behaves inconsistently when given a single argument. A reshape
should be enough to fix the issue.
Environment
- LAMPE version: 0.6.1
- PyTorch version: 1.12.0
- Python version: 3.9.15
- OS: Ubuntu 22.10