add support for multiplying dataclass times scalar
ethancaballero opened this issue · 1 comments
ethancaballero commented
e.g. support for:
model = model - 0.1 * dmodel
(inside of myia decorator):
@myia(backend='pytorch', backend_options={'device': device_type}, return_backend=True)
breuleux commented
Using myia.hypermap.HyperMap
instead of myia.composite.Elemwise
(using the appropriate leaf function) would work as a quick-fix, but it would bypass Python's __add__/__radd__
protocol entirely, which may not be the right thing here.
Alternatively, we could implement Python's protocol, which is essentially this:
def add(x, y):
if hasattr(x, '__add__'):
z = x.__add__(y)
else:
z = NotImplemented
if z is NotImplemented:
z = y.__radd__(x)
return z
We'd just need to add support for NotImplemented
, hasattr
and some way to check for NotImplemented
.