mila-iqia/myia

add support for multiplying dataclass times scalar

ethancaballero opened this issue · 1 comments

e.g. support for:

model = model - 0.1 * dmodel

(inside of myia decorator):

@myia(backend='pytorch', backend_options={'device': device_type}, return_backend=True)

Using myia.hypermap.HyperMap instead of myia.composite.Elemwise (using the appropriate leaf function) would work as a quick-fix, but it would bypass Python's __add__/__radd__ protocol entirely, which may not be the right thing here.

Alternatively, we could implement Python's protocol, which is essentially this:

def add(x, y):
    if hasattr(x, '__add__'):
        z = x.__add__(y)
    else:
        z = NotImplemented
    if z is NotImplemented:
        z = y.__radd__(x)
    return z

We'd just need to add support for NotImplemented, hasattr and some way to check for NotImplemented.