Nicholaswogan/numbalsoda

Unexpected behavior of NumbaLSODA and nb.cufnc in parallel mode

amar-iastate opened this issue · 2 comments

Hi There,
Thanks for sharing this package. I am seeing up to 20x speedup in my code. As I am simulating the evolution of multiple initial conditions, I want to speed up the code further by using parallel=True flag in Numba. I adapted the example you provided in Stackoverflow for my code to use NumbaLSODA with parallel=True flag in Numba. However, I observed that the speedup was marginal even with 8 cores . In some cases, my parallel code was actually slower (!).

After debugging on the 2 state example using the code below, I narrowed the reason down to the way the parameters in the rhs function are defined. if the parameters are either global variables or passed using the data argument in lsoda, the parallel = True speeds up the code almost linearly versus the number of cores. However, if the parameters are defined locally as in f_local below, the speed up is marginal. In fact, the parallel version using f_local is slower than the series version using f_global or f_param.

I thought that local declarations of variables is a good coding practice and it speeds up code execution, but this does not seem to be the case with Numba and NumbaLSODA. I do not know if this behavior is caused by cfunc in Numba or NumbaLSODA. It was definitely unexpected and I thought I will bring it to your notice. Do you know the reason for this behavior? Are local constants not compiled just as global constants? I am unable to dig deeper into the reason and was wondering if you could help.
Best
Amar

from NumbaLSODA import lsoda_sig, lsoda
from matplotlib import pyplot as plt
import numpy as np
import numba as nb
import time

a_glob=np.array([1.5,1.5])

@nb.cfunc(lsoda_sig)
def f_global(t, u_, du, p): # variable a is global
    u = nb.carray(u_, (2,))
    du[0] = u[0]-u[0]*u[1]*a_glob[0]
    du[1] = u[0]*u[1]-u[1]*a_glob[1]

    
@nb.cfunc(lsoda_sig)
def f_local(t, u_, du, p): # variable a is local
    u = nb.carray(u_, (2,))
    a = np.array([1.5,1.5]) 
    du[0] = u[0]-u[0]*u[1]*a[0]
    du[1] = u[0]*u[1]-u[1]*a[1]

    
@nb.cfunc(lsoda_sig)
def f_param(t, u_, du, p): # pass in a as a paameter
    u = nb.carray(u_, (2,))
    du[0] = u[0]-u[0]*u[1]*p[0]
    du[1] = u[0]*u[1]-u[1]*p[1]

funcptr_glob = f_global.address
funcptr_local = f_local.address
funcptr_param = f_param.address
t_eval = np.linspace(0.0,20.0,201)
np.random.seed(0)
a = np.array([1.5,1.5])

@nb.njit(parallel=True)
def main(n, flag):
#     a = np.array([1.5,1.5])
    u1 = np.empty((n,len(t_eval)), np.float64)
    u2 = np.empty((n,len(t_eval)), np.float64)
    for i in nb.prange(n):
        u0 = np.empty((2,), np.float64)
        u0[0] = np.random.uniform(4.5,5.5)
        u0[1] = np.random.uniform(0.7,0.9)
        if flag ==1: # global
            usol, success = lsoda(funcptr_glob, u0, t_eval, rtol = 1e-8, atol = 1e-8)
        if flag ==2: # local
            usol, success = lsoda(funcptr_local, u0, t_eval, rtol = 1e-8, atol = 1e-8)
        if flag ==3: # param
            usol, success = lsoda(funcptr_param, u0, t_eval, data = a, rtol = 1e-8, atol = 1e-8)
        u1[i] = usol[:,0]
        u2[i] = usol[:,1]
    return u1, u2

@nb.njit(parallel=False)
def main_series(n, flag): # same function as above but with parallel flag = False
#     a = np.array([1.5,1.5])
u1 = np.empty((n,len(t_eval)), np.float64)
u2 = np.empty((n,len(t_eval)), np.float64)
for i in nb.prange(n):
    u0 = np.empty((2,), np.float64)
    u0[0] = np.random.uniform(4.5,5.5)
    u0[1] = np.random.uniform(0.7,0.9)
    if flag ==1: # global
        usol, success = lsoda(funcptr_glob, u0, t_eval, rtol = 1e-8, atol = 1e-8)
    if flag ==2: # local
        usol, success = lsoda(funcptr_local, u0, t_eval, rtol = 1e-8, atol = 1e-8)
    if flag ==3: # param
        usol, success = lsoda(funcptr_param, u0, t_eval, data = a, rtol = 1e-8, atol = 1e-8)
    u1[i] = usol[:,0]
    u2[i] = usol[:,1]
return u1, u2

n = 100
# calling first time for JIT compiling
u1, u2 = main(n,1)
u1, u2 = main(n,2)
u1, u2 = main(n,3)

u1, u2 = main_series(n,1)
u1, u2 = main_series(n,1)
u1, u2 = main_series(n,1)

# Running code for large number 
n = 10000
a1 = time.time()
u1, u2 = main(n,1) # global
a2 = time.time()
print(a2 - a1) # this is fast


a1 = time.time()
u1, u2 = main(n,2) # local
a2 = time.time()
print(a2 - a1) # this is slow

a1 = time.time()
u1, u2 = main(n,3) # param
a2 = time.time()
print(a2 - a1) # this is fast and almost identical performance as global

a1 = time.time()
u1, u2 = main_series(n,1) # global
a2 = time.time()
print(a2 - a1) # this is faster than local + parallel

a1 = time.time()
u1, u2 = main_series(n,2) # local
a2 = time.time()
print(a2 - a1) # this is slow

a1 = time.time()
u1, u2 = main_series(n,3) # param
a2 = time.time()
print(a2 - a1) # this is fast and almost identical performance as global

The "local" cases are probably slower because the array a = np.array([1.5,1.5]) must be allocated every time the rhs function is called.

The "global" and "param" cases are faster because they have no allocations at all during integration.

For constants, in your rhs function, its normally a good idea to do something like this

def make_cfunc():
    a=np.array([1.5,1.5])
    @nb.cfunc(lsoda_sig)
    def f(t, u_, du, p): 
        u = nb.carray(u_, (2,))
        du[0] = u[0]-u[0]*u[1]*a[0]
        du[1] = u[0]*u[1]-u[1]*a[1]

    return f

In this case, a will be treated as constant, and will be "baked" into the rhs function. No allocations will occur, but you still don't have an annoying global variable in the scope of the other parts of your code.

This approach only works for constants. Alternatively, if a changes a bunch of times over the course of many integrations, you should pass a in as a parameter via p.

Thanks for this clarification. I was of the opinion that the numba llvm compiler would compile the constant line a = np.array([1.5,1.5]) so that it will not need new allocations each time the function is called. It seems I was mistaken.
The make_cfunc() approach above makes sense and I will use this in my code.