secondmind-labs/GPflux

GPflux for text classification?

Closed this issue · 2 comments

Hey, many thanks for this project!
I am currently investigating GPs for binary (and one-class-) classification tasks and did some first experiments using pre-trained sentence embeddings for feature representation, PCA for dimension reduction and GPs (GPFlow) for classification.
It sounds promising to use a text embedding, some dense layers and a GP in an end-to-end fashion.
At a first glance, GPflux seems to offer this. After checking the gpflux tutorials (Hybrid Deep GP models), I am actually not sure how to define the inducing variables. Seems like they have to cover the expected data ranges in each latent space dimension, right? Furthermore, I am not sure if GPflux offers variational inference for binary classification. Any comments, suggestions, links that could help to build hybrid models are appreciated. Many thanks!
Kind regards
Jens

Hi Jens,
Thanks for your interest in GPflux! First, to answer your question about inducing variables: the inducing inputs you provide as the user will serve as an initialization, as the inducing locations will be optimized when fit is called. Therefore, the simplest thing to do might be to just use normally distributed variables in the post-NN embedding space, however better performance would probably be obtained with a slightly more involved initialization, such as using k-means on these post-NN embeddings.

For your second question, GPflux does in fact support VI for classification! Any likelihood that can be used in GPflow can be used in GPflux. For instance, following the "Two-dimensional model" example at https://gpflow.readthedocs.io/en/master/notebooks/basics/classification.html, you can replace the model definition and training by the following two-layer DGP:

# Build model
num_data = len(X)
Z = X.copy()  # Use X for inducing variable initialization
kernel1 = gpflow.kernels.SquaredExponential()
inducing_variable1 = gpflow.inducing_variables.InducingPoints(Z.copy())
gp_layer1 = gpflux.layers.GPLayer(
    kernel1, inducing_variable1, num_data=num_data, num_latent_gps=2  # Choose a width-2 intermediate layer
)

kernel2 = gpflow.kernels.SquaredExponential()
inducing_variable2 = gpflow.inducing_variables.InducingPoints(Z.copy())
gp_layer2 = gpflux.layers.GPLayer(
    kernel2,
    inducing_variable2,
    num_data=num_data,
    num_latent_gps=1,
    mean_function=gpflow.mean_functions.Zero()
)

# Note Bernoulli likelihood!
likelihood_layer = gpflux.layers.LikelihoodLayer(gpflow.likelihoods.Bernoulli())
two_layer_dgp = gpflux.models.DeepGP([gp_layer1, gp_layer2], likelihood_layer)

# Train model
training_model = two_layer_dgp.as_training_model()
training_model.compile(tf.optimizers.Adam(0.01))

training_model.fit({"inputs": X, "targets": Y}, batch_size=10, epochs=200, verbose=0)

For the same GPflow notebook, the prediction would then become:

prediction_model = two_layer_dgp.as_prediction_model()
prediction = prediction_model(Xplot)
p = prediction.y_mean  # we only care about the mean

As a final note, if you only want a single-layer GP model, you might find that using GPflow may offer you more flexibility. For instance, you can mix TensorFlow models with GPflow as described in https://gpflow.readthedocs.io/en/master/notebooks/tailor/gp_nn.html. This might give you more flexibility as to which variational scheme you use, and how you define your inducing variables. The downside is that it doesn't have Keras integration, so you wouldn't be able to use the Keras fit method.

I hope this helps!
Best wishes,
Sebastian

Great - thanks for the detailed explanations!!
I have implemented a regression and VI model for my input data and both versions seem to work well. I also used the examples from the GPflux page and other Github issues and actually did it in a slightly different way compared to your example above by not using a DeepGP instance.

Binary classification case:

# my input X is an array of N strings
# y has dimensions (N, 2), since we have a binary classification case (one-hot vectors)
# the number of latent dimensions that are used as input for the GP
input_dim = 16
n_inducing_points = 100
n_out_dims = 2
n_ind = [n_inducing_points]*n_out_dims

# pre-trained embedding URL
url="https://tfhub.dev/google/universal-sentence-encoder-large/5"

# [output_dim, M, input_dim]
z_init = np.random.randn(n_out_dims, n_inducing_points, input_dim)

k = gpflow.kernels.SquaredExponential(variance=1.0, lengthscales=1.0)
kernel = gpflux.helpers.construct_basic_kernel(k, output_dim=n_out_dims)


inducing_variable = gpflux.helpers.construct_basic_inducing_variables(n_ind,
    input_dim, 
    n_out_dims, 
    share_variables=False, 
    z_init=z_init.copy())

gp_layer = gpflux.layers.GPLayer(kernel, inducing_variable, 
num_data=self.n_samples, 
    num_latent_gps=1,
    mean_function=gpflow.mean_functions.Zero())

likelihood =  gpflow.likelihoods.Bernoulli()    
likelihood_container = gpflux.layers.TrackableLayer()
likelihood_container.likelihood = likelihood

# put sequential model together
embed = hub.KerasLayer(hub.load(self.url))    

self.model = tf.keras.Sequential([
    Input(shape=(), dtype=tf.string),
    embed,
    Dense(256, activation='relu'),
    Dense(128, activation='relu'),
    Dense(64, activation='relu'),
    Dense(16, activation='relu'),
    gp_layer,
    likelihood_container,
])

opt = Adam()
loss = gpflux.losses.LikelihoodLoss(likelihood)
model.compile(loss = loss, optimizer=opt)
model.summary()

# make prediction
out = model.predict(X_test)

With my implementation, it seems that I only have access to the mean values and not the variances. Should I therefore use a DeepGP object instead?

Thanks again! Best regards,
Jens