pymc-devs/pymc4

Changes in PR #271 don't work with BatchStacker

alexlyttle opened this issue · 2 comments

Summary

PR #271 which changed the use of self.conditions to self._distribution doesn't work with BatchStacker. If a distribution which was changed in #271 is initialised with batch_stack=<some_number> then the model fails due to BatchStacker not having the attributes once held in conditions. Could BatchStacker be changed to inherit the attributes of the stacked distribution, or should the changes in that PR be reverted?

Example

import pymc4 as pm

@pm.model
def model():
    x = yield pm.Uniform('x', 0.0, 1.0, batch_stack=10)

pm.sample(model())

gives

AttributeError: 'BatchStacker' object has no attribute 'high'
Show full traceback
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-2-f1540b4b3836> in <module>
      3     x = yield pm.Uniform('x', 0.0, 1.0, batch_stack=10)
      4 
----> 5 pm.sample(model())

~/Source/pymc4/pymc4/inference/sampling.py in sample(model, num_samples, num_chains, burn_in, step_size, observed, state, nuts_kwargs, adaptation_kwargs, sample_chain_kwargs, xla, use_auto_batching)
    118         state=state,
    119         observed=observed,
--> 120         collect_reduced_log_prob=use_auto_batching,
    121     )
    122     init_state = list(init.values())

~/Source/pymc4/pymc4/inference/sampling.py in build_logp_and_deterministic_functions(model, num_chains, observed, state, collect_reduced_log_prob)
    201         raise ValueError("Can't use both `state` and `observed` arguments")
    202 
--> 203     state, deterministic_names = initialize_sampling_state(model, observed=observed, state=state)
    204 
    205     if not state.all_unobserved_values:

~/Source/pymc4/pymc4/inference/utils.py in initialize_sampling_state(model, observed, state)
     25         The list of names of the model's deterministics
     26     """
---> 27     _, state = flow.evaluate_meta_model(model, observed=observed, state=state)
     28     deterministic_names = list(state.deterministics)
     29 

~/Source/pymc4/pymc4/flow/executor.py in evaluate_model(self, model, state, _validate_state, values, observed, sample_shape)
    485                         try:
    486                             return_value, state = self.proceed_distribution(
--> 487                                 dist, state, sample_shape=sample_shape
    488                             )
    489                         except EvaluationError as error:

~/Source/pymc4/pymc4/flow/meta_executor.py in proceed_distribution(self, dist, state, sample_shape)
     88                 )
     89             else:
---> 90                 return_value = state.untransformed_values[scoped_name] = dist.get_test_sample()
     91         state.distributions[scoped_name] = dist
     92         return return_value, state

~/Source/pymc4/pymc4/distributions/distribution.py in get_test_sample(self, sample_shape, seed)
    152         """
    153         sample_shape = tf.TensorShape(sample_shape)
--> 154         return tf.broadcast_to(self.test_value, sample_shape + self.batch_shape + self.event_shape)
    155 
    156     def log_prob(self, value):

~/Source/pymc4/pymc4/distributions/distribution.py in test_value(self)
    104     @property
    105     def test_value(self):
--> 106         return tf.broadcast_to(self._test_value, self.batch_shape + self.event_shape)
    107 
    108     def sample(self, sample_shape=(), seed=None):

~/Source/pymc4/pymc4/distributions/distribution.py in _test_value(self)
    296     @property
    297     def _test_value(self):
--> 298         return 0.5 * (self.upper_limit() + self.lower_limit())
    299 
    300 

~/Source/pymc4/pymc4/distributions/continuous.py in upper_limit(self)
   1290 
   1291     def upper_limit(self):
-> 1292         return self._distribution.high
   1293 
   1294 

AttributeError: 'BatchStacker' object has no attribute 'high'

Oh damn! I'll see if we can somehow salvage #271. I prefer using the tfp distribution parameters for a few reasons:

  1. They have already been converted to tensors with an adequate dtype.
  2. They have been potentially sanitized by tfp.
  3. Some shape wrangling, like broadcasting, could have gone on in the tfp backend.

I know it will be a hard fix, and this also highlights that things that had an event_stack won't work either. @twiecki, what do you think is the best way to go forward? Revert #271 now, and try work on a better solution than #271 turned out to be in a future PR?

I just wanted to add that a proper fix to this would also involve #288.