Sampling fails when `num_chains=1`
tirthasheshpatel opened this issue · 5 comments
While debugging #309, I found that sampling fails when num_chains=1
on the latest tensorflow nightly and tensorflow probability nightly. Here's a minimal reproducible example:
>>> import pymc4 as pm
>>> @pm.model
... def model():
... x = yield pm.Normal("x", 0., 1.)
... return x
...
>>> m = model()
>>> pm.sample(m, num_chains=1, num_samples=10, burn_in=10)
Error:
Auto-assigning NUTS sampler
Traceback (most recent call last):
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\util\nest.py", line 403, in assert_same_structure
_pywrap_utils.AssertSameStructure(nest1, nest2, check_types,
ValueError: The two structures don't have the same nested structure.
First structure: type=list str=[<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/add:0' shape=() dtype=int32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/add_1:0' shape=(1,) dtype=float32>], [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/add:0' shape=(1,) dtype=float32>], <tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/pfor/Tile:0' shape=(1,) dtype=float32>, [<tensorflow.python.framework.indexed_slices.IndexedSlices object at 0x000001E28DA833D0>]]
Second structure: type=list str=[<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/iter:0' shape=() dtype=int32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/add:0' shape=(1,) dtype=float32>], [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_6:0' shape=(1,) dtype=float32>], <tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_7:0' shape=(1,) dtype=float32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_8:0' shape=(1,) dtype=float32>]]
More specifically: Substructure "type=IndexedSlices str=IndexedSlices(indices=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Reshape_1:0", shape=(1,), dtype=int32), values=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Reshape:0", shape=(1,), dtype=float32), dense_shape=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Cast:0", shape=(1,), dtype=int32))" is a sequence, while substructure "type=Tensor str=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_8:0", shape=(1,), dtype=float32)" is not
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\pymc4\inference\sampling.py", line 168, in sample
return sampler(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\pymc4\mcmc\samplers.py", line 225, in __call__
return self._sample(*args, **kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\pymc4\mcmc\samplers.py", line 132, in _sample
results, sample_stats = self._run_chains(init_state, burn_in)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\def_function.py", line 786, in __call__
result = self._call(*args, **kwds)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\def_function.py", line 829, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\def_function.py", line 716, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\function.py", line 2955, in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\function.py", line 3351, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\function.py", line 3190, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\framework\func_graph.py", line 987, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\def_function.py", line 625, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\function.py", line 3875, in bound_method_wrapper
return wrapped_fn(weak_instance(), *args, **kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\pymc4\mcmc\samplers.py", line 172, in _run_chains
results, sample_stats = mcmc.sample_chain(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\sample.py", line 361, in sample_chain
(_, _, final_kernel_results), (all_states, trace) = mcmc_util.trace_scan(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\internal\util.py", line 460, in trace_scan
_, final_state, _, trace_arrays = tf.while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\util\deprecation.py", line 574, in new_func
return func(*args, **kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2489, in while_loop_v2
return while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2687, in while_loop
return while_v2.while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 188, in while_loop
body_graph = func_graph_module.func_graph_from_py_func(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\framework\func_graph.py", line 987, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 174, in wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\internal\util.py", line 450, in _body
state = loop_fn(state, elem)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\sample.py", line 354, in _trace_scan_fn
seed, next_state, current_kernel_results = mcmc_util.smart_for_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\internal\util.py", line 349, in smart_for_loop
return tf.while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\util\deprecation.py", line 574, in new_func
return func(*args, **kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2489, in while_loop_v2
return while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2687, in while_loop
return while_v2.while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 188, in while_loop
body_graph = func_graph_module.func_graph_from_py_func(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\framework\func_graph.py", line 987, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 174, in wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\internal\util.py", line 351, in <lambda>
body=lambda i, *args: [i + 1] + list(body_fn(*args)),
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\sample.py", line 351, in _seeded_one_step
kernel.one_step(*state_and_results, **one_step_kwargs))
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\dual_averaging_step_size_adaptation.py", line 456, in one_step
new_state, new_inner_results = self.inner_kernel.one_step(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\nuts.py", line 419, in one_step
_, _, _, new_step_metastate = tf.while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\util\deprecation.py", line 574, in new_func
return func(*args, **kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2489, in while_loop_v2
return while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2687, in while_loop
return while_v2.while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 188, in while_loop
body_graph = func_graph_module.func_graph_from_py_func(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\framework\func_graph.py", line 987, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 174, in wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\nuts.py", line 423, in <lambda>
body=lambda iter_, seed, state, metastate: self._loop_tree_doubling( # pylint: disable=g-long-lambda
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\nuts.py", line 597, in _loop_tree_doubling
] = self._build_sub_tree(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\nuts.py", line 777, in _build_sub_tree
] = tf.while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\util\deprecation.py", line 574, in new_func
return func(*args, **kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2489, in while_loop_v2
return while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2687, in while_loop
return while_v2.while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 188, in while_loop
body_graph = func_graph_module.func_graph_from_py_func(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\framework\func_graph.py", line 987, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 174, in wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\nuts.py", line 785, in <lambda>
self._loop_build_sub_tree(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\nuts.py", line 838, in _loop_build_sub_tree
] = integrator(prev_tree_state.momentum,
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\internal\leapfrog_integrator.py", line 282, in __call__
] = tf.while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\util\deprecation.py", line 574, in new_func
return func(*args, **kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2489, in while_loop_v2
return while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2687, in while_loop
return while_v2.while_loop(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 188, in while_loop
body_graph = func_graph_module.func_graph_from_py_func(
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\framework\func_graph.py", line 987, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 179, in wrapped_body
nest.assert_same_structure(list(outputs), list(orig_loop_vars),
File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\util\nest.py", line 408, in assert_same_structure
raise type(e)("%s\n"
ValueError: The two structures don't have the same nested structure.
First structure: type=list str=[<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/add:0' shape=() dtype=int32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/add_1:0' shape=(1,) dtype=float32>], [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/add:0' shape=(1,) dtype=float32>], <tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/pfor/Tile:0' shape=(1,) dtype=float32>, [<tensorflow.python.framework.indexed_slices.IndexedSlices object at 0x000001E28DA833D0>]]
Second structure: type=list str=[<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/iter:0' shape=() dtype=int32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/add:0' shape=(1,) dtype=float32>], [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_6:0' shape=(1,) dtype=float32>], <tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_7:0' shape=(1,) dtype=float32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_8:0' shape=(1,) dtype=float32>]]
More specifically: Substructure "type=IndexedSlices str=IndexedSlices(indices=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Reshape_1:0", shape=(1,), dtype=int32), values=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Reshape:0", shape=(1,), dtype=float32), dense_shape=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Cast:0", shape=(1,), dtype=int32))" is a sequence, while substructure "type=Tensor str=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_8:0", shape=(1,), dtype=float32)" is not
Entire first structure:
[., [.], [.], ., [.]]
Entire second structure:
[., [.], [.], ., [.]]
Versions
- Tensorflow Nightly : 2.4.0-dev20200828
- Tensorflow probability Nightly : 0.12.0-dev20200830
- Numpy : 1.19.0
@rrkarim could you take a look?
Everything seems to be fine on: tf==2.4.0-dev20200731
, tfp==0.12.0-dev20200830
(the newest). Checking the newest tf nightly.
[UPD] breaks on tf==2.4.0-dev20200830
So the issue is caused by this commit and it is not about the recent discrete/compound merge (just checked). Also everything is fine with rwm, so I assume some issues appear in leapfrog integration.
Will return to the issue later. @tirthasheshpatel @junpenglao
it is something to do with the vectorized_map and tf.function - I am however struggle to find a minimal reproducible example.