patrick-kidger/optimistix

Parallel multi start

Opened this issue · 10 comments

Hello,

first of all i wanted to say that this is a really nice library. Then i wanted to ask if it is possible to do with filter_vmap something like a parallel multi start optimization. Meaning that if i want to run the minimization from different starting parameter guesses, can this be parallelized?

I have implemented sth like the code below. But it doesn't seem to work:

def optimize_batch(initial_guesses, compute_mse):
    def single_optimization(initial_guess):
        solver = optx.BFGS(rtol=1e-5, atol=1e-5)
        try:
            solver = optx.BFGS(rtol=1e-5, atol=1e-5)
            sol = optx.minimise(compute_mse, solver = solver, y0 = initial_guess, max_steps = 5000)
            loss = compute_mse(sol.value, None)
            success = True
            #print(loss)
        except:
            loss = np.inf
            success = False
        return loss, success

    vectorized_optimization = eqx.filter_vmap(single_optimization)
    return vectorized_optimization(initial_guesses)

def main():

    # Generate initial guesses
    num_samples = 1024
    sobol = qmc.Sobol(d=len(simulation.flexible_params_indices), scramble=True)
    initial_guesses = jnp.array(sobol.random(n=num_samples))

    # Run the batch optimization
    losses, successes = optimize_batch(initial_guesses, compute_mse)

    losses = jnp.where(successes, losses, jnp.inf).astype('float64')
    
    # Find the best result
    best_index = jnp.argmin(losses)
    best_loss = losses[best_index]
    best_params = initial_guesses[best_index]

    print(f"Best loss: {best_loss} at index {best_index}")
    

if __name__ == "__main__":
    main()

Yup, it should be possible to vmap over the initial condition.

Unfortunately your example isn't runnable (I don't know what qmc is).

Hello, thanks for the response. It is scipy.stats.qmc :)
Here should be a runnable example code:

import optimistix as optx
import numpy as np
import equinox as eqx
from jax import config
from scipy.stats import qmc
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Array, Float
import diffrax as dfx
from contextlib import contextmanager
import warnings
from tqdm import tqdm
import matplotlib.pyplot as plt
# JAX configuration
config.update("jax_enable_x64", True)

@contextmanager
def suppress_warnings():
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        yield

def vector_field(
    t, y: Float[Array, "2"], parameters: Float[Array, "4"]
) -> Float[Array, "2"]:
    prey, predator = y
    α, β, γ, δ = parameters
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    d_y = jnp.stack([d_prey, d_predator])
    return d_y

def solve(
    parameters: Float[Array, "4"], y0: Float[Array, "2"], saveat: dfx.SaveAt
) -> Float[Array, "ts"]:
    """Solve a single ODE."""
    term = dfx.ODETerm(vector_field)
    solver = dfx.Tsit5()
    t0 = saveat.subs.ts[0]
    t1 = saveat.subs.ts[-1]
    dt0 = 0.1
    sol = dfx.diffeqsolve(
        term,
        solver,
        t0,
        t1,
        dt0,
        y0,
        args=parameters,
        saveat=saveat,
        adjoint=dfx.DirectAdjoint(),
    )
    return sol.ys

# Generate noisy measurement data
def generate_noisy_data(true_params, y0, ts, noise_level=0.1):
    true_solution = solve(true_params, y0, dfx.SaveAt(ts=ts))
    key = jr.PRNGKey(0)
    noise = jr.normal(key, true_solution.shape) * noise_level
    return true_solution + noise

# Compute MSE loss
def compute_mse(params, data):
    y0, ts, noisy_data = data
    predicted = solve(params, y0, dfx.SaveAt(ts=ts))
    return jnp.mean((predicted - noisy_data)**2)


# def single_optimization(initial_guess, data):
#     solver = optx.BFGS(rtol=1e-5, atol=1e-5)
#     try:
#         sol = optx.minimise(compute_mse, solver=solver, y0=initial_guess, args = data, max_steps=5000)
#         loss = compute_mse(sol.value, data)
#         best_params = sol.value
#         success = True
#     except:
#         loss = jnp.inf
#         success = False
#         best_params = initial_guess
#     return best_params, loss, success

# def main():
#     # True parameters and initial condition
#     true_params = jnp.array([0.5, 0.025, 0.5, 0.005])
#     y0 = jnp.array([10.0, 5.0])
#     ts = jnp.linspace(0, 30, 100)

