/trocr

Powerful handwritten text recognition. A simple-to-use, unofficial implementation of the paper "TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models".

Primary LanguagePythonMIT LicenseMIT

Handwritten Character Recognition - an unofficial implementation of the paper

TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models


This is an unofficial implementation of TrOCR based on the Hugging Face transformers library and the TrOCR paper. There is also a repository by the authors of the paper (link). The code in this repository is merely a more simple wrapper to quickly get started with training and deploying this model for character recognition tasks.

 

Results:

Predictions

After training on a dataset of 2000 samples for 8 epochs, we got an accuracy of 96,5%. Both the training and the validation datasets were not completely clean. Otherwise, even higher accuracies would have been possible.

 

Architecture:

TrOCR (TrOCR architecture. Taken from the original paper.)

TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models, Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei, Preprint 2021.

 

 

 



 

1. Setup

Clone the repository and make sure to have conda or miniconda installed. Then go into the directory of the cloned repository and run

conda env create -n trocr --file environment.yml
conda activate trocr

This should install all necessary libraries.

Training without GPU:

It is highly recommended to use a CUDA GPU, but everything also works on cpu. For that, install from file environment-cpu.yml instead.

In case the process terminates with the warning "killed", reduce the batch size to fit into the working memory.

 

2. Using the repository

There are 3 modes, inference, validation and training. All 3 of them can either start with a local model in the right path (see src/constants/paths) or with the pretrained model from huggingface. Inference and Validation use the local model per default, training starts with the huggingface model per default.

 

Inference (Prediction):

python -m src predict <image_files>  # predict image files using the trained local model
python -m src predict data/img1.png data/img2.png  # list all image files
python -m src predict data/*  # also works with shell expansion
python -m src predict data/* --no-local-model  # uses the pretrained huggingface model

Validation:

python -m src validate # uses pretrained local model
python -m src validate --no-local-model # loads pretrained model from huggingface

Training:

python -m src train  # starts with pretrained model from huggingface
python -m src train --local-model  # starts with pretrained local model

 

For validation and training, input images should be in directories train and val and the labels should be in gt/labels.csv. In the csv, each row should consist of image name and then ending, for example img1.png,a (in quotes, if necessary).

It is also pretty straightforward to read labels from somewhere else. For that, just add the necessary code to load_filepaths_and_labels in src/dataset.py.

For choosing a subsample of the train data as validation data, this command can be used

find train -type f | shuf -n <num of val samples> | xargs -I '{}' mv {} val

 

3. Integrating into other projects

If you want to use the predictions as part of a bigger project, you can just use the interface provided by the TrocrPredictor in main. For that make sure to run all code as python modules.

See the following code example:

from PIL import Image
from trocr.src.main import TrocrPredictor

# load images
image_names = ["data/img1.png", "data/img2.png"]
images = [Image.open(img_name) for img_name in image_names]

# directly predict on Pillow Images or on file names
model = TrocrPredictor()
predictions = model.predict_images(images)
predictions = model.predict_for_file_names(image_names)

# print results
for i, file_name in enumerate(image_names):
    print(f'Prediction for {file_name}: {predictions[i]}')

 

4. Adapting the Code

In general, it should be easy to adapt the code for other input formats or use cases.

  • Learning Rate, Batch size, Train Epoch Count, Logging, Word Len: src/configs/constants.py
  • Input Paths, Model Checkpoint Path: src/configs/paths.py
  • Different label format: src/dataset.py : load_filepaths_and_labels

The word len constant is very important. To facilitate batch training, all labels need to be padded to the same length. Some experiments might be needed here. For us, padding to 8 worked well.

If you want to change specifics of the model, you can supply a TrOCRConfig object to the transformers interface. See https://huggingface.co/docs/transformers/model_doc/trocr#transformers.TrOCRConfig for more details.

 

5. Contact

If the setup fails to work, please let me know in a Github issue! Sometimes sub-dependencies update and become incompatible with other dependencies, so the dependency list has to be updated.

Feel free to submit issues with questions about the implementation as well.

For questions about the paper or the architecture, please get in touch with the authors.