Question regarding the data for the log-likelihood
Qazalbash opened this issue · 3 comments
Description
Hi, I am trying to run flowMC
to sample from the inhomogeneous Poisson likelihood mentioned in equation 4 of the paper. I am using the following Monte-Carlo approximation for the integral.
You can notice we have a different number of samples for each iteration. JAX does not support such arrays where each sub-array has a different size. Therefore my next natural choice was to use jax.tree_map
, but you have vectorized the log_pdf
(which is in my case the log-likelihood) by vmap
which is not working.
Is there any other way to pass those pre-computed values?
I am attaching the code down here.
from __future__ import annotations
import jax
from jax import jit
from jax import numpy as jnp
from numpyro import distributions as dist
from gwkokab.vts.utils import interpolate_hdf5
from ..models import *
from ..utils.misc import get_key
@jit
def exp_rate(rate, *, pop_params) -> float:
N = 1 << 13 # 2 ** 12
lambdas = Wysocki2019MassModel(
alpha_m=pop_params["alpha_m"],
k=0,
mmin=pop_params["mmin"],
mmax=pop_params["mmax"],
).sample(get_key(), sample_shape=(N,))
I = 0
m1 = lambdas[..., 0]
m2 = lambdas[..., 1]
value = jnp.exp(interpolate_hdf5(m1, m2))
F = jnp.sum(value)
I_current = F / N
I += rate * I_current
return I
def log_inhomogeneous_poisson_likelihood(x, data=None):
alpha = x[..., 0]
mmin = x[..., 1]
mmax = x[..., 2]
rate = x[..., 3]
alpha_prior = dist.LogUniform(-5.0, 5.0)
mmin_prior = dist.LogUniform(5.0, 30.0)
mmax_prior = dist.LogUniform(30.0, 100.0)
# rate_prior = dist.Uniform(1, 500)
expval = exp_rate(
rate,
pop_params={
"alpha_m": alpha,
"mmin": mmin,
"mmax": mmax,
},
)
mass_model = Wysocki2019MassModel(
alpha_m=alpha,
k=0,
mmin=mmin,
mmax=mmax,
)
log_integral = jnp.sum(
jnp.asarray(
jax.tree_map(
lambda d: jax.nn.logsumexp(
mass_model.log_prob(d)
- alpha_prior.log_prob(alpha)
- mmin_prior.log_prob(mmin)
- mmax_prior.log_prob(mmax),
)
- jnp.log(len(d))
+ jnp.log(rate),
data,
)
)
)
return log_integral - expval
Wysocki2019MassModel
is equation 7 of the same paper mentioned above. Also, I have an inner feeling that there is something wrong with the likelihood function too, but I cannot point it out.
@Qazalbash Regarding the sub-arrays being different sizes, you mean the number of samples per event, d
, being different sizes from event to event, right? Usually, the easiest way is to upsample or downsample the event to a common size, then one should be able to vmap over it without problems.
Regarding your likelihood, a number of things:
- It seems your
exp_rate
and likelihood function is constructing theWysocki2019MassModel
at every call, which I am not sure whether it is intended. Depending on how theWysocki2019MassModel
is written, this can massively slow the computation down. Same goes for the priors. I would initialize those objects outside the likelihood and pass them in as data. - The function signature of the likelihood should be
f(x: array, data: dict)
, where x is a 1D array of the parameters. Meaning it should something like x = jnp.array([alpha, mmin, mmax, rate]). - The
sample
in the mass model seems to be a stochastic process to me, which I am not sure you need it. I assume lambdas are the hyperparameters, and it is for computing the rate. In that case,x
in the likelihood is your lambda, you don't have to resample it within the likelihood.
@kazewong Thank you for the points. I have a few uncertainties about certain aspects of the process, which I've outlined below. I hope this doesn't inconvenience you.
- The reason for making
Wysocki2019MassModel
inside thelog_inhomogeneous_poisson_likelihood
is according to my understanding, we are running theflowMC
to recover the parameters$\alpha$ ,$m_\text{min}$ ,$m_\text{max}$ ,$\mathcal{R}$ , therefore at each iteration, these values would change as we can see that we are passing them asx
in thelog_inhomogeneous_poisson_likelihood
, therefore our model should also get updated according to these parameters. This is my understanding and I highly doubt it. - I am following one of the tutorial in the docs and there they have passed data as a
jnp.array
. I also can not grasp how we are passing different data to the sampler in the form ofdict
.
This is the Wysocki2019MassModel
.
# Copyright 2023 The GWKokab Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing_extensions import Optional
from jax import lax, numpy as jnp
from jax.random import uniform
from jaxtyping import Array
from numpyro.distributions import constraints, Distribution
from numpyro.distributions.util import promote_shapes, validate_sample
from ..utils.misc import get_key
class Wysocki2019MassModel(Distribution):
r"""It is a double-side truncated power law distribution, as
described in equation 7 of the `paper <https://arxiv.org/abs/1805.06442>`__.
.. math::
p(m_1,m_2\mid\alpha,k,m_{\text{min}},m_{\text{max}},M_{\text{max}})\propto\frac{m_1^{-\alpha-k}m_2^k}{m_1-m_{\text{min}}}
"""
arg_constraints = {
"alpha_m": constraints.real,
"k": constraints.nonnegative_integer,
"mmin": constraints.positive,
"mmax": constraints.positive,
}
support = constraints.real_vector
reparametrized_params = ["m1", "m2"]
def __init__(self, alpha_m: float, k: int, mmin: float, mmax: float, *, valid_args=None) -> None:
r"""Initialize the power law distribution with a lower and upper mass limit.
:param alpha_m: index of the power law distribution
:param k: mass ratio power law index
:param mmin: lower mass limit
:param mmax: upper mass limit
:param valid_args: If `True`, validate the input arguments.
"""
self.alpha_m, self.k, self.mmin, self.mmax = promote_shapes(alpha_m, k, mmin, mmax)
batch_shape = lax.broadcast_shapes(
jnp.shape(alpha_m),
jnp.shape(k),
jnp.shape(mmin),
jnp.shape(mmax),
)
super(Wysocki2019MassModel, self).__init__(
batch_shape=batch_shape,
validate_args=valid_args,
event_shape=(2,),
)
@validate_sample
def log_prob(self, value):
return (
-(self.alpha_m + self.k) * jnp.log(value[..., 0])
+ self.k * jnp.log(value[..., 1])
- jnp.log(value[..., 0] - self.mmin)
)
def sample(self, key: Optional[Array | int], sample_shape: tuple = ()) -> Array:
if key is None or isinstance(key, int):
key = get_key(key)
m2 = uniform(key=key, minval=self.mmin, maxval=self.mmax, shape=sample_shape + self.batch_shape)
U = uniform(key=get_key(key), minval=0.0, maxval=1.0, shape=sample_shape + self.batch_shape)
beta = 1 - (self.k + self.alpha_m)
conditions = [beta == 0.0, beta != 0.0]
choices = [
jnp.exp(U * jnp.log(self.mmax) + (1.0 - U) * jnp.log(m2)),
jnp.exp(jnp.power(beta, -1.0) * jnp.log(U * jnp.power(self.mmax, beta) + (1.0 - U) * jnp.power(m2, beta))),
]
m1 = jnp.select(conditions, choices)
return jnp.stack([m1, m2], axis=-1)
def __repr__(self) -> str:
string = f"Wysocki2019MassModel(alpha_m={self.alpha_m}, k={self.k}, "
string += f"mmin={self.mmin}, mmax={self.mmax})"
return string
I have pretty much got the idea where I am wrong. Your response will make it more clear. Thanks in advance.
I think this is resolved. Closing the issue.