Incompatibility of build_graph function with tensorflow 2.0
mattdornfeld opened this issue · 1 comments
mattdornfeld commented
I'm running the below example in tf 2.0 and I get an error AttributeError: module 'tensorflow_core._api.v2.train' has no attribute 'export_meta_graph'
. It seems the build_graph function is incompatible with the tf 2.0 API. The example works fine with tf 1.15.
In [2]: import tensorflow as tf
...: from tensorflow import keras
...: from tensorflow.keras import layers
...: from sparkflow.graph_utils import build_graph
...:
...: tf.compat.v1.disable_eager_execution()
...:
...: output_dim = 64
...: model = keras.Sequential()
...: model.add(layers.Dense(output_dim, kernel_initializer='uniform', input_shape=(10,)))
...: model.add(layers.Activation('softmax'))
...:
...: loss_fn = keras.losses.SparseCategoricalCrossentropy()
...: model.compile(loss=loss_fn, optimizer='adam')
...:
...: y_true = tf.compat.v1.placeholder(dtype=tf.float32, shape=(None, output_dim))
...: loss = model.loss.fn(y_true, model.output)
...: mg = build_graph(lambda : loss)
...:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-2-33c7a76593c9> in <module>
16 y_true = tf.compat.v1.placeholder(dtype=tf.float32, shape=(None, output_dim))
17 loss = model.loss.fn(y_true, model.output)
---> 18 mg = build_graph(lambda : loss)
/usr/local/lib/python3.7/site-packages/sparkflow/graph_utils.py in build_graph(func)
12 with first_graph.as_default() as g:
13 v = func()
---> 14 mg = json_format.MessageToJson(tf.train.export_meta_graph())
15 return mg
16
AttributeError: module 'tensorflow_core._api.v2.train' has no attribute 'export_meta_graph'
M0315G commented
Yes I experienced the same issue when training my CNN classifier. The reason behind it is that from TF 2.x, Tensorflow supports eager version and does not depend on the DAGs heavily to run a session. The only solution I could find is to downgrade to TF 1.x and use the API