how to save graph_dict and read in the next time use
yeswzc opened this issue · 2 comments
Hi, is there a way to save the graph_dict in python that can be loaded into python next time? I tried pickle pickle.HIGHEST_PROTOCOL but it cann't be read.
Thank you!
Hi, I have a solution, but maybe not the best one.
import torchsnapshot
from tensordict import TensorDict
import copy
tmp = copy.deepcopy(graph_dict)
save
tmp['adj_norm'] = tmp['adj_norm'].to_dense()
tmp['adj_label'] = tmp['adj_label'].to_dense()
d = TensorDict(tmp, [])
state = {'state': d}
snapshot = torchsnapshot.Snapshot.take(app_state=state, path="snapshot")
restore
snapshot = torchsnapshot.Snapshot(path="snapshot")
graph_dict_r = TensorDict({}, [])
state_target = {"state": graph_dict_r}
snapshot.restore(app_state=state_target)
assert(graph_dict_r == d).all()
convert to SEDR input
graph_dict_r = graph_dict_r.to_dict()
graph_dict_r['adj_norm'] = graph_dict_r['adj_norm'].to_sparse()
graph_dict_r['adj_label'] = graph_dict_r['adj_label'].to_sparse()
Thank you! Works well!