Overview | Tutorials | Examples | Installation
ktrain
News and Announcements
- 2019-12-10:
- ktrain v0.7.x is released and now uses TensorFlow Keras (i.e.,
tf.keras
) instead of stand-alone Keras. If you're using custom Keras models with ktrain, you must change allkeras
references totensorflow.keras
. That is, don't import Keras like this:from keras.layers import Dense
. Do this instead:from tensorflow.keras.layers import Dense
. If you mix calls to tf.keras with Keras, you will experience problems. Supported versions of TensorFlow include 1.14 and 2.0.
- ktrain v0.7.x is released and now uses TensorFlow Keras (i.e.,
- 2019-11-12:
- ktrain v0.6.x is released and includes pre-canned support for learning from unlabeled or partially labeled text data.
- Coming Soon:
- better support for custom data formats and models
- ability to train HuggingFace Transformer models within ktrain
Overview
ktrain is a lightweight wrapper for the deep learning library Keras (and other libraries) to help build, train, and deploy neural networks. With only a few lines of code, ktrain allows you to easily and quickly:
- estimate an optimal learning rate for your model given your data using a Learning Rate Finder
- utilize learning rate schedules such as the triangular policy, the 1cycle policy, and SGDR to effectively minimize loss and improve generalization
- employ fast and easy-to-use pre-canned models for
text
,vision
, andgraph
data:text
data:- Text Classification: BERT, NBSVM, fastText, GRUs with pretrained word vectors, and other models [example notebook]
- Sequence Labeling: Bidirectional LSTM-CRF with optional pretrained word embeddings [example notebook]
- Unsupervised Topic Modeling with LDA [example notebook]
- Document Similarity with One-Class Learning: given some documents of interest, find and score new documents that are semantically similar to them using One-Class Text Classification [example notebook]
- Document Recommendation Engine: given text from a sample document, recommend documents that are semantically similar to it from a larger corpus [example notebook]
vision
data:- image classification (e.g., ResNet, Wide ResNet, Inception) [example notebook]
graph
data:- graph node classification with graph neural networks (e.g., GraphSAGE) [example notebook]
- perform multilingual text classification (e.g., Chinese Sentiment Analysis with BERT, Arabic Sentiment Analysis with NBSVM)
- load and preprocess text and image data from a variety of formats
- inspect data points that were misclassified and provide explanations to help improve your model
- leverage a simple prediction API for saving and deploying both models and data-preprocessing steps to make predictions on new raw data
Tutorials
Please see the following tutorial notebooks for a guide on how to use ktrain on your projects:
- Tutorial 1: Introduction
- Tutorial 2: Tuning Learning Rates
- Tutorial 3: Image Classification
- Tutorial 4: Text Classification
- Tutorial 5: Learning from Unlabeled Text Data
- Tutorial 6: Text Sequence Tagging for Named Entity Recognition
- Tutorial 7: Graph Node Classification with Graph Neural Networks
- Tutorial A1: Additional tricks, which covers topics such as previewing data augmentation schemes, inspecting intermediate output of Keras models for debugging, setting global weight decay, and use of built-in and custom callbacks.
- Tutorial A2: Explaining Predictions and Misclassifications
Some blog tutorials about ktrain are shown below:
ktrain: A Lightweight Wrapper for Keras to Help Train Neural Networks
Examples
Tasks such as text classification and image classification can be accomplished easily with only a few lines of code.
IMDb Movie Reviews Using BERT
Example: Text Classification ofimport ktrain
from ktrain import text as txt
# load data
(x_train, y_train), (x_test, y_test), preproc = txt.texts_from_folder('data/aclImdb', maxlen=500,
preprocess_mode='bert',
train_test_names=['train', 'test'],
classes=['pos', 'neg'])
# load model
model = txt.text_classifier('bert', (x_train, y_train), preproc=preproc)
# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model,
train_data=(x_train, y_train),
val_data=(x_test, y_test),
batch_size=6)
# find good learning rate
learner.lr_find() # briefly simulate training to find good learning rate
learner.lr_plot() # visually identify best learning rate
# train using 1cycle learning rate schedule for 3 epochs
learner.fit_onecycle(2e-5, 3)
Dogs and Cats Using a Pretrained ResNet50 model
Example: Classifying Images ofimport ktrain
from ktrain import vision as vis
# load data
(train_data, val_data, preproc) = vis.images_from_folder(
datadir='data/dogscats',
data_aug = vis.get_data_aug(horizontal_flip=True),
train_test_names=['train', 'valid'],
target_size=(224,224), color_mode='rgb')
# load model
model = vis.image_classifier('pretrained_resnet50', train_data, val_data, freeze_layers=80)
# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model=model, train_data=train_data, val_data=val_data,
workers=8, use_multiprocessing=False, batch_size=64)
# find good learning rate
learner.lr_find() # briefly simulate training to find good learning rate
learner.lr_plot() # visually identify best learning rate
# train using triangular policy with ModelCheckpoint and implicit ReduceLROnPlateau and EarlyStopping
learner.autofit(1e-4, checkpoint_folder='/tmp/saved_weights')
Named Entity Recognition using a randomly initialized Bidirectional LSTM CRF model
Example: Sequence Labeling forimport ktrain
from ktrain import text as txt
# load data
(trn, val, preproc) = txt.entities_from_txt('data/ner_dataset.csv',
sentence_column='Sentence #',
word_column='Word',
tag_column='Tag',
data_format='gmb')
# load model
model = txt.sequence_tagger('bilstm-crf', preproc)
# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model, train_data=trn, val_data=val)
# conventional training for 1 epoch using a learning rate of 0.001 (Keras default for Adam optmizer)
learner.fit(1e-3, 1)
Cora Citation Graph using a GraphSAGE model
Example: Node Classification onimport ktrain
from ktrain import graph as gr
# load data with supervision ratio of 10%
(trn, val, preproc) = gr.graph_nodes_from_csv(
'cora.content', # node attributes/labels
'cora.cites', # edge list
sample_size=20,
holdout_pct=None,
holdout_for_inductive=False,
train_pct=0.1, sep='\t')
# load model
model=gr.graph_node_classifier('graphsage', trn)
# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=64)
# find good learning rate
learner.lr_find(max_epochs=100) # briefly simulate training to find good learning rate
learner.lr_plot() # visually identify best learning rate
# train using triangular policy with ModelCheckpoint and implicit ReduceLROnPlateau and EarlyStopping
learner.autofit(0.01, checkpoint_folder='/tmp/saved_weights')
Using ktrain on Google Colab? See this simple demo of Multiclass Text Classification with BERT.
Additional examples can be found here.
Installation
Make sure pip is up-to-date with: pip3 install -U pip
.
- Ensure Tensorflow 1.14 or TensorFlow 2 is installed if it is not already
For GPU:
pip3 install "tensorflow_gpu>=1.14,<=2"
For CPU:
pip3 install "tensorflow>=1.14,<=2"
- Install ktrain:
pip3 install ktrain
Some things to note:
-
The ktrain package can be used with either TensorFlow 2.0 or TensorFlow 1.14. If using TensorFlow 2.0, ktrain presently runs in 1.x mode using tf.compat.v1.disable_v2_behavior. In the future, this will be removed and only TensorFlow 2 will be supported.
-
Since some ktrain dependencies have not yet been migrated to
tf.keras
in TensorFlow 2 (or may have other issues), ktrain is temporarily using forked versions of some libraries. Specifically, ktrain uses forked versionseli5
andstellargraph
. If not installed, ktrain will complain when a method or function needing either of these libraries is invoked. To install these forked versions, you can do the following:
pip3 install git+https://github.com/amaiya/eli5@tfkeras_0_10_1
pip3 install git+https://github.com/amaiya/stellargraph@no_tf_dep_082
This code was tested on Ubuntu 18.04 LTS using TensorFlow 1.14 and TensorFlow 2 (Keras version 2.2.4-tf).
Creator: Arun S. Maiya
Email: arun [at] maiya [dot] net