MilesCranmer/PySR

[BUG]: Piecewise not in torch_mappings

tbuckworth opened this issue · 5 comments

What happened?

after fitting a pysr module with "greater" as a binary operator, exporting to torch failed with the following error:

KeyError: 'Function Piecewise was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'

I've seen that in #433 Piecewise was added to the mappings, so I'm surprised to see this error.

I did attempt to fix myself, but it didn't work out:
I've tried adding mappings such as:

{sympy.Piecewise: lambda x, y: torch.where(x[1], x[0], y[0])}

but then the same error arises for sympy.functions.elementary.piecewise.ExprCondPair and then sympy.logic.boolalg.BooleanTrue

in the end, I added

extra_torch_mappings = {
        sympy.Piecewise: lambda x, y: torch.where(x[1], x[0], y[0]),
        sympy.functions.elementary.piecewise.ExprCondPair: tuple,
        sympy.logic.boolalg.BooleanTrue: torch.BoolTensor,
        "greater": lambda x, y: torch.where(x > y, 1.0, 0.0),
    }

But even this produced the following error:

KeyError: 'Function ITE was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'

Hopefully, I am missing something obvious?

Version

0.18.4

Operating System

Linux

Package Manager

pip

Interface

Script (i.e., python my_script.py)

Relevant log output

No response

Extra Info

No response

I just realised that #433 is a pull request, so I copied the code and used it to add the mappings manually.
However, I'm still getting the error:
KeyError: 'Function ITE was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'

I've added this mapping, which seems to circumvent the error, but I haven't fully tested it yet:

def if_then_else(*conds):
    a, b, c = conds
    return torch.where(a, torch.where(b, True, False), torch.where(c, True, False))

extra_torch_mappings = {sympy.logic.boolalg.ITE: if_then_else}

Nice! Yeah that should be added to the GitHub pull request. Feel free to suggest that on the PR via the review system and you will be credited as a coauthor of the PR.

Thanks! I'll add a review comment on the PR.

There was another error with piecewise, when cond is a float (1.), but I fixed it by replacing cond with cond.bool():

output += torch.where(
                    cond.bool() & ~already_used, expr, torch.zeros_like(expr)
                )
                already_used = already_used | cond.bool()

Now, as long as I use a single batch dimension, it works, but multiple batch dimensions fail.

I believe this is due to export_torch.py, where _SingleSymPyModule.forward is:

            def forward(self, X):
                if self._selection is not None:
                    X = X[:, self._selection]
                symbols = {symbol: X[:, i] for i, symbol in enumerate(self.symbols_in)}
                return self._node(symbols)

if X[:, is replaced with X[..., then i believe it will work. This is a separate issue though, I suppose

(Just leaving it open until that PR is closed, since there are still some TODO items)