ValueError when channels > 3
tdrobbins opened this issue · 1 comments
tdrobbins commented
Trying to train on images with more than three channels raises a ValueError. I fixed the bug with a really simple patch to unet.utils.to_rgb()
and can submit a pull request, if you like.
Here's the full error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-8-220be538e83e> in <module>
----> 1 trainer.fit(model4, data)
~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/unet/trainer.py in fit(self, model, train_dataset, validation_dataset, test_dataset, epochs, batch_size, **fit_kwargs)
94 epochs=epochs,
95 callbacks=callbacks,
---> 96 **fit_kwargs)
97
98 self.evaluate(model, test_dataset, prediction_shape)
~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside `run_distribute_coordinator` already.
~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
874 epoch_logs.update(val_logs)
875
--> 876 callbacks.on_epoch_end(epoch, epoch_logs)
877 if self.stop_training:
878 break
~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in on_epoch_end(self, epoch, logs)
363 logs = self._process_logs(logs)
364 for callback in self.callbacks:
--> 365 callback.on_epoch_end(epoch, logs)
366
367 def on_train_batch_begin(self, batch, logs=None):
~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/unet/callbacks.py in on_epoch_end(self, epoch, logs)
32 self._log_histogramms(epoch, predictions)
33
---> 34 self._log_image_summaries(epoch, predictions)
35
36 self.file_writer.flush()
~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/unet/callbacks.py in _log_image_summaries(self, epoch, predictions)
50 utils.to_rgb(cropped_labels[..., :1].numpy()),
51 utils.to_rgb(mask)),
---> 52 axis=2)
53
54 with self.file_writer.as_default():
<__array_function__ internals> in concatenate(*args, **kwargs)
ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 3, the array at index 0 has size 4 and the array at index 1 has size 3
And a quick script to reproduce the error:
import unet
import tensorflow as tf
X = tf.random.normal((100,256,256,4))
Y_flat = tf.random.categorical(tf.math.log([[0.5, 0.5]]),100*256*256)
Y = tf.reshape(Y_flat,(100,256,256))
Y_onehot = tf.one_hot(Y,2)
data = tf.data.Dataset.from_tensor_slices((X,Y_onehot))
train_data = data.take(75)
test_data = data.skip(75)
model4 = unet.build_model(256,256,channels=4,padding="same")
unet.finalize_model(model4,loss=tf.keras.losses.categorical_crossentropy)
trainer = unet.Trainer()
trainer.fit(model4, data)
jakeret commented
Hi, thanks for reporting this. Yes, I somehow always had 1 or 3 channels in mind when writing to_rgb
.
Would be great if you could send me a PR