ThomasDelteil/HandwrittenTextRecognition_MXNet

Question: is it possible to hybridize lstm_ocr_ctc?

Closed this issue · 4 comments

Hi, before all amazing work with HandwrittenTextRecognition_MXNet.
It's also my first time working with MxNet so a please be patience.

I've successfully trained my own word ocr net, although when trying to hybridize the net I've not been able to converge into a usable solution.

I think the problem has been presented in [https://github.com//issues/9](Question on the shape of feature map of OCR_LSTM_CTC) witch is the necessity of making this transformation make us use split which returns a NDArray .

Is there a way in which is possible to Hybridize the net and train with it? Maybe do it for an already trained net?

Hey @lmmaia

Thanks and yes it is indeed possible, when we first worked on it LSTMs were not hybridizable. That's not the case anymore 👍. Here is the hybridizable version of the network, in my case training is 30% faster.

If you want to make training even faster, use multiple workers for data loading (I train an epoch in 10s with this + hybridization):

train_data = gluon.data.DataLoader(train_ds.transform(augment_transform), batch_size, shuffle=True, last_batch="discard", num_workers=8)
test_data = gluon.data.DataLoader(test_ds.transform(transform), batch_size, shuffle=False, last_batch="discard", num_workers=8)

Model:

class EncoderLayer(gluon.HybridBlock):
    def __init__(self, hidden_states=200, lstm_layers=1, **kwargs):
        super(EncoderLayer, self).__init__(**kwargs)
        with self.name_scope():
            self.lstm = mx.gluon.rnn.LSTM(hidden_states, lstm_layers, bidirectional=True)
            
    def hybrid_forward(self, F, x):
        x = x.transpose((0, 3, 1, 2))
        x = x.flatten()
        x = x.split(num_outputs=max_seq_len, axis=1) # (SEQ_LEN, N, CHANNELS)
        x = F.concat(*[elem.expand_dims(axis=0) for elem in x], dim=0)
        x = self.lstm(x)
        x = x.transpose((1, 0, 2)) #(N, SEQ_LEN, HIDDEN_UNITS)
        return x

class Network(gluon.HybridBlock):
    def __init__(self, num_downsamples=2, resnet_layer_id=4, lstm_hidden_states=200, lstm_layers=1, **kwargs):
        super(Network, self).__init__(**kwargs)
        self.p_dropout = 0.5
        self.num_downsamples = num_downsamples
        self.body = self.get_body(resnet_layer_id=resnet_layer_id)

        self.encoders = gluon.nn.HybridSequential()
        
        for _ in range(self.num_downsamples):
            encoder = self.get_encoder(lstm_hidden_states=lstm_hidden_states, lstm_layers=lstm_layers)
            self.encoders.add(encoder)
        self.decoder = self.get_decoder()
        self.downsampler = self.get_down_sampler(64)

    def get_down_sampler(self, num_filters):
        '''
        Creates a two-stacked Conv-BatchNorm-Relu and then a pooling layer to
        downsample the image features by half.
        '''
        out = gluon.nn.HybridSequential()
        for _ in range(2):
            out.add(gluon.nn.Conv2D(num_filters, 3, strides=1, padding=1))
            out.add(gluon.nn.BatchNorm(in_channels=num_filters))
            out.add(gluon.nn.Activation('relu'))
        out.add(gluon.nn.MaxPool2D(2))
        out.collect_params().initialize(mx.init.Normal(), ctx=ctx)
        out.hybridize()
        return out

    def get_body(self, resnet_layer_id):
        '''
        Create the feature extraction network of the SSD based on resnet34.
        The first layer of the res-net is converted into grayscale by averaging the weights of the 3 channels
        of the original resnet.

        Returns
        -------
        network: gluon.nn.HybridSequential
            The body network for feature extraction based on resnet
        '''
        
        pretrained = resnet34_v1(pretrained=True, ctx=ctx)
        pretrained_2 = resnet34_v1(pretrained=True, ctx=mx.cpu(0))
        first_weights = pretrained_2.features[0].weight.data().mean(axis=1).expand_dims(axis=1)
        # First weights could be replaced with individual channels.
        
        body = gluon.nn.HybridSequential()
        with body.name_scope():
            first_layer = gluon.nn.Conv2D(channels=64, kernel_size=(7, 7), padding=(3, 3), strides=(2, 2), in_channels=1, use_bias=False)
            first_layer.initialize(mx.init.Normal(), ctx=ctx)
            first_layer.weight.set_data(first_weights)
            body.add(first_layer)
            body.add(*pretrained.features[1:-resnet_layer_id])
        return body

    def get_encoder(self, lstm_hidden_states, lstm_layers):
        encoder = gluon.nn.HybridSequential()
        encoder.add(EncoderLayer(hidden_states=lstm_hidden_states, lstm_layers=lstm_layers))
        encoder.add(gluon.nn.Dropout(self.p_dropout))
        encoder.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
        return encoder
    
    def get_decoder(self):
        alphabet_size = len(string.ascii_letters+string.digits+string.punctuation+' ') + 1
        decoder = mx.gluon.nn.Dense(units=alphabet_size, flatten=False)
        decoder.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
        return decoder

    def hybrid_forward(self, F, x):
        features = self.body(x)
        hidden_states = []
        hs = self.encoders[0](features)
        hidden_states.append(hs)
        for i, _ in enumerate(range(self.num_downsamples - 1)):
            features = self.downsampler(features)
            hs = self.encoders[i+1](features)
            hidden_states.append(hs)
        hs = F.concat(*hidden_states, dim=2)
        output = self.decoder(hs)
        return output

net = Network(num_downsamples=num_downsamples, resnet_layer_id=resnet_layer_id , lstm_hidden_states=lstm_hidden_states, lstm_layers=lstm_layers)
net.hybridize(static_alloc=True, static_shape=True)

although when trying to hybridize the net I've not been able to converge into a usable solution

Did you mean the training doesn't converge well? Yes I am noticing this as well, very strange

Thanks for the prompt response.

Yes I have tried to train before with my hybridized version and the training didn't converged well.

I'll will try my data set with this version you just shared and retrain to see if it performs better 👍.

So, I was able to train and get good results with the suggested classes, I did some changes just to fit the project I had before but its amazing ;)
I'm now training an epoch with 128 batch in 7s on a single gpu.