hidet-org/hidet

[Feature]Does hidet support training?

Closed this issue · 1 comments

Does hidet support training? I saw that it is now a backend inside torch dynamo now, so does this mean we can also reuse AOT autograd with hidet backend to do training now? How does that work if it does support now or why it's not working yet?

Hi @BruceDai003,

With the AOT autograd, it should work in principle.

For example, after registering the operators torch.ops.aten.add and torch.ops.aten.cos (#223), we can run the following example

import torch
from functorch.compile import aot_function


def fn(a, b, c, d):
    x = a + b + c + d
    return x.cos().cos()


# Test that it works
a, b, c, d = [torch.randn(2, 4, requires_grad=True) for _ in range(4)]
ref = fn(a, b, c, d)
loss = ref.sum()
loss.backward()


# The compiler_fn is called after the forward and backward graphs are extracted.
# Here, we just print the code in the compiler_fn. Return of this function is a callable.
def compiler_fn(fx_module: torch.fx.GraphModule, _):
    import hidet
    hidet.torch.dynamo_config.print_input_graph()
    return torch.compile(fx_module, backend='hidet')


# Pass on the compiler_fn to the aot_function API
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)

# Run the aot_print_fn once to trigger the compilation and print the graphs
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs
res = aot_print_fn(cloned_a, cloned_b, cloned_c, cloned_d)
res.sum().backward()
assert torch.allclose(ref, res)
assert torch.allclose(cloned_a.grad, a.grad)
assert torch.allclose(cloned_b.grad, b.grad)
assert torch.allclose(cloned_c.grad, c.grad)
assert torch.allclose(cloned_d.grad, d.grad)

with output

/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:1251: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
  warnings.warn(
opcode         name           target            args                                           kwargs
-------------  -------------  ----------------  ---------------------------------------------  --------
placeholder    primals_1      primals_1         ()                                             {}
placeholder    primals_2      primals_2         ()                                             {}
placeholder    primals_3      primals_3         ()                                             {}
placeholder    primals_4      primals_4         ()                                             {}
call_function  add_tensor     aten.add.Tensor   (primals_1, primals_2)                         {}
call_function  add_tensor_1   aten.add.Tensor   (add_tensor, primals_3)                        {}
call_function  add_tensor_2   aten.add.Tensor   (add_tensor_1, primals_4)                      {}
call_function  cos_default    aten.cos.default  (add_tensor_2,)                                {}
call_function  cos_default_1  aten.cos.default  (cos_default,)                                 {}
output         output         output            ((cos_default_1, add_tensor_2, cos_default),)  {}
/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:1251: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
  warnings.warn(
opcode         name           target            args                         kwargs
-------------  -------------  ----------------  ---------------------------  --------
placeholder    add_2          add_2             ()                           {}
placeholder    cos            cos               ()                           {}
placeholder    tangents_1     tangents_1        ()                           {}
call_function  sin_default    aten.sin.default  (cos,)                       {}
call_function  neg_default    aten.neg.default  (sin_default,)               {}
call_function  mul_tensor     aten.mul.Tensor   (tangents_1, neg_default)    {}
call_function  sin_default_1  aten.sin.default  (add_2,)                     {}
call_function  neg_default_1  aten.neg.default  (sin_default_1,)             {}
call_function  mul_tensor_1   aten.mul.Tensor   (mul_tensor, neg_default_1)  {}
output         output         output            ((mul_tensor_1,),)           {}

Process finished with exit code 0

However, our current focus is still on inference. Thus, we may lack some backward operators, and there might be some unexpected problems in using hidet as the compiler for training.