manicman1999/StyleGAN2-Tensorflow-2.0

conv2d_mod/Conv2D NCHW not implemented

Opened this issue · 2 comments

generated_images = self.GAN.GM.predict(n1 + [n2], batch_size = BATCH_SIZE)

File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py", line 909, in predict
use_multiprocessing=use_multiprocessing)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 722, in predict
callbacks=callbacks)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 393, in model_iteration
batch_outs = f(ins_batch)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py", line 3740, in call
outputs = self._graph_fn(*converted_inputs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1081, in call
return self._call_impl(args, kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1121, in _call_impl
return self._call_flat(args, self.captured_inputs, cancellation_manager)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1224, in _call_flat
ctx, args, cancellation_manager=cancellation_manager)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 511, in call
ctx=ctx)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/execute.py", line 67, in quick_execute
six.raise_from(core._status_to_exception(e.code, message), None)
File "", line 3, in raise_from
tensorflow.python.framework.errors_impl.UnimplementedError: The Conv2D op currently only supports the NHWC tensor format on the CPU. The op was given the format: NCHW
[[node model_1/conv2d_mod/Conv2D
(defined at /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:1751) ]] [Op:__inference_keras_scratch_graph_11413]

Function call stack:
keras_scratch_graph

Seems conv2d does not take NCHW data format. I tried to force to run on gpu (with tf.device('/gpu:1'):...), it did not work.
I also tried different tf versions (2.0, 2.3), even with docker image for tf2.0, all got into the same issue.

Anyone knows how to get around this issue?
Thanks

It is because it runs on CPU, try batch_size = 1, and in conv_mod.py :

# add this
x = tf.transpose(x, [0, 2, 3, 1])

# change NCHW to NHWC
x = tf.nn.conv2d(x, w, strides=self.strides, padding="SAME", data_format="NHWC")

# add this
x = tf.transpose(x, [0, 3, 1, 2])

Thanks Anthony, your solution works.
I thought weights also need to transpose axis in_chan to match with activation data format, turns out it doesn't.