Pytorch implementation and comparison speech-to-text
(STT) models.
References:
- Deep Speech: Scaling up end-to-end speech recognition
- Wav2Letter (WIP)
- Jasper: An End-to-End Convolutional Neural Acoustic Model (WIP)
python main.py --num-workers 0 --batch-size 32 --train-data-urls train-clean-100 train-clean-360 --num-epochs 15 --window-stride 20 --optimizer adam --learning-rate 3e-4 --log-steps 100 --checkpoint test
- trained on
train-clean-100
train-clean-360
. - WER on
dev-clean
(9 epochs): 0.33 - pre-trained weights: https://github.com/discort/stt_models/releases/tag/0.1
- TPU needs to recompile the RNN graph for each training example
- CTCLoss is currently (March 2021) is not supported on XLA (source) and here
Create a Google Cloud project
wget https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-332.0.0-darwin-x86_64.tar.gz
tar -xf google-cloud-sdk-332.0.0-darwin-x86_64.tar.gz
./google-cloud-sdk/install.sh
gcloud projects create <PROJECT_ID>
gcloud projects list
gcloud projects delete <PROJECT_ID>
Turn on the Cloud TPU API for that project.
gcloud config set compute/zone <your-zone-here> --project <PROJECT_ID>
Set up a Compute Engine instance
https://console.cloud.google.com/?cloudshell=true
export PROJECT_ID=<project-id>
gcloud config set project ${PROJECT_ID}
gcloud compute instances create deepspeech-xla \
--zone=europe-west4-a \
--machine-type=n1-highmem-16 \
--image-family=torch-xla \
--image-project=ml-images \
--boot-disk-size=300GB \
--scopes=https://www.googleapis.com/auth/cloud-platform
gcloud compute ssh deepspeech-xla --zone=europe-west4-a
gcloud compute tpus create deepspeech-xla \
--zone=europe-west4-a \
--network=default \
--version=pytorch-1.8 \
--accelerator-type=v3-8
gcloud compute tpus list --zone=europe-west4-a
conda activate torch-xla-1.8
export TPU_IP_ADDRESS=<ip-address>
export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
git clone https://github.com/discort/stt_models.git
pip install -r requirements.txt
python main.py \
--use-tpu 1 \
--world-size 1 \
--num-workers 0 \
--batch-size 128 \
--train-data-urls train-clean-100 train-clean-360 train-other-500 \
--val-data-urls dev-clean \
--num-epochs 1
gcloud compute tpus list --zone=europe-west4-a
gcloud compute instances list
gcloud compute tpus delete deepspeech-xla --zone=europe-west4-a
gcloud compute instances delete deepspeech-xla --zone=europe-west4-a