econ-ark/HARK

`DiscreteDistribution.dist_of_func`

Closed this issue · 3 comments

Mv77 commented

There is momentum gathering behind an xarray revolution. After checking out the library and @alanlujan91 's ConsLabeledModel, I am happy to join that revolution.

A crucial part, I think, is getting the distribution machinery to work with x arrays. Which I think we do with the DiscreteDistributionLabeled class.

I was playing around with its dist_of_func method and ran into behavior that is either unintuitive or a bug. Consider the following example code, which tries to take the expectation of a labeled function in two different ways.

a) Directly, using the .expected method.
b) First creating an object that represents the distribution of the function, and only then taking its expectation.

        # Create a basic labeled distribution
        base_dist = DiscreteDistributionLabeled(
            pmv=np.array([0.5, 0.5]), atoms=np.array([[1.0, 2.0], [3.0, 4.0]]), var_names=["a", "b"]
        )

        # Define a transition function
        def transition(shocks, state):
            state_new = {}
            state_new["m"] = state["m"]*shocks["a"]
            state_new["n"] = state["n"]*shocks["b"]
            return state_new

        m = xr.DataArray(np.linspace(0, 10, 11), name="m", dims=("grid",))
        n = xr.DataArray(np.linspace(0, -10, 11), name="n", dims=("grid",))
        state_grid = xr.Dataset({"m": m, "n": n})

        # Evaluate labeled transformation

        # This works
        new_state_exp = base_dist.expected(transition, state=state_grid)
        # This does not work
        new_state_dstn = base_dist.dist_of_func(transition, state=state_grid)
        new_state_dstn.expected()

Method a) works but b) instead raises the error

new_state_dstn.expected()
  File "/home/mvg/GitHub/HARK/HARK/distribution.py", line 1243, in expected
    return super().expected()
  File "/home/mvg/GitHub/HARK/HARK/distribution.py", line 936, in expected
    f_query = self.atoms
AttributeError: 'DiscreteDistributionLabeled' object has no attribute 'atoms'

The application of the function of the function is not yielding an object with an atoms dimension over which to integrate, I think. It might be that something else needs to be added to the function. But it is weird that .expected() can deal with it and .dist_of_func can not.

I can look into it sometime, but wanted to confirm whether it is a bug, or a lack of documentation or simply a desired to-be-implemented feature.

Mv77 commented

@alanlujan91 I'm taking a stab at making this work and wanted to consult something with you.

I am running into issues where the properties and types of the objects used and returned by DiscreteDistributionLabeled's dist_of_fun and expected methods vary a lot depending on their arguments. See for instance this fragment of expected

HARK/HARK/distribution.py

Lines 1236 to 1245 in 37134b9

if len(kwargs):
f_query = func(self.dataset, *args, **kwargs)
ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability)
return ldd._weighted.mean("atom")
else:
if func is None:
return super().expected()
else:
return super().expected(func_wrapper, *args)

If I:

  • Pass a function with kwargs, it produces an xarray-like object.
  • Pass a function without kwargs, or no function at all it calls DiscreteDistribution.expected() and (if it succeeds in calling it) returns a numpy array.

Was there a reason for this behavior to depend on the arguments? Would it make more sense to homogenize everything? Or make the return types depend on f(atoms) instead. I.e, if f(atoms) is a dictionary or xarray then work with xarray, and if it is a numpy array or a float then work with numpy?

There are probably better ways of doing this, but the reason for different handling depending on arguments is as follows:

if len(kwargs): function depends on keyword arguments, which can be applied directly to xarrays. Then we can create a labeled Distribution and the expectation is the mean of a WeightedDataset.

if func is None: no need to apply function, faster to just do expected directly

else: function depends on positional arguments, which means that we need to convert the function arguments to be broadcastable.