#     # Generate noisy data
#     noisy_data = generate_noisy_data(true_params, y0, ts)

#     # Package data for optimization
#     data = (y0, ts, noisy_data)

#     # Generate initial guesses
#     num_samples = 100
#     sobol = qmc.Sobol(d=4, scramble=True)
#     initial_guesses = jnp.array(sobol.random(n=num_samples))

#     losses = np.zeros(num_samples)
#     successes = []
#     solutions = np.zeros_like(initial_guesses)

#     # Run the batch optimization
#     for i, initial_guess in tqdm(enumerate(initial_guesses), total=num_samples, desc="Optimizing", ncols=100):
#         with suppress_warnings():
#             try:
#                 best_params, loss, success = single_optimization(initial_guess, data)
#                 successes.append(success)
#                 solutions[i, :] = best_params if success else np.inf
#             except Exception as e:
#                 loss = np.inf
#                 successes.append(False)
#                 solutions[i, :] = initial_guess
#         losses[i] = loss
    
#     print(f"Lowest loss: {np.nanmin(losses)}")
#     print(f"Lowest index: {np.nanargmin(losses)}")

#     best_parameters = solutions[np.nanargmin(losses),:]
    
#     # Generate the best-fit trajectory
#     best_trajectory = solve(best_parameters, y0, dfx.SaveAt(ts=ts))

#     # Plotting
#     fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))

#     # Plot prey
#     ax1.scatter(ts, noisy_data[:, 0], color='red', alpha=0.5, label='Noisy Data (Prey)')
#     ax1.plot(ts, best_trajectory[:, 0], color='blue', label='Best Fit (Prey)')
#     ax1.set_ylabel('Prey Population')
#     ax1.legend()
#     ax1.set_title('Prey Population: Noisy Data vs Best Fit')

#     # Plot predator
#     ax2.scatter(ts, noisy_data[:, 1], color='green', alpha=0.5, label='Noisy Data (Predator)')
#     ax2.plot(ts, best_trajectory[:, 1], color='orange', label='Best Fit (Predator)')
#     ax2.set_xlabel('Time')
#     ax2.set_ylabel('Predator Population')
#     ax2.legend()
#     ax2.set_title('Predator Population: Noisy Data vs Best Fit')

#     plt.tight_layout()
#     plt.show()

#     print(f"True parameters: {true_params}")
#     print(f"Best parameters: {best_parameters}")

# if __name__ == "__main__":
#     main()

def optimize_batch(initial_guesses, compute_mse, data):
    def single_optimization(initial_guess):
        solver = optx.BFGS(rtol=1e-5, atol=1e-5)
        try:
            sol = optx.minimise(compute_mse, solver=solver, y0=initial_guess, args = data, max_steps=5000)
            loss = compute_mse(sol.value, data)
            success = True
            best_params = sol.value  
        except:
            loss = jnp.inf
            success = False
            best_params = initial_guess 
        return best_params, loss, success

    vectorized_optimization = eqx.filter_vmap(single_optimization)
    return vectorized_optimization(initial_guesses)

def main():
    # True parameters and initial condition
    true_params = jnp.array([0.5, 0.025, 0.5, 0.005])
    y0 = jnp.array([10.0, 5.0])
    ts = jnp.linspace(0, 30, 100)

    # Generate noisy data
    noisy_data = generate_noisy_data(true_params, y0, ts)

    # Package data for optimization
    data = (y0, ts, noisy_data)

    # Generate initial guesses
    num_samples = 100
    sobol = qmc.Sobol(d=4, scramble=True)
    initial_guesses = jnp.array(sobol.random(n=num_samples))

    # Run the batch optimization
    best_params, losses, successes = optimize_batch(initial_guesses, compute_mse, data)
    
    # Find the best result
    best_index = jnp.argmin(losses)
    best_loss = losses[best_index]
    best_param = best_params[best_index]

    print(f"True parameters: {true_params}")
    print(f"Best parameters: {best_param}")
    print(f"Best loss: {best_loss} at index {best_index}")

if __name__ == "__main__":
    main()

In the commented out part i added the case with the for loop which takes 12 seconds on my mac. The vmap never finishes

This is quite a large MWE! I'd love to help, but it'd be great if you can condense this down to the most minimal thing that demonstrates your issue first. (Are qmc, tqdm, matplotlib all necessary to reproduce your issue? What about diffrax.diffeqsolve, or will a smaller non-diffrax function suffice? Etc.)

