Noble-Lab/casanovo

Export casanovo to torchscript/onnx

LLautenbacher opened this issue · 1 comments

Hi,

I want to export casanovo to Torchscript or ONNX to make it accessible via Koina.
When I follow the documentation for Lightning to do that (using method="trace"). I get a UnsupportedNodeError. I'm not familiar with Lightning or Pytorch. Can you help with creating a Torchscript/ONNX export of your model?

Here is the full traceback
---------------------------------------------------------------------------
UnsupportedNodeError                      Traceback (most recent call last)
Cell In[5], [line 1](vscode-notebook-cell:?execution_count=5&line=1)
----> [1](vscode-notebook-cell:?execution_count=5&line=1) runner.model.to_torchscript("model.pt", method="trace", example_inputs=inp)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/lightning/pytorch/core/module.py:1479, in LightningModule.to_torchscript(self, file_path, method, example_inputs, **kwargs)
1477 example_inputs = self._apply_batch_transfer_handler(example_inputs)
1478 with _jit_is_scripting():
-> 1479 torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
1480 else:
1481 raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}")

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_trace.py:820, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
818 else:
819 raise RuntimeError("example_kwarg_inputs should be a dict")
--> 820 return trace_module(
821 func,
822 {"forward": example_inputs},
823 None,
824 check_trace,
825 wrap_check_inputs(check_inputs),
826 check_tolerance,
827 strict,
828 _force_outplace,
829 _module_class,
830 example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
831 _store_inputs=_store_inputs,
832 )
833 if (
834 hasattr(func, "self")
835 and isinstance(func.self, torch.nn.Module)
836 and func.name == "forward"
837 ):
838 if example_inputs is None:

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_trace.py:1053, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
1050 torch.jit._trace._trace_module_map = trace_module_map
1051 register_submods(mod, "__module")
-> 1053 module = make_module(mod, _module_class, _compilation_unit)
1055 for method_name, example_inputs in inputs.items():
1056 if method_name == "forward":
1057 # "forward" is a special case because we need to trace
1058 # Module.__call__, which sets up some extra tracing, but uses
1059 # argument names of the real Module.forward method.

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_trace.py:624, in make_module(mod, _module_class, _compilation_unit)
622 elif torch._jit_internal.module_has_exports(mod):
623 infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods
--> 624 return torch.jit._recursive.create_script_module(
625 mod, infer_methods_stubs_fn, share_types=False, is_tracing=True
626 )
627 else:
628 if _module_class is None:

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_recursive.py:558, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
556 if not is_tracing:
557 AttributeTypeIsSupportedChecker().check(nn_module)
--> 558 return create_script_module_impl(nn_module, concrete_type, stubs_fn)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_recursive.py:631, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
628 script_module._concrete_type = concrete_type
630 # Actually create the ScriptModule, initializing it with the function we just defined
--> 631 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
633 # Compile methods if necessary
634 if concrete_type not in concrete_type_store.methods_compiled:

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_script.py:647, in RecursiveScriptModule._construct(cpp_module, init_fn)
633 """
634 Construct a RecursiveScriptModule that's ready for use.
635
(...)
644 init_fn: Lambda that initializes the RecursiveScriptModule passed to it.
645 """
646 script_module = RecursiveScriptModule(cpp_module)
--> 647 init_fn(script_module)
649 # Finalize the ScriptModule: replace the nn.Module state with our
650 # custom implementations and flip the _initializing bit.
651 RecursiveScriptModule._finalize_scriptmodule(script_module)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_recursive.py:607, in create_script_module_impl..init_fn(script_module)
604 scripted = orig_value
605 else:
606 # always reuse the provided stubs_fn to infer the methods to compile
--> 607 scripted = create_script_module_impl(
608 orig_value, sub_concrete_type, stubs_fn
609 )
611 cpp_module.setattr(name, scripted)
612 script_module._modules[name] = scripted

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_recursive.py:635, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
633 # Compile methods if necessary
634 if concrete_type not in concrete_type_store.methods_compiled:
--> 635 create_methods_and_properties_from_stubs(
636 concrete_type, method_stubs, property_stubs
637 )
638 # Create hooks after methods to ensure no name collisions between hooks and methods.
639 # If done before, hooks can overshadow methods that aren't exported.
640 create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/recursive.py:467, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
464 property_defs = [p.def
for p in property_stubs]
465 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 467 concrete_type._create_methods_and_properties(
468 property_defs, property_rcbs, method_defs, method_rcbs, method_defaults
469 )

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_recursive.py:1036, in compile_unbound_method(concrete_type, fn)
1034 if _jit_internal.is_ignored_fn(fn):
1035 return None
-> 1036 stub = make_stub(fn, fn.name)
1037 with torch._jit_internal._disable_emit_hooks():
1038 # We don't want to call the hooks here since the graph that is calling
1039 # this function is not yet complete
1040 create_methods_and_properties_from_stubs(concrete_type, (stub,), ())

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/_recursive.py:71, in make_stub(func, name)
69 def make_stub(func, name):
70 rcb = _jit_internal.createResolutionCallbackFromClosure(func)
---> 71 ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
72 return ScriptMethodStub(rcb, ast, func)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:372, in get_jit_def(fn, def_name, self_name, is_classmethod)
369 qualname = get_qualified_name(fn)
370 pdt_arg_types = type_trace_db.get_args_types(qualname)
--> 372 return build_def(
373 parsed_def.ctx,
374 fn_def,
375 type_line,
376 def_name,
377 self_name=self_name,
378 pdt_arg_types=pdt_arg_types,
379 )

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:433, in build_def(ctx, py_def, type_line, def_name, self_name, pdt_arg_types)
430 type_comment_decl = torch._C.parse_type_comment(type_line)
431 decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
--> 433 return Def(Ident(r, def_name), decl, build_stmts(ctx, body))

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:195, in build_stmts(ctx, stmts)
194 def build_stmts(ctx, stmts):
--> 195 stmts = [build_stmt(ctx, s) for s in stmts]
196 return list(filter(None, stmts))

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:195, in (.0)
194 def build_stmts(ctx, stmts):
--> 195 stmts = [build_stmt(ctx, s) for s in stmts]
196 return list(filter(None, stmts))

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:406, in Builder.call(self, ctx, node)
404 if method is None:
405 raise UnsupportedNodeError(ctx, node)
--> 406 return method(ctx, node)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:773, in StmtBuilder.build_For(ctx, stmt)
766 if stmt.orelse:
767 raise NotSupportedError(r, "else branches of for loops aren't supported")
769 return For(
770 r,
771 [build_expr(ctx, stmt.target)],
772 [build_expr(ctx, stmt.iter)],
--> 773 build_stmts(ctx, stmt.body),
774 )

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:195, in build_stmts(ctx, stmts)
194 def build_stmts(ctx, stmts):
--> 195 stmts = [build_stmt(ctx, s) for s in stmts]
196 return list(filter(None, stmts))

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:195, in (.0)
194 def build_stmts(ctx, stmts):
--> 195 stmts = [build_stmt(ctx, s) for s in stmts]
196 return list(filter(None, stmts))

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:406, in Builder.call(self, ctx, node)
404 if method is None:
405 raise UnsupportedNodeError(ctx, node)
--> 406 return method(ctx, node)

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:676, in StmtBuilder.build_Expr(ctx, stmt)
674 return None
675 else:
--> 676 return ExprStmt(build_expr(ctx, value))

File ~/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/jit/frontend.py:405, in Builder.call(self, ctx, node)
403 method = getattr(self, "build_" + node.class.name, None)
404 if method is None:
--> 405 raise UnsupportedNodeError(ctx, node)
406 return method(ctx, node)

UnsupportedNodeError: Yield aren't supported:
File "/cmnfs/home/llautenbacher/miniconda3/envs/casanovo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2230
"""
for name, param in self.named_parameters(recurse=recurse):
yield param
~ <--- HERE

We are very interested in helping to make this happen, but unfortunately, we have zero familiarity with ONNX.

Separately, we've found that torch compile doesn't work with Casanovo (though I don't know the details there). Perhaps these are related issues.

If there is anything specific we can help with, please let us know.