vislearn/FrEIA

[BUG] Nondeterministic

ju-w opened this issue · 0 comments

ju-w commented

Hello,

Problem

Consider the example below. Sometimes the graph is created differently and the naming changes. This is especially problematic when saving and loading the state_dict, as the names might not match anymore and loading will fail.

The underlying issue is from topological_order function, as a result of unstable iteration over dictionary.

def topological_order(all_nodes: List[AbstractNode], in_nodes: List[InputNode],

Example
def do():
    nodes =     [       Ff.InputNode(3, 64, 64)]
    nodes.append(       Ff.Node(nodes[-1], Fm.Split, {'section_sizes':[2,1]}))
    split =     [       Ff.Node(nodes[-1].out1, Fm.Flatten, {})]
    nodes.append(       Ff.Node(nodes[-1], Fm.Flatten, {}))
    nodes.append(       Ff.Node(nodes[-1], Fm.PermuteRandom))
    nodes.append(       Ff.Node([nodes[-1].out0]+[split[0].out0], Fm.Concat, {'dim':0}))
    nodes.append(       Ff.OutputNode(nodes[-1]))
    inn = Ff.GraphINN(split+nodes, verbose=1)
    [print(x) for x in inn.state_dict().keys() if "perm_inv" in x]

do()
do()
do()
do()
do()
...
GraphINN(
  (module_list): ModuleList(
    (0): Split()
    (1-2): 2 x Flatten()
    (3): PermuteRandom()
    (4): Concat()
  )
)
module_list.3.perm_inv

GraphINN(
  (module_list): ModuleList(
    (0): Split()
    (1): Flatten()
    (2): PermuteRandom()
    (3): Flatten()
    (4): Concat()
  )
)
module_list.2.perm_inv
...