Accept Reject sampler not working for multivariate pdfs
Qazalbash opened this issue · 0 comments
Qazalbash commented
For this piece of code,
from jaxampler.sampler import AcceptRejectSampler
from jaxtro.models import Wysocki2019MassModel
from matplotlib import pyplot as plt
model = Wysocki2019MassModel(alpha=0.8, k=0, mmin=5.0, mmax=40.0, Mmax=80.0, name="Wysocki2019MassModel")
sampler = AcceptRejectSampler()
samples = sampler.sample(target_rv=model, proposal_rv=model, scale=1.05, N=1000)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(samples[:, 0], samples[:, 1], samples[:, 2])
plt.show()
we get this error,
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/gradf/Desktop/test/test_sampler.py", line 8, in <module>
samples = sampler.sample(target_rv=model, proposal_rv=model, scale=1.05, N=1000)
File "/home/gradf/.local/lib/python3.10/site-packages/jaxampler/sampler/arsampler.py", line 43, in sample
pdf_ratio = target_rv.pdf(V)
File "/home/gradf/.local/lib/python3.10/site-packages/jaxampler/rvs/crvs/crvs.py", line 38, in pdf
return jnp.exp(self.logpdf(*x))
TypeError: Wysocki2019MassModel.logpdf() missing 1 required positional argument: 'm2'