Numba jit with scipy

9,985

Solution 1

Numba simply is not a general-purpose library to speed code up. There is a class of problems that can be solved in a much faster way with numba (especially if you have loops over arrays, number crunching) but everything else is either (1) not supported or (2) only slightly faster or even a lot slower.

[...] would it even speed up the code?

SciPy is already a high-performance library so in most cases I would expect numba to perform worse (or rarely: slightly better). You might do some profiling to find out if the bottleneck is really in the code that you jitted, then you could get some improvements. But I suspect the bottleneck will be in the compiled code of SciPy and that compiled code is probably already heavily optimized (so it's really unlikely that you find an implementation that could "only" compete with that code).

Is there a way to use jit with scipy.integrate.quad and curve_fit without manually deleting all try except structures from the scipy code?

As you correctly assumed try and except is simply not supported by numba at this time.

2.6.1. Language

2.6.1.1. Constructs

Numba strives to support as much of the Python language as possible, but some language features are not available inside Numba-compiled functions. The following Python language features are not currently supported:

[...]

  • Exception handling (try .. except, try .. finally)

So the answer here is No.

Solution 2

Nowadays try and except work with numba. However numba and scipy are still not compatible. Yes, Scipy calls compiled C and Fortran, but it does so in a way that numba can't deal with.

Fortunately there are alternatives to scipy that work well with numba! Below I use NumbaQuadpack and NumbaMinpack to do some curve fitting and integration similar to your example code. Disclaimer: i put together these packages. Below, I also give an equivalent implementation in scipy.

The Scipy implementation is ~18 times slower than the Scipy alternatives (NumbaQuadpack and NumbaMinpack).

Using Scipy alternatives (0.23 ms)

from NumbaQuadpack import quadpack_sig, dqags
from NumbaMinpack import minpack_sig, lmdif
import numpy as np
import numba as nb
import timeit
np.random.seed(0)

x = np.linspace(0,2*np.pi,100)
y = np.sin(x)+ np.random.rand(100)

@nb.jit
def fitfunction(x, A, B):
    return A*np.sin(B*x)

@nb.cfunc(minpack_sig)
def fitfunction_optimize(u_, fvec, args_):
    u = nb.carray(u_,(2,))
    args = nb.carray(args_,(200,))
    A, B = u
    x = args[:100]
    y = args[100:]
    for i in range(100):
        fvec[i] = fitfunction(x[i], A, B) - y[i] 
optimize_ptr = fitfunction_optimize.address

@nb.cfunc(quadpack_sig)
def fitfunction_integrate(x, data):
    A = data[0]
    B = data[1]
    return fitfunction(x, A, B)
integrate_ptr = fitfunction_integrate.address

@nb.njit
def fast_function():  
    try:
        neqs = 100
        u_init = np.array([2.0,.8],np.float64)
        args = np.append(x,y)
        fitparam, fvec, success, info = lmdif(optimize_ptr , u_init, neqs, args)
        if not success:
            raise Exception

        lower = 0.0
        uppers = np.linspace(np.pi,np.pi*2.0,200)
        solutions = np.empty(len(uppers))
        for i in range(len(uppers)):
            solutions[i], abserr, success = dqags(integrate_ptr, lower, uppers[i], data = fitparam)
            if not success:
                raise Exception
    except:
        print('doing something else')
        
fast_function()
iters = 1000
t_nb = timeit.Timer(fast_function).timeit(number=iters)/iters
print(t_nb)

Using Scipy (4.4 ms)

import scipy.integrate as integrate
from scipy.optimize import curve_fit
import numpy as np
import numba as nb
import timeit

np.random.seed(0)

x = np.linspace(0,2*np.pi,100)
y = np.sin(x)+ np.random.rand(100)

@nb.jit
def fitfunction(x, A, B):
    return A*np.sin(B*x)

def function():
    try:
        p0 = (2.0,.8)
        fit_param, fit_cov = curve_fit(fitfunction, x, y, p0=p0, maxfev=500)

        lower = 0.0
        uppers = np.linspace(np.pi,np.pi*2.0,200)
        solutions = np.empty(len(uppers))
        for i in range(len(uppers)):
            solutions[i], abserr = integrate.quad(fitfunction, lower, uppers[i], args = tuple(fit_param))
    except:
        print('do something else')

function()
iters = 1000
t_sp = timeit.Timer(function).timeit(number=iters)/iters
print(t_sp)
Share:
9,985

Related videos on Youtube

Katermickie
Author by

Katermickie

Updated on September 15, 2022

Comments

  • Katermickie
    Katermickie over 1 year

    So I wanted to speed up a program I wrote with the help of numba jit. However jit seems to be not compatible with many scipy functions because they use try ... except ... structures that jit cannot handle (Am I right with this point?)

    A relatively simple solution I came up with is to copy the scipy source code I need and delete the try except parts (I already know that it will not run into errors so the try part will always work anyways)

    However I do not like this solution and I am not sure if it will work.

    My code structure looks like the following

    import scipy.integrate as integrate
    from scipy optimize import curve_fit
    from numba import jit
    
    def fitfunction():
        ...
    
    @jit
    def function(x):
        # do some stuff
        try:
            fit_param, fit_cov = curve_fit(fitfunction, x, y, p0=(0,0,0), maxfev=500)
            for idx in some_list:
                integrated = integrate.quad(lambda x: fitfunction(fit_param), lower, upper)
        except:
            fit_param=(0,0,0)
            ...
    

    Now this results in the following error:

    LoweringError: Failed at object (object mode backend)

    I assume this is due to jit not being able to handle try except (it also does not work if I only put jit on the curve_fit and integrate.quad parts and work around my own try except structure)

    import scipy.integrate as integrate
    from scipy optimize import curve_fit
    from numba import jit
    
    def fitfunction():
        ...
    
    @jit
    def integral(lower, upper):
        return integrate.quad(lambda x: fitfunction(fit_param), lower, upper)
    
    @jit
    def fitting(x, y, pzero, max_fev)
        return curve_fit(fitfunction, x, y, p0=pzero, maxfev=max_fev)
    
    
    def function(x):
        # do some stuff
        try:
            fit_param, fit_cov = fitting(x, y, (0,0,0), 500)
            for idx in some_list:
                integrated = integral(lower, upper)
        except:
            fit_param=(0,0,0)
            ...
    

    Is there a way to use jit with scipy.integrate.quad and curve_fit without manually deleting all try except structures from the scipy code?

    And would it even speed up the code?

    • hpaulj
      hpaulj about 5 years
      Instead of trying to jit the scipy function, why don't you focus on speeding up your own function, the fitfunction. That's the one that quad and curve_fit call repeatedly. quad already uses compiled code, in the _quadpack module.
  • Katermickie
    Katermickie almost 3 years
    Nice! I will take a closer look
  • Nick Alger
    Nick Alger about 2 years
    The problem is that often one wishes to use scipy within a numba function. For example, often one has a loop filled with numerical calculations that numba is good at speeding up. But if one of your numerical calculations calls a scipy function (a common situation), now numba can't be used.
  • MSeifert
    MSeifert about 2 years
    @NickAlger Yeah, definitely inconvenient. I'm not saying that numba is (or was) great the way it is. However it's a limitation and a limitation that is often easy to work around. E.g. there were only a few cases where I couldn't move the external function call outside of the numba function. Also often it's even not even beneficial anymore to use numba if you use external function calls, because the numba loop and numerical calculation performance boost were simply insignificant compared to the time spent in external functions.
  • MSeifert
    MSeifert about 2 years
    But it really depends on your specific case. It's just not possible to give a generic answer to that. Maybe give cython a try if you require fast calculations and external function calls?