
A Tensorflow implementation Mutual Information estimation methods

Yet to do

  • Include more architectures for the critic network
  • More comments in functions
  • Testing code when X and Z are images
  • Include units test

Variational Bounds on Mutual Information

Throughout this repo, I offer a ready-to-use implementation of state-of-the-art variational methods for mutual information estimation in Tensorflow. This includes:

  • MINE estimator [1], based on the Donsker-Varadhan representation of the KL divergence
  • NWJ estimator [2], based on the f-divergence variational representation of the KL divergence
  • NCE estimator [3], based on noise contrastive estimation principle

Getting Started

Defining the estimator

To define an estimator, one needs to provide a few arguments.

    from estimator import Mi_estimator
    my_mi_estimator = Mi_estimator(regu_name = ...,
                                   dim_x = ...,
                                   dim_z = ...,
                                   batch_size= ...,
                                   critic_layers=[256, 256, 256],

Measuring the MI between static data

For simple MI estimation between two random variables:
    import numpy as np
    dim_x = 10
    dim_z = 15
    x_data = np.random.rand(data_size, dim_x)
    z_data = np.random.rand(data_size, dim_z)
    my_mi_estimator = Mi_estimator(regu_name = 'mine',
                                   dim_x = [dim_x],
                                   dim_z = [dim_z],
                                   batch_size= 128,
                                   critic_layers=[256, 256, 256],
    mi_estimate = my_mi_estimator.fit(x_data, z_data)

Using MI as a regularizer in TensorFlow graph

For use in a Tensorflow graph (as a regulazation term for instance).
    import tensorflow as tf
    dim = 10
    x = tf.placeholder(shape=[batch_size, dim], tf.float32)
    z = tf.placeholder(shape=[batch_size, dim], tf.float32)
    my_mi_estimator = Mi_estimator(regu_name = 'mine',
                                   dim_x = [dim],
                                   dim_z = [dim],
                                   batch_size= 128,
                                   critic_layers=[256, 256, 256],
    estimator_train_op, estimator_quantities = my_mi_estimator(x, z)
    # Define regularized loss
    loss = loss_unregularized + beta * estimator_quantities['mi_for_grads']
    # Define main training op
    main_train_op = optimizer.minimize(loss, var_list=...)
    # At every training iteration:
    # Perfom main training operation to optimize loss
    sess.run([main_train_op], feed_dict={x: ..., z= ..., ...})
    # Then update estimator 
    sess.run(estimator_train_op, feed_dict={x: ..., z= ...})   

End-to-End Examples

We provide code to showcase the two functionalities we just talked about

Estimation of mutual information for static data

In the case of correlated Gaussian random variables:

We have access to the mutual information:

First go the example directory

cd examples/correlated_gaussians/

Then, run some tests to make sure the code doesn't yield any bug:

python3 run_test.py

Finally, to run the experiments, you can check all the available options in "demo_gaussian.py", and loop over any parameters by modifying the header of run_exp.py When that is done, simply run:

python3 run_exp.py

After training, plots are avaible in /plots/. You should obtain something like:

The plots above present the estimated MI after 10 epochs of training on a 100k samples dataset with batch size 128, for the three estimation methods.

Using MI as a regularization term in a loss

The most interesting use case of these bounds is in the context of mutual information maximization. A typical example is reduction of mode collapse in GANs. In the context of GANs, the mutual information I(Z;X) is used as a proxy for the entropy of the generator H(X), where X represents the output of the generator, and Z the noise vector. The maximization of I(X;Z) results in the maximization of H(X).

loss_regularized = loss_gan - beta * I(X;Z)

Use adaptive gradient clipping !!

A naive differentiation of the previously defined loss_regularized term will probably fail to yield the expected result. A very probable reason for that is that I(X;Z) will most likely have gradients whose scale are much larger than those from the loss_gan. In order to maintain a certain balance in gradients' scale, on must use adaptive gradient clipping proposed in [1]. This simple trick consists in scaling the MI term such that its gradient norm doesn't exceed the one from the original loss. Concretely, one must instead minimize:

scale = min( ||G(loss)||, ||G(I(X;Z))|| ) / ||G(I(X;Z))||
loss_regularized = loss_gan - beta * scale * I(X;Z)

where G(.) defines the gradient operator with respect to specified parameters.

We provide a simple example in 2D referred to as "25 gaussians experiments" where the target distribution is:

The simple GAN will produce, with the provided generator and discriminator architecture distributions like:

While the above plot clearly exposes a mode collapse, one can easily reduce this mode collapse by adding a MI regularization:

To see the code for this example, first go the example directory

cd examples/gan/

Then, run some tests to make sure the code doesn't yield any bug:

python3 run_test.py

Finally, to run the experiments, you can check all the available options in "demo_gaussian.py", and loop over any parameters by modifying the header of run_exp.py. Then simply run:

python3 run_exp.py


