Quantile Regression

A simple method to estimate uncertainty in Machine Learning

Open in Colab

Motivation

When generating predictions about an output, it is sometimes useful to get a confidence score or, similarly, a range of values around this expected value in which the actual value might be found. Practical examples include estimating an upper and lower bound when predicting an ETA or stock price since you not only care about the average outcome but are also very interested in the best-case and worst-case scenarios in when trying to minimize risk e.g. avoid getting late or not loosing money.

While most Machine Learning techniques do not provide a natural way of doing this, in this article, we will be exploring Quantile Regression as a means of doing so. This technique will allow us to learn some critical statistical properties of our data: the quantiles.

Install Dependencies
# uncomment to install dependencies
# ! curl -Ls https://raw.githubusercontent.com/cgarciae/quantile-regression/master/requirements.txt > requirements.txt
# ! pip install -qr requirements.txt
# ! pip install -U matplotlib

To begin our journey into quantile regression, we will first get a hold on some data:

Show Code
import numpy as np
import matplotlib.pyplot as plt
import os

plt.rcParams["figure.dpi"] = int(os.environ.get("FIGURE_DPI", 150))
plt.rcParams["figure.facecolor"] = os.environ.get("FIGURE_FACECOLOR", "white")
np.random.seed(69)


def create_data(multimodal: bool):
    x = np.random.uniform(0.3, 10, 1000)
    y = np.log(x) + np.random.exponential(0.1 + x / 20.0)

    if multimodal:
        x = np.concatenate([x, np.random.uniform(5, 10, 500)])
        y = np.concatenate([y, np.random.normal(6.0, 0.3, 500)])

    return x[..., None], y[..., None]


multimodal: bool = False

x, y = create_data(multimodal)

fig = plt.figure()
plt.scatter(x[..., 0], y[..., 0], s=20, facecolors="none", edgecolors="k")
plt.close()
fig

png

Here we have a simple 2D dataset; however, notice that y has some very peculiar statistical properties:

  1. The data does not have the property of being normally distributed. The data is exponentially distributed.
  2. The previous also means its noise is not symmetric.
  3. Its variance is not constant. It increases as x increases.

When making predictions for this kind of data, we might be very interested in knowing what range of values our data revolves around such that we can judge if a specific outcome is expected or not, what are the best and worst-case scenarios, and so on.

Quantile Loss

The only thing special about quantile regression is its loss function. Instead of the usual MAE or MSE losses for quantile regression, we use the following function:

$$ \begin{aligned} E &= y - f(x) \\ L_q &= \begin{cases} q E, & E \gt 0 \\ (1 - q) (-E), & E \lt 0 \end{cases} \end{aligned} $$

Here $E$ is the error term, and $L_q$ is the loss function for the quantile $q$. So what do we mean by this? Concretely it means that $L_q$ will bias $f(x)$ to output the value of the $q$'th quantile instead of the usual mean or median statistic. The big question is: how does it do it?

First lets notice that this formula can be rewritten as follows:

$$ \begin{aligned} E &= y - f(x) \\ L_q &= \max \begin{cases} q E \\ (q - 1) E \end{cases} \end{aligned} $$

Using $\max$ instead of a conditional statement will make it more straightforward to implement on tensor/array libraries. We will do this next in jax.

import jax
import jax.numpy as jnp


def quantile_loss(q, y_true, y_pred):
    e = y_true - y_pred
    return jnp.maximum(q * e, (q - 1.0) * e)

Loss Landscape

Now that we have this function let us explore the error landscape for a particular set of predictions. Here we will generate values for y_true in the range $[10, 20]$, and for a particular value of $q$ (0.8 by default), we will compute the total error you would get for each value y_pred could take. Ideally, we want to find the value of y_pred where the error is the smallest.

Show code
def calculate_error(q):
    y_true = np.linspace(10, 20, 100)
    y_pred = np.linspace(10, 20, 200)

    loss = jax.vmap(quantile_loss, in_axes=(None, None, 0))(q, y_true, y_pred)
    loss = loss.mean(axis=1)

    return y_true, y_pred, loss


q = 0.8
y_true, y_pred, loss = calculate_error(q)
q_true = np.quantile(y_true, q)


fig = plt.figure()
plt.plot(y_pred, loss)
plt.vlines(q_true, 0, loss.max(), linestyles="dashed", colors="k")
plt.gca().set_xlabel("y_pred")
plt.gca().set_ylabel("loss")
plt.title(f"Q({q:.2f}) = {q_true:.1f}")
plt.close()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
fig

png

If we plot the error, the quantile loss's minimum value is strictly at the value of the $q$th quantile. It achieves this because the quantile loss is not symmetrical; for quantiles above 0.5 it penalizes positive errors stronger than negative errors, and the opposite is true for quantiles below 0.5. In particular, quantile 0.5 is the median, and its formula is equivalent to the MAE.

Deep Quantile Regression

Generally, we would need to create to create a model per quantile. However, if we use a neural network, we can output the predictions for all the quantiles simultaneously. Here will use elegy to create a neural network with two hidden layers with relu activations and linear layers with n_quantiles output units.

import elegy


class QuantileRegression(elegy.Module):
    def __init__(self, n_quantiles: int):
        super().__init__()
        self.n_quantiles = n_quantiles

    def call(self, x):
        x = elegy.nn.Linear(128)(x)
        x = jax.nn.relu(x)
        x = elegy.nn.Linear(64)(x)
        x = jax.nn.relu(x)
        x = elegy.nn.Linear(self.n_quantiles)(x)

        return x

