HasnainRaz/FC-DenseNet-TensorFlow

How to load the pretrained model?

UpCoder opened this issue · 4 comments

Hi, I am interested about how to load the pretrained model parameters of encoder part. Can you give me some suggestion? Thanks!

Hey,
You can nest the encoder into a new variable scope, and then use tf.get_collection to get the names of all variables in the encoder as a list, and pass this list to the tf.train.Saver.restore as the argument for var_list, so essentially something like this:

def model(self, x, training):
        """
        Defines the complete graph model for the Tiramisu based on the provided
        parameters.
        Args:
            x: Tensor, input image to segment.
            training: Bool Tesnor, indicating whether training or not.

        Returns:
            x: Tensor, raw unscaled logits of predicted segmentation.
        """
        concats = []
        with tf.variable_scope('encoder'):
            x = tf.layers.conv2d(x,
                                filters=48,
                                kernel_size=[3, 3],
                                strides=[1, 1],
                                padding='SAME',
                                dilation_rate=[1, 1],
                                activation=None,
                                kernel_initializer=tf.contrib.layers.xavier_initializer(),
                                name='first_conv3x3')
            print("First Convolution Out: ", x.get_shape())
            for block_nb in range(0, self.nb_blocks):
                dense = self.dense_block(x, training, block_nb, 'down_dense_block_' + str(block_nb))

                if block_nb != self.nb_blocks - 1:
                    x = tf.concat([x, dense], axis=3, name='down_concat_' + str(block_nb))
                    concats.append(x)
                    x = self.transition_down(x, training, x.get_shape()[-1], 'trans_down_' + str(block_nb))
                    print("Downsample Out:", x.get_shape())

            x = dense
            print("Bottleneck Block: ", dense.get_shape())

           .....Decoder continues downwards

Then when you want to restore the encoder only:

encoder_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='encoder')
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, ckpt_name, var_list=encoder_vars)

Thank you very much. I also want to know how can we download the specifical pretrained model? Can you provide the download link of specifical pretrained DenseNet parameters? Thanks again.

Currently I do not have a pretrained model file, you'll probably have to do the training first yourself.

OK, Thank you!