fidler-lab/polyrnn-pp

I have a problem, can you help me?

zhouchanggeng opened this issue · 1 comments


ValueError Traceback (most recent call last)
in ()
1 #Initializing and restoring PolyRNN++
----> 2 model = PolygonModel(PolyRNN_metagraph, polyGraph)
3 model.register_eval_fn(lambda input_: evaluator.do_test(evalSess, input_))
4 polySess = tf.Session(config=tf.ConfigProto(
5 allow_soft_placement=True

~\polyrnn-pp\src\PolygonModel.py in init(self, meta_graph_path, graph)
30 self.saver = None
31 self.eval_pred_fn = None
---> 32 self._restore_graph(meta_graph_path)
33
34 def _restore_graph(self, meta_graph_path):

~\polyrnn-pp\src\PolygonModel.py in _restore_graph(self, meta_graph_path)
34 def _restore_graph(self, meta_graph_path):
35 with self.graph.as_default():
---> 36 self.saver = tf.train.import_meta_graph(meta_graph_path, clear_devices=True)
37
38 def _prediction(self):

D:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\training\saver.py in import_meta_graph(meta_graph_or_file, clear_devices, import_scope, **kwargs)
1907 clear_devices=clear_devices,
1908 import_scope=import_scope,
-> 1909 **kwargs)
1910 if meta_graph_def.HasField("saver_def"):
1911 return Saver(saver_def=meta_graph_def.saver_def, name=import_scope)

D:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\framework\meta_graph.py in import_scoped_meta_graph(meta_graph_or_file, clear_devices, graph, import_scope, input_map, unbound_inputs_col_name, restore_collections_predicate)
735 importer.import_graph_def(
736 input_graph_def, name=(import_scope or ""), input_map=input_map,
--> 737 producer_op_list=producer_op_list)
738
739 # Restores all the other collections.

D:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
430 'in a future version' if date is None else ('after %s' % date),
431 instructions)
--> 432 return func(*args, **kwargs)
433 return tf_decorator.make_decorator(func, new_func, 'deprecated',
434 _add_deprecated_arg_notice_to_docstring(

D:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\framework\importer.py in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
654 'Input types mismatch (expected %r but got %r)'
655 % (', '.join(dtypes.as_dtype(x).name for x in input_types),
--> 656 ', '.join(x.name for x in op._input_types))))
657 # pylint: enable=protected-access
658

ValueError: graph_def is invalid at node 'GatherTree': Input types mismatch (expected 'int32, int32, int32, int32' but got 'int32, int32, int32').

Please check that you are using tf 1.3.
Also check issues 2 and 3

Thanks