dashee87/blogScripts

Dixon-Coles extremely slow

Opened this issue · 1 comments

Hi,

thanks for your very useful blog posts! We've been using them for teaching our software-practicals. One thing that regularly comes up is that the Dixon-Coles implementation is extremely slow. This is due to the use of python loops.

The performance can be significantly increased by making use of vectorization and pulling calculations that can be done ahead-of-time out of the optimization objective function. Most prominently, this includes looking up the correct team indices for each match. There are a few other tricks with diminishing returns. In total the runtime can be sped up by a factor of about 1000.

I got the impression this repository is not really maintained so I'm not going through the effort of a PR, just stating this for other visitors.

My own implementation is below (with somewhat different DataFrame column names and parameter dict):

import numpy as np
import pandas as pd
from scipy.stats import poisson
from scipy.optimize import minimize

from datetime import datetime


def predict_outcome_probs(
    params: dict,
    team1: str,
    team2: str,
) -> np.ndarray:
    """
    Predict outcome probabilities  ``[p_home, p_draw, p_away]`` according to
    Dixon-Coles.

    Args:
        params: trained parameters returned by ``train_parameters``
        team1: name of the home team
        team2: name of the away team
    """
    probs = predict_score_probs(params, team1, team2)
    return np.array([
        np.tril(probs, -1).sum(),
        np.diag(probs).sum(),
        np.triu(probs, 1).sum(),
    ])


def train_parameters(
    matches: pd.DataFrame,
    tau: float = None,
    t0: datetime = None,
    initial: dict = None,
    options: dict = {"maxiter": 100},
) -> dict:
    """
    Train Dixon-Coles model and return trained parameter values.

    Args:
        matches: dataframe with columns Team1, Team2, Score1, Score2, Datetime
        tau: mean lifetime for weight decay of the match importance
        t0: reference time for weight decay
        initial: initial parameter values
        options: optimizer options
    """
    teams = np.unique([matches.Team1, matches.Team2])
    team1 = pd.Categorical(matches.Team1, categories=teams).codes
    team2 = pd.Categorical(matches.Team2, categories=teams).codes
    score1 = matches.Score1.values
    score2 = matches.Score2.values

    if tau is None:
        weights = 1
    else:
        if t0 is None:
            t0 = datetime.now()
        time = (t0 - matches.Datetime).dt.days
        weights = np.exp(-time / tau)

    n_teams = len(teams)
    if initial is None:
        initial = np.concatenate((
            np.random.uniform(0, 1, n_teams),   # attack
            np.random.uniform(0, -1, n_teams),  # defense
            np.array([1.0, 0.0]),               # home, rho
        ))

    def objective(params):
        attack_coefs = params[:n_teams]
        defend_coefs = params[n_teams:-2]
        home, rho = params[-2:]
        return -dc_log_likelihood(
            score1, score2,
            attack_coefs[team1],
            defend_coefs[team1],
            attack_coefs[team2],
            defend_coefs[team2],
            home, rho, weights,
        ).sum()

    opt_output = minimize(
        objective,
        initial,
        options=options,
        method='L-BFGS-B',
    )
    return {
        'base': 0,
        'attack': dict(zip(teams, opt_output.x[:n_teams])),
        'defense': dict(zip(teams, opt_output.x[n_teams:-2])),
        'home': opt_output.x[-1],
        'rho': opt_output.x[-2],
    }


def rho_correction(goals1, goals2, lambda1, lambda2, rho):
    """Correction term for the 0:0, 0:1, 1:0, 1:1 probabilities."""
    # rho_min = max(-1/lambda1, -1/lambda2)
    # rho_max = min(1/(lambda1 * lambda2), 1)
    rho_min = -1 / np.maximum(lambda1.max(), lambda2.max())
    rho_max = 1 / np.maximum((lambda1 * lambda2).max(), 1.0)
    rho = np.clip(rho, rho_min + 1e-3, rho_max - 1e-3)

    x_corr = np.where(goals1 == 0, -lambda1, 1.0)
    y_corr = np.where(goals2 == 0, lambda2, -1.0)
    r_corr = np.where((goals1 < 2) & (goals2 < 2), rho, 0.0)
    return 1 + x_corr * y_corr * r_corr


def dc_log_likelihood(
    goals1, goals2,
    attack1, defend1,
    attack2, defend2,
    rho, home,
    weights=1,
):
    """Log likelihood of the Dixon-Coles model."""
    lambda1 = np.exp(attack1 + defend2 + home)
    lambda2 = np.exp(attack2 + defend1)
    return weights * (
        np.log(rho_correction(goals1, goals2, lambda1, lambda2, rho)) +
        poisson.logpmf(goals1, lambda1) +
        poisson.logpmf(goals2, lambda2)
    )


