[Feature]Does hidet support training?
Closed this issue · 1 comments
BruceDai003 commented
Does hidet support training? I saw that it is now a backend inside torch dynamo now, so does this mean we can also reuse AOT autograd with hidet backend to do training now? How does that work if it does support now or why it's not working yet?
yaoyaoding commented
Hi @BruceDai003,
With the AOT autograd, it should work in principle.
For example, after registering the operators torch.ops.aten.add
and torch.ops.aten.cos
(#223), we can run the following example
import torch
from functorch.compile import aot_function
def fn(a, b, c, d):
x = a + b + c + d
return x.cos().cos()
# Test that it works
a, b, c, d = [torch.randn(2, 4, requires_grad=True) for _ in range(4)]
ref = fn(a, b, c, d)
loss = ref.sum()
loss.backward()
# The compiler_fn is called after the forward and backward graphs are extracted.
# Here, we just print the code in the compiler_fn. Return of this function is a callable.
def compiler_fn(fx_module: torch.fx.GraphModule, _):
import hidet
hidet.torch.dynamo_config.print_input_graph()
return torch.compile(fx_module, backend='hidet')
# Pass on the compiler_fn to the aot_function API
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
# Run the aot_print_fn once to trigger the compilation and print the graphs
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs
res = aot_print_fn(cloned_a, cloned_b, cloned_c, cloned_d)
res.sum().backward()
assert torch.allclose(ref, res)
assert torch.allclose(cloned_a.grad, a.grad)
assert torch.allclose(cloned_b.grad, b.grad)
assert torch.allclose(cloned_c.grad, c.grad)
assert torch.allclose(cloned_d.grad, d.grad)
with output
/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:1251: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
warnings.warn(
opcode name target args kwargs
------------- ------------- ---------------- --------------------------------------------- --------
placeholder primals_1 primals_1 () {}
placeholder primals_2 primals_2 () {}
placeholder primals_3 primals_3 () {}
placeholder primals_4 primals_4 () {}
call_function add_tensor aten.add.Tensor (primals_1, primals_2) {}
call_function add_tensor_1 aten.add.Tensor (add_tensor, primals_3) {}
call_function add_tensor_2 aten.add.Tensor (add_tensor_1, primals_4) {}
call_function cos_default aten.cos.default (add_tensor_2,) {}
call_function cos_default_1 aten.cos.default (cos_default,) {}
output output output ((cos_default_1, add_tensor_2, cos_default),) {}
/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:1251: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
warnings.warn(
opcode name target args kwargs
------------- ------------- ---------------- --------------------------- --------
placeholder add_2 add_2 () {}
placeholder cos cos () {}
placeholder tangents_1 tangents_1 () {}
call_function sin_default aten.sin.default (cos,) {}
call_function neg_default aten.neg.default (sin_default,) {}
call_function mul_tensor aten.mul.Tensor (tangents_1, neg_default) {}
call_function sin_default_1 aten.sin.default (add_2,) {}
call_function neg_default_1 aten.neg.default (sin_default_1,) {}
call_function mul_tensor_1 aten.mul.Tensor (mul_tensor, neg_default_1) {}
output output output ((mul_tensor_1,),) {}
Process finished with exit code 0
However, our current focus is still on inference. Thus, we may lack some backward operators, and there might be some unexpected problems in using hidet as the compiler for training.