weighted quadrature
Opened this issue · 8 comments
I use weighted quadrature in my research, and I was thinking of using quadax to replace scipy.integrate.quad in my work. I'd be interested in contributing a feature that handles weighted integration. Would you be open to a PR like that?
Yes definitely! I haven't had the time to dig into how it's done exactly but if you want to take a stab at it that would certainly be welcome.
Okay, I've been looking at weighted quadrature schemes and how they are implemented in quadpack. In quadpack, it looks like for each type of weighting function there is a different adaptive interval bisection scheme to be used (i.e. by calling a different fixed-order rule depending if the new bisected interval has the singularity or not, and in some instances using both clenshaw-curtis and gauss-kronrod for different subregions). This is in contrast with this library which has the main adaptive_quadrature
function that takes care of this adaptive portion of the different quadrature methods for one rule. Currently, I've been trying to see how I can write a weighted quadrature method without having to right a specialized adaptive_*
routine like its done in quadpack! So it has been taking a while to get this up and running. The good news is I think I have a general understanding of how some of the weighting schemes in quadpack work -- it has just been a matter of translating this to work with quadax
.
Okay, here's my tentative plan for now for the specific (algebraic-logarithmic) weight function
Let's say we want to integrate
However, in the h-adaptive scheme, when we bisect the interval
So we need a rule that either does the modified CC method or the basic CC method depending on the interval its called on. This is what I envision for the adaptive method:
def quadcc_alglogweight(fun, interval, args=(), weightargs=None, full_output=False, epsabs=None, epsrel=None,max_ninter=50, order=32, norm=jnp.inf):
# compute modified Chebyshev moments based on weightargs
chebmom = # do something here
def weightrule(fun, a, b, args, norm, n):
return fixed_quadcc_alglogweight(
fun,
a,
b,
args,
norm,
n,
weightargs=weightargs,
chebmom=chebmom,
)
def defaultrule(fun, a, b, args, norm, n):
fun = lambda x, args: fun(x, args) * alglogweightfn(x, **weightargs)
return fixed_quadcc(fun, a, b, args, norm, n)
@functools.partial(jax.jit, static_argnums=(0, 4, 5))
def rule(fun, a, b, args, norm, n):
# rule switches depending on interval
return jax.lax.cond(
weightargs["singularity"] == a or b, weightrule, defaultrule, operand=(fun, a, b, args, norm, n)
)
y, info = adaptive_quadrature(
rule,
fun,
interval,
args,
full_output,
epsabs,
epsrel,
max_ninter,
n=order,
norm=norm,
)
info = QuadratureInfo(info.err, info.neval * order, info.status, info.info)
return y, info
I'm not sure if this will work/be jitted properly etc., but I tried to avoid writing my own adaptive_quadrature
routine. If you see any obvious limitation to this idea I'd love to hear them, otherwise I'll try writing it up. If it fails then, I'll probably have to write a specific adaptive_quadrature
to handle weighted integration.
I think the basic approach makes sense, but a few comments:
- Instead of defining some custom rule, it might be cleaner to just specify the default rule for non-singular intervals, and then have a specialized rule for each weight function? Then you can put the switching logic here:
Lines 488 to 492 in 640afd5
if a1 <= weight_func.singularity <= b1:
# use special rule
else:
# use regular rule
and same for a2, b2 etc.
- Similarly, it's probably cleanest to define a class for each weight function that also includes info about the specialized quadrature rule? (I was never crazy about how it was done in scipy with strings and arguments that mean different things for different weights). Something like
class AlgLogWeight(eqx.Module): # equinox.Module is basically a callable pytree
c1: float
c2: float
alpha: float
beta: float
def evaluate(self, x):
"""evaluate the weight function at a point x"""
return (x-self.c1)**self.alpha * (x - self.c2)**self.beta * ...
def quadrature(self, fun, a, b):
"""Integrate fun*weight from a to b"""
# or maybe just have it return modified nodes and weights?
Though I'm not sure if there are multiple possible rules possible for a given weight functions, or is there a standard rule that everyone uses?
Does your first point involve modifying the existing adaptive_quadrature
method or writing a similar but new method for weighted functions? Edit: My concern this that different weight function will have different switching conditions, i.e. for the algebraic-logarithmic weight function, we need to check if there is a singularity at the endpoints
if a1 == weight_func.singularity1:
# use first special rule
elif b1 == weight_func.singulatity2
# use second special rule
else:
# use regular rule
Or for the Cauchy weight function, we need to check if the singularity is contained within the interval
if a1 < weight_func.singularity < b1:
# use special rule
else:
# use regular rule
I'm not sure how to make a general adaptive_quadrature
method that can decide which switching method to use based on which weight function we are considering. On the other hand, copy-and-pasting the current adaptive_quadrature
method and slightly modifying parts of it for specific weight functions doesn't seem like a great option to me. Thoughts? :End edit
I really like your second point of using a class instead of passing around a dict of weight function parameters etc.! I definitely think that's the way to go.
For your last question, for the algebraic-logarithmic weight function, quadpack/scipy (which I'm treating as the standard) uses slightly different rules depending on if whether the singularity is at a or b (not both, but practically you just split the interval in half so each half only has one singular end point). I haven't looked that much into the other weights functions (cauchy or sin/cos weights) yet.
Ah I see what you mean. What if we made the check a method of the weight function?
Like:
if weight_function is not None and weight_function.special_interval(a,b):
foo = weight_function.integrate(a,b)
else:
foo = base_rule.integrate(a,b)
I started to work up some classes for the basic integration rules here: #6 you should be able to use something similar for the weight functions.
Good idea! I'll start building off of the classes you've introduce.
Sorry for the delay, just wanted to give a short update: I have a version of the code that implements weighted quadrature (for just one specific weighting function for now), and next I want to implement some test integrals to make sure it works as expected.