`rvs` function contains unnecessary details which can be abstracted away!
Qazalbash opened this issue · 0 comments
Qazalbash commented
Description
rvs(...)
method in GenericRV
and its inherited types contain some unnecessary details that can be hidden in the GenericRV
class. For example, jaxampler.rvs.Exponential.rvs(...)
is implemented as,
def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array:
if key is None:
key = self.get_key()
new_shape = shape + self._shape
U = jax.random.uniform(key, shape=new_shape)
rvs_val = self._loc - self._scale * jnp.log(U)
return rvs_val
The first three lines are common in each GenericRV.rvs
method. This could not be very pleasant for the users to check for key
and shape
each time. Instead, they should implement some method like,
def _rvs(self, shape: tuple[int, ...], key: Array) -> Array: ...
And this method should be called inside the GenericRV.rvs
method, like,
def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array:
if key is None:
key = self.get_key()
new_shape = shape + self._shape
return self._rvs(shape=shape, key=key)
This design will ease user experience.