Changes in PR #271 don't work with BatchStacker
alexlyttle opened this issue · 2 comments
Summary
PR #271 which changed the use of self.conditions
to self._distribution
doesn't work with BatchStacker
. If a distribution which was changed in #271 is initialised with batch_stack=<some_number>
then the model fails due to BatchStacker
not having the attributes once held in conditions. Could BatchStacker
be changed to inherit the attributes of the stacked distribution, or should the changes in that PR be reverted?
Example
import pymc4 as pm
@pm.model
def model():
x = yield pm.Uniform('x', 0.0, 1.0, batch_stack=10)
pm.sample(model())
gives
AttributeError: 'BatchStacker' object has no attribute 'high'
Show full traceback
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-2-f1540b4b3836> in <module>
3 x = yield pm.Uniform('x', 0.0, 1.0, batch_stack=10)
4
----> 5 pm.sample(model())
~/Source/pymc4/pymc4/inference/sampling.py in sample(model, num_samples, num_chains, burn_in, step_size, observed, state, nuts_kwargs, adaptation_kwargs, sample_chain_kwargs, xla, use_auto_batching)
118 state=state,
119 observed=observed,
--> 120 collect_reduced_log_prob=use_auto_batching,
121 )
122 init_state = list(init.values())
~/Source/pymc4/pymc4/inference/sampling.py in build_logp_and_deterministic_functions(model, num_chains, observed, state, collect_reduced_log_prob)
201 raise ValueError("Can't use both `state` and `observed` arguments")
202
--> 203 state, deterministic_names = initialize_sampling_state(model, observed=observed, state=state)
204
205 if not state.all_unobserved_values:
~/Source/pymc4/pymc4/inference/utils.py in initialize_sampling_state(model, observed, state)
25 The list of names of the model's deterministics
26 """
---> 27 _, state = flow.evaluate_meta_model(model, observed=observed, state=state)
28 deterministic_names = list(state.deterministics)
29
~/Source/pymc4/pymc4/flow/executor.py in evaluate_model(self, model, state, _validate_state, values, observed, sample_shape)
485 try:
486 return_value, state = self.proceed_distribution(
--> 487 dist, state, sample_shape=sample_shape
488 )
489 except EvaluationError as error:
~/Source/pymc4/pymc4/flow/meta_executor.py in proceed_distribution(self, dist, state, sample_shape)
88 )
89 else:
---> 90 return_value = state.untransformed_values[scoped_name] = dist.get_test_sample()
91 state.distributions[scoped_name] = dist
92 return return_value, state
~/Source/pymc4/pymc4/distributions/distribution.py in get_test_sample(self, sample_shape, seed)
152 """
153 sample_shape = tf.TensorShape(sample_shape)
--> 154 return tf.broadcast_to(self.test_value, sample_shape + self.batch_shape + self.event_shape)
155
156 def log_prob(self, value):
~/Source/pymc4/pymc4/distributions/distribution.py in test_value(self)
104 @property
105 def test_value(self):
--> 106 return tf.broadcast_to(self._test_value, self.batch_shape + self.event_shape)
107
108 def sample(self, sample_shape=(), seed=None):
~/Source/pymc4/pymc4/distributions/distribution.py in _test_value(self)
296 @property
297 def _test_value(self):
--> 298 return 0.5 * (self.upper_limit() + self.lower_limit())
299
300
~/Source/pymc4/pymc4/distributions/continuous.py in upper_limit(self)
1290
1291 def upper_limit(self):
-> 1292 return self._distribution.high
1293
1294
AttributeError: 'BatchStacker' object has no attribute 'high'
Oh damn! I'll see if we can somehow salvage #271. I prefer using the tfp distribution parameters for a few reasons:
- They have already been converted to tensors with an adequate dtype.
- They have been potentially sanitized by tfp.
- Some shape wrangling, like broadcasting, could have gone on in the tfp backend.
I know it will be a hard fix, and this also highlights that things that had an event_stack
won't work either. @twiecki, what do you think is the best way to go forward? Revert #271 now, and try work on a better solution than #271 turned out to be in a future PR?
I just wanted to add that a proper fix to this would also involve #288.