Hello, i think i boiled it down to a smaller MWE, for me it makes a difference if i use diffrax or not:

import optimistix as optx
import equinox as eqx
from jax import config
from scipy.stats import qmc
import jax.numpy as jnp
import diffrax as dfx
# JAX configuration
config.update("jax_enable_x64", True)
import time

def lotka_volterra(t, y, parameters):
    prey, predator = y
    α, β, γ, δ = parameters
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    d_y = jnp.stack([d_prey, d_predator])
    return d_y

def solve(parameters):
    """Solve a single ODE."""
    term = dfx.ODETerm(lotka_volterra)
    solver = dfx.Tsit5()
    saveat = dfx.SaveAt(ts=jnp.linspace(0, 30, 100))
    sol = dfx.diffeqsolve(term, solver, 0, 30, 0.1, jnp.array([1,1]), args=parameters, saveat=saveat, adjoint=dfx.DirectAdjoint())
    return sol.ys

# Compute MSE loss
def compute_mse(params, ts):
    predicted = solve(params)
    return jnp.mean((predicted - jnp.zeros_like(predicted))**2)


def optimize_batch(initial_guesses, compute_mse, ts):
    def single_optimization(initial_guess):
        solver = optx.BFGS(rtol=1e-5, atol=1e-5)
        try:
            sol = optx.minimise(compute_mse, solver=solver, y0=initial_guess, args = ts, max_steps=5000)
            loss = compute_mse(sol.value, ts)
            best_params = sol.value  
        except:
            loss = jnp.inf
            success = False
            best_params = initial_guess 
        return best_params, loss

    vectorized_optimization = eqx.filter_vmap(single_optimization)
    return vectorized_optimization(initial_guesses)

def main():
    # True parameters and initial condition
    ts = jnp.linspace(0, 30, 100)

    # Generate initial guesses
    num_samples = 10000
    sobol = qmc.Sobol(d=4, scramble=True)
    initial_guesses = jnp.array(sobol.random(n=num_samples))

    # Run the batch optimization
    best_params, losses = optimize_batch(initial_guesses, compute_mse, ts)


if __name__ == "__main__":
    main()

Try using adjoint=diffrax.RecursiveCheckpointAdjoint() (the default) rather than diffrax.DirectAdjoint(). The latter does weird and magical things under the hood that are not computationally efficient. When I try your code with that change, then it completes in about a second.

If you're combining Optimistix and Diffrax together in this way, by the way, then you might also like to take a look at #51 and #61. If you bump into any other weird edge cases then know that we're hoping to improve compatibility between them :)

Thanks for you answer, the thing is if i set the adjoint it runs through put produces only one values, namely inf, instead of the whole list i would expect....
Thanks also at the mentioned discussions, i will have a look. In general i. think diffrax together with optimistix has a huge potential!
I experience now something "weird" in my case where i use as mentioned diffrax and optimistix together to try and minimize my likelihood. So i have observed that the forward problem, namely solving the ODE runs without a problem for a stepsize_controller and Kvaerno5 through, whenever i use although an optimization algorithmn with either one of these instead of Tsit5 i get the error:
equinox._errors.EqxRuntimeError: The maximum number of solver steps was reached. Try increasing max_steps.
jax.pure_callback failed

Do you know why this is the case? Increasing max_steps only leads to a longer time until the error pops up

Do you have a MWE? I don't observe that with the code you have above.

For what it's worth, whilst looking at this I did encounter an entirely unrelated error -- which perhaps is affecting you! -- and is fixed in patrick-kidger/equinox#777 . I'd recommend using the version of Equinox from that branch and seeing if it helps?

I think i have fixed it now. But not sure what lead to a fix of this behaviour... Although i still observe that vmap is relatively slow on it's own. Would you know how to integrate maybe sharding here to leverage the use of multi CPU /GPU on the example above?
Also sth unrelated to this, maybe i can also open up an issue there, is there any possibilty to include box constraints with the minimisation easily? So sth in the flavour of scipys trust-constr.

On sharding -- indeed something like shard_map may help! I believe this decouples each piece completely. (Unlike vmap, which uses the same number of iterations across the batch.) I don't have any example for that though.

On box constraints -- we're starting to think about implementing those over in #64, but it's still early days yet.