Qazalbash/jaxampler

Accept Reject sampler not working for multivariate pdfs

Qazalbash opened this issue · 0 comments

For this piece of code,

from jaxampler.sampler import AcceptRejectSampler
from jaxtro.models import Wysocki2019MassModel
from matplotlib import pyplot as plt

model = Wysocki2019MassModel(alpha=0.8, k=0, mmin=5.0, mmax=40.0, Mmax=80.0, name="Wysocki2019MassModel")
sampler = AcceptRejectSampler()

samples = sampler.sample(target_rv=model, proposal_rv=model, scale=1.05, N=1000)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(samples[:, 0], samples[:, 1], samples[:, 2])
plt.show()

we get this error,

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/gradf/Desktop/test/test_sampler.py", line 8, in <module>
    samples = sampler.sample(target_rv=model, proposal_rv=model, scale=1.05, N=1000)
  File "/home/gradf/.local/lib/python3.10/site-packages/jaxampler/sampler/arsampler.py", line 43, in sample
    pdf_ratio = target_rv.pdf(V)
  File "/home/gradf/.local/lib/python3.10/site-packages/jaxampler/rvs/crvs/crvs.py", line 38, in pdf
    return jnp.exp(self.logpdf(*x))
TypeError: Wysocki2019MassModel.logpdf() missing 1 required positional argument: 'm2'