[BUG]: Piecewise not in torch_mappings
tbuckworth opened this issue · 5 comments
What happened?
after fitting a pysr module with "greater" as a binary operator, exporting to torch failed with the following error:
KeyError: 'Function Piecewise was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'
I've seen that in #433 Piecewise was added to the mappings, so I'm surprised to see this error.
I did attempt to fix myself, but it didn't work out:
I've tried adding mappings such as:
{sympy.Piecewise: lambda x, y: torch.where(x[1], x[0], y[0])}
but then the same error arises for sympy.functions.elementary.piecewise.ExprCondPair
and then sympy.logic.boolalg.BooleanTrue
in the end, I added
extra_torch_mappings = {
sympy.Piecewise: lambda x, y: torch.where(x[1], x[0], y[0]),
sympy.functions.elementary.piecewise.ExprCondPair: tuple,
sympy.logic.boolalg.BooleanTrue: torch.BoolTensor,
"greater": lambda x, y: torch.where(x > y, 1.0, 0.0),
}
But even this produced the following error:
KeyError: 'Function ITE was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'
Hopefully, I am missing something obvious?
Version
0.18.4
Operating System
Linux
Package Manager
pip
Interface
Script (i.e., python my_script.py
)
Relevant log output
No response
Extra Info
No response
I just realised that #433 is a pull request, so I copied the code and used it to add the mappings manually.
However, I'm still getting the error:
KeyError: 'Function ITE was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'
I've added this mapping, which seems to circumvent the error, but I haven't fully tested it yet:
def if_then_else(*conds):
a, b, c = conds
return torch.where(a, torch.where(b, True, False), torch.where(c, True, False))
extra_torch_mappings = {sympy.logic.boolalg.ITE: if_then_else}
Nice! Yeah that should be added to the GitHub pull request. Feel free to suggest that on the PR via the review system and you will be credited as a coauthor of the PR.
Thanks! I'll add a review comment on the PR.
There was another error with piecewise
, when cond is a float (1.), but I fixed it by replacing cond
with cond.bool()
:
output += torch.where(
cond.bool() & ~already_used, expr, torch.zeros_like(expr)
)
already_used = already_used | cond.bool()
Now, as long as I use a single batch dimension, it works, but multiple batch dimensions fail.
I believe this is due to export_torch.py, where _SingleSymPyModule.forward is:
def forward(self, X):
if self._selection is not None:
X = X[:, self._selection]
symbols = {symbol: X[:, i] for i, symbol in enumerate(self.symbols_in)}
return self._node(symbols)
if X[:,
is replaced with X[...,
then i believe it will work. This is a separate issue though, I suppose
(Just leaving it open until that PR is closed, since there are still some TODO items)