JinmiaoChenLab/SEDR

how to save graph_dict and read in the next time use

yeswzc opened this issue · 2 comments

Hi, is there a way to save the graph_dict in python that can be loaded into python next time? I tried pickle pickle.HIGHEST_PROTOCOL but it cann't be read.

Thank you!

Hi, I have a solution, but maybe not the best one.

import torchsnapshot
from tensordict import TensorDict

import copy
tmp = copy.deepcopy(graph_dict)

save

tmp['adj_norm'] = tmp['adj_norm'].to_dense()
tmp['adj_label'] = tmp['adj_label'].to_dense()
d = TensorDict(tmp, [])
state = {'state': d}
snapshot = torchsnapshot.Snapshot.take(app_state=state, path="snapshot")

restore

snapshot = torchsnapshot.Snapshot(path="snapshot")
graph_dict_r = TensorDict({}, [])
state_target = {"state": graph_dict_r}
snapshot.restore(app_state=state_target)
assert(graph_dict_r == d).all()

convert to SEDR input

graph_dict_r = graph_dict_r.to_dict()
graph_dict_r['adj_norm'] = graph_dict_r['adj_norm'].to_sparse()
graph_dict_r['adj_label'] = graph_dict_r['adj_label'].to_sparse()

Thank you! Works well!