
ValueError when channels > 3

tdrobbins opened this issue · 1 comments

Trying to train on images with more than three channels raises a ValueError. I fixed the bug with a really simple patch to unet.utils.to_rgb() and can submit a pull request, if you like.

Here's the full error:

ValueError                                Traceback (most recent call last)
<ipython-input-8-220be538e83e> in <module>
----> 1 trainer.fit(model4, data)

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/unet/trainer.py in fit(self, model, train_dataset, validation_dataset, test_dataset, epochs, batch_size, **fit_kwargs)
     94                             epochs=epochs,
     95                             callbacks=callbacks,
---> 96                             **fit_kwargs)
     98         self.evaluate(model, test_dataset, prediction_shape)

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
     64   def _method_wrapper(self, *args, **kwargs):
     65     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
---> 66       return method(self, *args, **kwargs)
     68     # Running inside `run_distribute_coordinator` already.

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
    874           epoch_logs.update(val_logs)
--> 876         callbacks.on_epoch_end(epoch, epoch_logs)
    877         if self.stop_training:
    878           break

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in on_epoch_end(self, epoch, logs)
    363     logs = self._process_logs(logs)
    364     for callback in self.callbacks:
--> 365       callback.on_epoch_end(epoch, logs)
    367   def on_train_batch_begin(self, batch, logs=None):

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/unet/callbacks.py in on_epoch_end(self, epoch, logs)
     32         self._log_histogramms(epoch, predictions)
---> 34         self._log_image_summaries(epoch, predictions)
     36         self.file_writer.flush()

~/.local/share/virtualenvs/GitHub-mrs5Yhkq/lib/python3.7/site-packages/unet/callbacks.py in _log_image_summaries(self, epoch, predictions)
     50                                  utils.to_rgb(cropped_labels[..., :1].numpy()),
     51                                  utils.to_rgb(mask)),
---> 52                                 axis=2)
     54         with self.file_writer.as_default():

<__array_function__ internals> in concatenate(*args, **kwargs)

ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 3, the array at index 0 has size 4 and the array at index 1 has size 3

And a quick script to reproduce the error:

import unet
import tensorflow as tf

X = tf.random.normal((100,256,256,4))
Y_flat = tf.random.categorical(tf.math.log([[0.5, 0.5]]),100*256*256)
Y = tf.reshape(Y_flat,(100,256,256))
Y_onehot = tf.one_hot(Y,2)

data = tf.data.Dataset.from_tensor_slices((X,Y_onehot))
train_data = data.take(75)
test_data = data.skip(75)

model4 = unet.build_model(256,256,channels=4,padding="same")
trainer = unet.Trainer()

trainer.fit(model4, data)

Hi, thanks for reporting this. Yes, I somehow always had 1 or 3 channels in mind when writing to_rgb.
Would be great if you could send me a PR