/mxnet-transducer

Fast parallel RNN-Transducer.

Primary LanguageC++MIT LicenseMIT

mxnet-transducer

A fast parallel implementation of RNN Transducer (Graves 2013 joint network), on both CPU and GPU for mxnet.

GPU version is now available for Graves2012 add network.

Install and Test

First get mxnet and the code:

git clone --recursive https://github.com/apache/incubator-mxnet
git clone https://github.com/HawkAaron/mxnet-transducer

Copy all files into mxnet dir:

cp -r mxnet-transducer/rnnt* incubator-mxnet/src/operator/contrib/

Then follow the installation instructions of mxnet:

https://mxnet.incubator.apache.org/install/index.html

Finally, add Python API into /path/to/mxnet_root/mxnet/gluon/loss.py:

class RNNTLoss(Loss):
    def __init__(self, batch_first=True, blank_label=0, weight=None, **kwargs):
        batch_axis = 0 if batch_first else 2
        super(RNNTLoss, self).__init__(weight, batch_axis, **kwargs)
        self.batch_first = batch_first
        self.blank_label = blank_label

    def hybrid_forward(self, F, pred, label, pred_lengths, label_lengths):
        if not self.batch_first:
            pred = F.transpose(pred, (2, 0, 1, 3))

        loss = F.contrib.RNNTLoss(pred, label.astype('int32', False), 
                                    pred_lengths.astype('int32', False), 
                                    label_lengths.astype('int32', False), 
                                    blank_label=self.blank_label)
        return loss

From the repo test with:

python test/test.py 10 300 100 50 --mx

Reference