Now we will adequately define a QuantileLoss class that is parameterized by a set of user-defined quantiles.

Show code
class QuantileLoss(elegy.Loss):
    def __init__(self, quantiles):
        super().__init__()
        self.quantiles = np.array(quantiles)

    def call(self, y_true, y_pred):
        loss = jax.vmap(quantile_loss, in_axes=(0, None, -1), out_axes=1)(
            self.quantiles, y_true[:, 0], y_pred
        )
        return jnp.sum(loss, axis=-1)

Notice that we use the same quantile_loss that we created previously, along with some jax.vmap magic to properly vectorize the function. Finally, we will create a simple function that creates and trains our model for a set of quantiles using elegy.

Show code
import optax


def train_model(quantiles, epochs: int, lr: float, eager: bool):
    model = elegy.Model(
        QuantileRegression(n_quantiles=len(quantiles)),
        loss=QuantileLoss(quantiles),
        optimizer=optax.adamw(lr),
        run_eagerly=eager,
    )

    model.fit(x, y, epochs=epochs, batch_size=64, verbose=0)

    return model


if not multimodal:
    quantiles = (0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95)
else:
    quantiles = np.linspace(0.05, 0.95, 9)

model = train_model(quantiles=quantiles, epochs=3001, lr=1e-4, eager=False)
model.summary(x)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer                         Outputs Shape         Trainable         Non-trainable ┃
┃                              ┃                      ┃ Parameters        Parameters    ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ Inputs                       │ (1000, 1)    float64 │                  │               │
├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤
│ linear    Linear             │ (1000, 128)  float32256      1.0 KB  │               │
├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤
│ linear_1  Linear             │ (1000, 64)   float328,256    33.0 KB │               │
├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤
│ linear_2  Linear             │ (1000, 7)    float32455      1.8 KB  │               │
├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤
│ *         QuantileRegression │ (1000, 7)    float32 │                  │               │
├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤
│                                              Total  8,967    35.9 KB                │
└──────────────────────────────┴──────────────────────┴──────────────────┴───────────────┘
                                                                                          
                            Total Parameters: 8,967   35.9 KB                             

Now that we have a model let us generate some test data that spans the entire domain and compute the predicted quantiles.

Show code
x_test = np.linspace(x.min(), x.max(), 100)
y_pred = model.predict(x_test[..., None])

fig = plt.figure()
plt.scatter(x, y, s=20, facecolors="none", edgecolors="k")

for i, q_values in enumerate(np.split(y_pred, len(quantiles), axis=-1)):
    plt.plot(x_test, q_values[:, 0], linewidth=2, label=f"Q({quantiles[i]:.2f})")

plt.legend()
plt.close()
fig

png

Amazing! Notice how the first few quantiles are tightly packed together while the last ones spread out, capturing the behavior of the exponential distribution. We can also visualize the region between the highest and lowest quantiles, and this gives us some bounds on our predictions.

Show code
median_idx = np.where(np.isclose(quantiles, 0.5))[0]

fig = plt.figure()
plt.fill_between(x_test, y_pred[:, -1], y_pred[:, 0], alpha=0.5, color="b")
plt.scatter(x, y, s=20, facecolors="none", edgecolors="k")
plt.plot(
    x_test,
    y_pred[:, median_idx],
    color="r",
    linestyle="dashed",
    label="Q(0.5)",
)
plt.legend()
plt.close()
fig

png

On the other hand, having multiple quantile values allows us to estimate the density of the data. Since the difference between two adjacent quantiles represent the probability that a point lies between them, we can construct a piecewise function that approximates the density of the data.

Show code
def get_pdf(quantiles, q_values):
    densities = []

    for i in range(len(quantiles) - 1):
        area = quantiles[i + 1] - quantiles[i]
        b = q_values[i + 1] - q_values[i]
        a = area / b

        densities.append(a)

    return densities


def piecewise(xs):
    return [xs[i + j] for i in range(len(xs) - 1) for j in range(2)]


def doubled(xs):
    return [np.clip(xs[i], 0, 3) for i in range(len(xs)) for _ in range(2)]

For a given x, we can compute the quantile values and then use these to compute the conditional piecewise density function of y given x.

Show code
xi = 7.0

q_values = model.predict(np.array([[xi]]))[0].tolist()

densities = get_pdf(quantiles, q_values)

fig = plt.figure()
plt.title(f"x = {xi}")
plt.fill_between(piecewise(q_values), 0, doubled(densities))
# plt.fill_between(q_values, 0, densities + [0])
# plt.plot(q_values, densities + [0], color="k")
plt.xlim(0, y.max())
plt.gca().set_xlabel("y")
plt.gca().set_ylabel("p(y)")
plt.close()
fig

png

One of the exciting properties of Quantile Regression is that we did not need to know a priori the output distribution, and training is easy compared to other methods.

Recap

  • Quantile Regression is a simple and effective method for learning some statistics about the output distribution.
  • It is advantageous to establish bounds on the predictions of a model when risk management is desired.
  • The Quantile Loss function is simple and easy to implement.
  • Quantile Regression can be efficiently implemented using Neural Networks since a single model can predict all the quantiles.
  • The quantiles can be used to estimate the conditional density of the data.

Next Steps

  • Try running this notebook with multimodal = True.
  • Take a look at Mixture Density Networks.
  • Learn more about jax and elegy.

Acknowledgments

Many thanks to David Cardozo for his proofreading and getting the notebook to run in colab.