def predict_score_probs(params, team1, team2, max_goals=10):
    """Return matrix of score probabilities for the Dixon-Coles model."""
    output_matrix, (avg1, avg2) = predict_score_probs_poisson(
        params, team1, team2, max_goals)
    output_matrix[:2, :2] *= rho_correction(
        np.arange(2)[:, None],
        np.arange(2)[None, :],
        avg1, avg2,
        params['rho'],
    )
    return output_matrix


def predict_score_probs_poisson(params, team1, team2, max_goals=10):
    """Return matrix of score probabilities for the poisson model."""
    goals_average = get_expected_goals(params, team1, team2)
    home_goals, away_goals = poisson.pmf(
        np.arange(max_goals + 1).reshape((1, -1)),
        goals_average.reshape((2, 1)))
    return np.outer(home_goals, away_goals), goals_average


def get_expected_goals(params, team1, team2):
    """Return expected goal rates for either team in this matchup."""
    base = params['base']
    home = params['home']
    attack = params['attack']
    defense = params['defense']
    return np.array([
        np.exp(base + attack.get(team1, 0) + defense.get(team2, 0) + home),
        np.exp(base + defense.get(team1, 0) + attack.get(team2, 0))
    ])

This can be made even slightly faster, e.g. by using jax:

--- dixon_coles.py	2022-01-06 15:31:29.235639125 +0100
+++ dixon_coles_jax.py	2022-01-06 15:40:24.100362012 +0100
@@ -1,11 +1,16 @@
+import jax
+import jax.numpy as jnp
 import numpy as np
 import pandas as pd
-from scipy.stats import poisson
+from jax.scipy.stats import poisson
 from scipy.optimize import minimize
 
 from datetime import datetime
 
 
+jax.config.update("jax_enable_x64", True)
+
+
 def predict_outcome_probs(
     params: dict,
     team1: str,
@@ -57,7 +62,7 @@
         if t0 is None:
             t0 = datetime.now()
         time = (t0 - matches.Datetime).dt.days
-        weights = np.exp(-time / tau)
+        weights = jnp.array(np.exp(-time / tau))
 
     n_teams = len(teams)
     if initial is None:
@@ -67,7 +72,9 @@
             np.array([1.0, 0.0]),               # home, rho
         ))
 
-    def objective(params):
+    @jax.jit
+    @jax.value_and_grad
+    def objective_jax(params):
         attack_coefs = params[:n_teams]
         defend_coefs = params[n_teams:-2]
         home, rho = params[-2:]
@@ -80,11 +87,17 @@
             home, rho, weights,
         ).sum()
 
+    def objective(x):
+        loss, grad = objective_jax(jnp.array(x, dtype=jnp.float64))
+        return (np.array(loss, dtype=np.float64),
+                np.array(grad, dtype=np.float64))
+
     opt_output = minimize(
         objective,
         initial,
         options=options,
         method='L-BFGS-B',
+        jac=True,
     )
     return {
         'base': 0,
@@ -99,13 +112,13 @@
     """Correction term for the 0:0, 0:1, 1:0, 1:1 probabilities."""
     # rho_min = max(-1/lambda1, -1/lambda2)
     # rho_max = min(1/(lambda1 * lambda2), 1)
-    rho_min = -1 / np.maximum(lambda1.max(), lambda2.max())
-    rho_max = 1 / np.maximum((lambda1 * lambda2).max(), 1.0)
-    rho = np.clip(rho, rho_min + 1e-3, rho_max - 1e-3)
-
-    x_corr = np.where(goals1 == 0, -lambda1, 1.0)
-    y_corr = np.where(goals2 == 0, lambda2, -1.0)
-    r_corr = np.where((goals1 < 2) & (goals2 < 2), rho, 0.0)
+    rho_min = -1 / jnp.maximum(lambda1.max(), lambda2.max())
+    rho_max = 1 / jnp.maximum((lambda1 * lambda2).max(), 1.0)
+    rho = jnp.clip(rho, rho_min + 1e-3, rho_max - 1e-3)
+
+    x_corr = jnp.where(goals1 == 0, -lambda1, 1.0)
+    y_corr = jnp.where(goals2 == 0, lambda2, -1.0)
+    r_corr = jnp.where((goals1 < 2) & (goals2 < 2), rho, 0.0)
     return 1 + x_corr * y_corr * r_corr
 
 
@@ -117,10 +130,10 @@
     weights=1,
 ):
     """Log likelihood of the Dixon-Coles model."""
-    lambda1 = np.exp(attack1 + defend2 + home)
-    lambda2 = np.exp(attack2 + defend1)
+    lambda1 = jnp.exp(attack1 + defend2 + home)
+    lambda2 = jnp.exp(attack2 + defend1)
     return weights * (
-        np.log(rho_correction(goals1, goals2, lambda1, lambda2, rho)) +
+        jnp.log(rho_correction(goals1, goals2, lambda1, lambda2, rho)) +
         poisson.logpmf(goals1, lambda1) +
         poisson.logpmf(goals2, lambda2)
     )