/tfsnippet

A set of utilities for writing and testing TensorFlow models

Primary LanguagePythonMIT LicenseMIT

TFSnippet

Stable stable_build stable_cover stable_doc
Develop develop_build develop_cover develop_doc

TFSnippet is a set of utilities for writing and testing TensorFlow models.

The design philosophy of TFSnippet is non-interfering. It aims to provide a set of useful utilities, possible to be used along with any other TensorFlow libraries and frameworks.

Dependencies

TensorFlow >= 1.5

Installation

Documentation

Examples

Quick Tutorial

From the very beginning, you might import the TFSnippet as:

Distributions

If you use TFSnippet distribution classes to obtain random samples, you shall get enhanced tensor objects, from which you may compute the log-likelihood by simply calling log_prob().

The distributions from ZhuSuan can be casted into a TFSnippet distribution class, in case we haven't provided a wrapper for a certain ZhuSuan distribution:

Data Flows

It is a common practice to iterate through a dataset by mini-batches. The tfsnippet.DataFlow provides a unified interface for assembling the mini-batch iterators.

Training

After you've build the model and obtained the training operation, you may quickly run a training-loop by using utilities from TFSnippet:

input_x = ...  # the input x placeholder
input_y = ...  # the input y placeholder
loss = ...  # the training loss
params = tf.trainable_variables()  # the trainable parameters

# We shall adopt learning-rate annealing, the initial learning rate is
# 0.001, and we would anneal it by a factor of 0.99995 after every step.
learning_rate = spt.AnnealingVariable('learning_rate', 0.001, 0.99995)

# Build the training operation by AdamOptimizer
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.minimize(loss, var_list=params)

# Build the training data-flow
train_flow = spt.DataFlow.arrays(
    [train_x, train_y], batch_size=64, shuffle=True, skip_incomplete=True)
# Build the validation data-flow
valid_flow = spt.DataFlow.arrays([valid_x, valid_y], batch_size=256)

with spt.TrainLoop(params, max_epoch=max_epoch, early_stopping=True) as loop:
    trainer = spt.Trainer(loop, train_op, [input_x, input_y], train_flow,
                          metrics={'loss': loss})
    # Anneal the learning-rate after every step by 0.99995.
    trainer.anneal_after_steps(learning_rate, freq=1)
    # Do validation and apply early-stopping after every epoch.
    trainer.evaluate_after_epochs(
        spt.Evaluator(loop, loss, [input_x, input_y], valid_flow),
        freq=1
    )
    # You may log the learning-rate after every epoch registering an
    # event handler.  Surely you may also add any other handlers.
    trainer.events.on(
        EventKeys.AFTER_EPOCH,
        lambda epoch: trainer.loop.collect_metrics(lr=learning_rate),
    )
    # Print training metrics after every epoch.
    trainer.log_after_epochs(freq=1)
    # Run all the training epochs and steps.
    trainer.run()