Qazalbash/jaxampler

Add `verbose`

Qazalbash opened this issue · 0 comments

Description

Many frameworks provide verbose to see what is going on. We should provide an optional keyword argument verbose=False in specific functions.

Code

>>> from functools import partial
>>> from typing import Optional
>>>
>>> from jax import jit
>>> from jax.scipy.stats import norm
>>> from jaxampler.rvs import Normal, ContinuousRV
>>> from jaxampler.typing import Numeric
>>> from jaxampler.sampler import MetropolisHastingSampler
>>>
>>>
>>> class TwoPeakNormal(ContinuousRV):
...     def __init__(self, name: Optional[str] = None) -> None:
...         super().__init__(name)
...
...     @partial(jit, static_argnums=(0,))
...     def pdf_x(self, x: Numeric) -> Numeric:
...         return 0.5 * (norm.pdf(x, loc=-2.0, scale=1.0) + norm.pdf(x, loc=2.0, scale=1.0))
>>>
>>>
>>> sampler = MetropolisHastingSampler(name="forTwoPeakNormal")
>>> p = TwoPeakNormal(name="TwoPeakNormal")
>>> q = lambda x: Normal(mu=x, sigma=0.4, name="Normal")
>>>
>>> samples = sampler.sample(
...     p=p,
...     q=q,
...     N=1000,
...     burn_in=1000,
...     n_chains=3,
...     key=None,
...     hasting_ratio=True,
...     x0=q(0.0).rvs(shape=(3,)),
...     verbose=True,
... )

Output

Burn-in        : 100%|#####################################################################################################################| 1.00k/1.00k [00:00<00:00, 2.69ksamples/s]
chain      0   :  19%|######################6                                                                                                 | 189/1.00k [00:08<00:32, 24.8samples/s]
chain      1   :  19%|######################6                                                                                                 | 189/1.00k [00:08<00:36, 22.2samples/s]
chain      2   :  19%|#######################                                                                                                 | 192/1.00k [00:08<00:39, 20.4samples/s]
Total          :  19%|######################8                                                                                                 | 571/3.00k [00:08<00:36, 66.1samples/s]