/bertax_training

Training scripts for BERTax

Primary LanguagePython

BERTax training utilities

This repository contains utilities for pre-training and fine-tuning BERTax models, as well as various utility functions and scripts used in the development of BERTax.

Additionally to the described mode of training BERTax on genomic DNA sequences, development scripts for training on gene sequences are included as well.

Training new BERTax models

Data preparation

For the training of a new BERTax model the user must provide one of the three following data-structures.

multi fastas with TaxIds

For training any model you can use multi fastas which should contain the sequences of your classes of interest.

[class_1].fa
[class_2].fa
...
[class_n].fa

The fasta files contain headers, which consist of a sequences associated species’ TaxID and a concurrent index.

Example fasta file (you can find our pretraining fasta files here):

>380669 0
TCGAGATACCAGATGGAAATCCTCCAGAGGTATTATCGGAA
>264076 1
GCAGACGAGTTCACCACTGCTGCAGGAAAAGAT
>218387 2
AACTATGCATAGGGCCTTTGCCGGCACTAT

After generating these files you can transform them for training.

preprocessing/fasta2fragments.py / preprocessing/fragments2fasta.py
convert between multi-fasta and json training files
preprocessing/genome_db.py, preprocessing/genome_mince.py
For splitting whole genomes into smaller fragments

fragments directories

For training the normal, genomic DNA-based models, a fixed directory structure with one json file and one txt file per classis required:

[class_1]_fragments.json
[class_1]_species_picked.txt
[class_2]_fragments.json
[class_2]_species_picked.txt
...
[class_n]_fragments.json
[class_n]_species_picked.txt

The json files must consist of a simple list of sequences. Example json file:

["ACGTACGTACGATCGA", "TACACTTTTTA", ..., "ATACTATCTATCTA"]

The txt files are a ordered lists of the corresponding TaxIDs, meaning the first listed TaxID describes the taxonomical origin of the first sequence in the json file with the same prefix.

Example txt file:

380669
264076
218387
11569
204873
346884
565995
11318

gene model training directories

The gene models were used in an early stage of BERTax development, where a different directory structure was required:

Each sequence is contained in a fasta file, additionally, a json file containg all file-names and associated classes can speed up preprocessing tremendously.

[class_1]/
  [sequence_1.fa]
  [seuqence_2.fa]
  ...
  [sequence_n.fa]
[class_2]/
  ...
.../
[class_n]/
  ...
  [sequence_l.fa]
{files.json}

The json-files cotains a list of two lists with equal size, the first list contains filepaths to the fasta files and the second list the associated classes:

[["class_1/sequence1.fa", "class_1/sequence2.fa", ..., "class_n/sequence_l.fa"],
 ["class_1", "class_1", ..., "class_n"]]

Training process

The normal, genomic DNA-based model can be pre-trained with models/bert_nc.py and fine-tuned with models/bert_nc_finetune.py.

For example, the BERTax model was pre-trained with:

python -m models.bert_nc fragments_root_dir --batch_size 32 --head_num 5 \
       --transformer_num 12 --embed_dim 250 --feed_forward_dim 1024 --dropout_rate 0.05 \
       --name bert_nc_C2 --epochs 10

and fine-tuned with:

python -m models.bert_nc_finetune bert_nc_C2.h5 fragments_root_dir --multi_tax \
       --epochs 15 --batch_size 24 --save_name _small_trainingset_filtered_fix_classes_selection \
       --store_predictions --nr_seqs 1000000000

The development gene models can be pre-trained with models/bert_pretrain.py:

python -m models.bert_pretrain bert_gene_C2 --epochs 10 --batch_size 32 --seq_len 502 \
	 --head_num 5 --embed_dim 250 --feed_forward_dim 1024 --dropout_rate 0.05 \
	 --root_fa_dir sequences --from_cache sequences/files.json

and fine-tuned with models/bert_finetune.py:

python -m models.bert_finetune bert_gene_C2_trained.h5 --epochs 4 \
	 --root_fa_dir sequences --from_cache sequences/files.json

All training scripts can be called with the --help flag to adjust various parameters.

Using BERT models

It is recommended to use fine-tuned models in the BERTax tool with the parameter --custom_model_file.

However, a much more minimal script to predict multi-fasta sequences with the trained model is also available in this repository:

python -m utils.test_bert finetuned_bert.h5 --fasta sequences.fa

Benchmarking

If the user needs a predefined training and test set, for example for benchmarking different approaches:

python -m preprocessing.make_dataset single_sequences_json_folder/ out_folder/ --unbalanced

This creates a the files test.tsv, train.tsv, classes.pkl which can be used by bert_nc_finetune

python -m models.bert_nc_finetune bert_nc_trained.h5 make_dataset_out_folder/ --unbalanced --use_defined_train_test_set

If fasta files are necessary, e.g., for competing methods, you can parse the train.tsv and test.tsv via

python -m preprocessing.dataset2fasta make_dataset_out_folder/

Additional scripts

preprocessing/fasta2fragments.py / preprocessing/fragments2fasta.py
convert between multi-fasta and json training files
preprocessing/genome_db.py, preprocessing/genome_mince.py
scripts used to generate genomic fragments for training

Dependencies

  • tensorflow >= 2
  • keras
  • numpy
  • tqdm
  • scikit-learn
  • keras-bert
  • biopython