This is a Tensorflow implementation of the "V-Net" architecture used for 3D medical imaging segmentation. This code only implements the Tensorflow graph, it must be used within a training program.
This is an example a network this code implements.
The function v_net(tf_input, input_channels, output_channels, n_channels)
has the following arguments:
-
tf_input
: a rank 5 tensor with shape[batch_size, X, Y, Z, input_channels]
whereX, Y, Z
are the spatial dimensions of the images andinput_channels
is the number of channels the images have; -
input_channels
: the number of channels of the input images; -
output_channels
is the number of desired output channels.v_net()
will return a tensor with the same shape astf_input
but with a different number of channels i.e.[batch_size, x, y, z, output_channels]
. -
n_channels
is the number of channels used internally in the network. In the original paper this number was 16. This number doubles at every level of the contracting path. See the image for better understanding of this number.
Apart from the number of input channels input_channels
none of the dimensions of tf_input
need to be known.
This allows reading examples from queues and even train the network with examples of different sizes.
In a binary segmentation problem you could use output_channels=1
with a sigmoid loss and in a three class
problem you could use output_channels=3
with a softmax loss.
import tensorflow as tf
from VNet import v_net
input_channels = 6
ouptut_channels = 1
tf_input = tf.placeholder(dtype=tf.float32, shape=(10, 190, 190, 20, input_channels))
logits = v_net(tf_input, input_channels, 16, output_channels)
logits
will have shape [10, 190, 190, 20, 1]
, it can the be flattened and used in the sigmoid cross entropy function.
There are two different slightly implementations in the code.
VNetOriginal.py
implements the network as is in the original paper with
three small differences.
- The input can have more than one channel. If the input has more than one channel han one more convolution is added
in level one to increase the input number of channels to match
n_channels
. If the input has only one channel than it is broadcasted in the first skip connection (repeated ``n_channel` times). n_channels
does not need to be 16.- The output does not need to have two channels like in the original architecture.
If you input_channels=1
, n_channels=16
and output_channels=2
the function v_net()
implements the
original architecture.
VNet.py
is an updated version of the architecture with the following improvements/fixes:
- Relu non-linearities replaced with PRelu (Parametric Relu)
- The "V-Net" paper implemented the element-wise sum of the skip connection after the non-linearity. However, according to the original residual network paper, this should be done before the last non-linearity of the convolution block. This is fixed in this implementation.