Qazalbash/jaxampler

`rvs` function contains unnecessary details which can be abstracted away!

Qazalbash opened this issue · 0 comments

Description

rvs(...) method in GenericRV and its inherited types contain some unnecessary details that can be hidden in the GenericRV class. For example, jaxampler.rvs.Exponential.rvs(...) is implemented as,

def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array:
    if key is None:
        key = self.get_key()
    new_shape = shape + self._shape
    U = jax.random.uniform(key, shape=new_shape)
    rvs_val = self._loc - self._scale * jnp.log(U)
    return rvs_val

The first three lines are common in each GenericRV.rvs method. This could not be very pleasant for the users to check for key and shape each time. Instead, they should implement some method like,

def _rvs(self, shape: tuple[int, ...], key: Array) -> Array: ...

And this method should be called inside the GenericRV.rvs method, like,

def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array:
    if key is None:
        key = self.get_key()
    new_shape = shape + self._shape
    return self._rvs(shape=shape, key=key)

This design will ease user experience.