RedisAI/redisai-py

Model exporting through python client

hhsecond opened this issue · 2 comments

Utilities for exporting TF and PyTorch models.

  • For exporting TF model
    • Utility function should accept session object and output var names as a list of strings
    • Fetching output tensor name is difficult for non-TF users. Our function should fetch the last couple of tensors and give suggestions. We should have a utility for fetching placeholder names.
  • For exporting PyTorch model
    • Exporting with tracing is easily achievable. But we should throw proper error messages
    • Exporting through scripting is on user since we won't have control over it.
    • So the utility function should accept a traced model or a model that is traceable or a scripted model and output file path

@lantiga What do you think?

I have two approaches for fixing this.

  1. Write a Model class that has model.export_graph() and model.import_graph() as class methods.
Model.export_graph(frozen_graph, output_names, path) # if TF
Model.export_graph(pytorch_model, path) # if PyTorch traced graph or ScriptModule

...

model = Model.import_graph(path)
r.setmodel(model)  # model will have all the information required for setting the model
  1. A much more simple approach. A utility module with def export_graph() and def import_graph()
from redisai import graph_utils as g
g.export_graph(frozen_graph, output_names, path) # if TF
g.export_graph(pytorch_model, path) # if PyTorch traced graph or ScriptModule

...

model = g.import_graph(path)  # model is just the binary data from graph.pb / graph.pt

We use tools like netron for fetching the input and output placeholders name in case of TF model right now. I could also write a class method for fetching all the node names.

Any thoughts? @lantiga @mnunberg

Closing this for now since this seems to have really low priority. Feel free to reopen if somebody wants to take it up.