JIT-compiling the extension results in non-functional Python module.
mrTsjolder opened this issue · 0 comments
When submitting a bug report, please include the following information (where relevant):
- OS: Arch Linux
- PyTorch version: 2.3.0
- How you installed PyTorch (conda, pip, source): conda
- Python version: 3.12
- CUDA/cuDNN version: 12.1 (runtime) / 8.9.2
- GPU models and configuration: 1070 Ti
- GCC version (if compiling from source): /
When using torch.utils.cpp_extension.load
to JIT compile the extension, the returned module does not provide access to the exported functions.
Code to reproduce (in interactive Python session with working directory extension-cpp
):
>>> import torch
>>> from torch.utils.cpp_extension import load
>>> extension_cpp = load(name="extension_cpp", sources=["extension_cpp/csrc/muladd.cpp"])
>>> extension_cpp.mymuladd
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: module 'extension_cpp' has no attribute 'mymuladd'
Note that after the loading, the following does work:
>>> torch.ops.extension_cpp.mymuladd
<OpOverloadPacket(op='extension_cpp.mymuladd')>
I.e., load
does seem to have the side-effect of loading the PyTorch op (similar to setting is_python_module=False
), but does not provide the functions in the Python module.
I would expect to be able to use extension_cpp
in the same way as torch.ops.extension_cpp
.
I suspect that this issue is due to the empty PYBIND11_MODULE
block.
This seems to be confirmed by the fact that the following modification in muladd.cpp
makes extension_cpp.mymuladd
available:
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("mymuladd", &mymuladd_cpu, "custom pytorch operation 'muladd'");
m.def("mymul", &mymuladd_cpu, "custom pytorch operation 'mul'");
m.def("myadd_out", &mymuladd_cpu, "custom pytorch operation 'add_out'");
}
However, I am not sure if this is just a workaround or the actual bug-fix.