pymc-devs/pymc4

Sampling fails when `num_chains=1`

tirthasheshpatel opened this issue · 5 comments

While debugging #309, I found that sampling fails when num_chains=1 on the latest tensorflow nightly and tensorflow probability nightly. Here's a minimal reproducible example:

>>> import pymc4 as pm
>>> @pm.model
... def model():
...  x = yield pm.Normal("x", 0., 1.)
...  return x
...
>>> m = model()
>>> pm.sample(m, num_chains=1, num_samples=10, burn_in=10)

Error:

Auto-assigning NUTS sampler
Traceback (most recent call last):
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\util\nest.py", line 403, in assert_same_structure
    _pywrap_utils.AssertSameStructure(nest1, nest2, check_types,
ValueError: The two structures don't have the same nested structure.

First structure: type=list str=[<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/add:0' shape=() dtype=int32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/add_1:0' shape=(1,) dtype=float32>], [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/add:0' shape=(1,) dtype=float32>], <tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/pfor/Tile:0' shape=(1,) dtype=float32>, [<tensorflow.python.framework.indexed_slices.IndexedSlices object at 0x000001E28DA833D0>]]

Second structure: type=list str=[<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/iter:0' shape=() dtype=int32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/add:0' shape=(1,) dtype=float32>], [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_6:0' shape=(1,) dtype=float32>], <tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_7:0' shape=(1,) dtype=float32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_8:0' shape=(1,) dtype=float32>]]

More specifically: Substructure "type=IndexedSlices str=IndexedSlices(indices=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Reshape_1:0", shape=(1,), dtype=int32), values=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Reshape:0", shape=(1,), dtype=float32), dense_shape=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Cast:0", shape=(1,), dtype=int32))" is a sequence, while substructure "type=Tensor str=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_8:0", shape=(1,), dtype=float32)" is not

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\pymc4\inference\sampling.py", line 168, in sample
    return sampler(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\pymc4\mcmc\samplers.py", line 225, in __call__
    return self._sample(*args, **kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\pymc4\mcmc\samplers.py", line 132, in _sample
    results, sample_stats = self._run_chains(init_state, burn_in)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\def_function.py", line 786, in __call__
    result = self._call(*args, **kwds)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\def_function.py", line 829, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\def_function.py", line 716, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\function.py", line 2955, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\function.py", line 3351, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\function.py", line 3190, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\framework\func_graph.py", line 987, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\def_function.py", line 625, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\eager\function.py", line 3875, in bound_method_wrapper
    return wrapped_fn(weak_instance(), *args, **kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\pymc4\mcmc\samplers.py", line 172, in _run_chains
    results, sample_stats = mcmc.sample_chain(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\sample.py", line 361, in sample_chain
    (_, _, final_kernel_results), (all_states, trace) = mcmc_util.trace_scan(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\internal\util.py", line 460, in trace_scan
    _, final_state, _, trace_arrays = tf.while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\util\deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2489, in while_loop_v2
    return while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2687, in while_loop
    return while_v2.while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 188, in while_loop
    body_graph = func_graph_module.func_graph_from_py_func(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\framework\func_graph.py", line 987, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 174, in wrapped_body
    outputs = body(*_pack_sequence_as(orig_loop_vars, args))
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\internal\util.py", line 450, in _body
    state = loop_fn(state, elem)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\sample.py", line 354, in _trace_scan_fn
    seed, next_state, current_kernel_results = mcmc_util.smart_for_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\internal\util.py", line 349, in smart_for_loop
    return tf.while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\util\deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2489, in while_loop_v2
    return while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2687, in while_loop
    return while_v2.while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 188, in while_loop
    body_graph = func_graph_module.func_graph_from_py_func(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\framework\func_graph.py", line 987, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 174, in wrapped_body
    outputs = body(*_pack_sequence_as(orig_loop_vars, args))
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\internal\util.py", line 351, in <lambda>
    body=lambda i, *args: [i + 1] + list(body_fn(*args)),
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\sample.py", line 351, in _seeded_one_step
    kernel.one_step(*state_and_results, **one_step_kwargs))
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\dual_averaging_step_size_adaptation.py", line 456, in one_step
    new_state, new_inner_results = self.inner_kernel.one_step(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\nuts.py", line 419, in one_step
    _, _, _, new_step_metastate = tf.while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\util\deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2489, in while_loop_v2
    return while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2687, in while_loop
    return while_v2.while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 188, in while_loop
    body_graph = func_graph_module.func_graph_from_py_func(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\framework\func_graph.py", line 987, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 174, in wrapped_body
    outputs = body(*_pack_sequence_as(orig_loop_vars, args))
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\nuts.py", line 423, in <lambda>
    body=lambda iter_, seed, state, metastate: self._loop_tree_doubling(  # pylint: disable=g-long-lambda
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\nuts.py", line 597, in _loop_tree_doubling
    ] = self._build_sub_tree(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\nuts.py", line 777, in _build_sub_tree
    ] = tf.while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\util\deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2489, in while_loop_v2
    return while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2687, in while_loop
    return while_v2.while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 188, in while_loop
    body_graph = func_graph_module.func_graph_from_py_func(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\framework\func_graph.py", line 987, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 174, in wrapped_body
    outputs = body(*_pack_sequence_as(orig_loop_vars, args))
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\nuts.py", line 785, in <lambda>
    self._loop_build_sub_tree(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\nuts.py", line 838, in _loop_build_sub_tree
    ] = integrator(prev_tree_state.momentum,
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow_probability\python\mcmc\internal\leapfrog_integrator.py", line 282, in __call__
    ] = tf.while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\util\deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2489, in while_loop_v2
    return while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2687, in while_loop
    return while_v2.while_loop(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 188, in while_loop
    body_graph = func_graph_module.func_graph_from_py_func(
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\framework\func_graph.py", line 987, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\ops\while_v2.py", line 179, in wrapped_body
    nest.assert_same_structure(list(outputs), list(orig_loop_vars),
  File "C:\Users\tirth\Desktop\INTERESTS\PyMC4\env\lib\site-packages\tensorflow\python\util\nest.py", line 408, in assert_same_structure
    raise type(e)("%s\n"
ValueError: The two structures don't have the same nested structure.

First structure: type=list str=[<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/add:0' shape=() dtype=int32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/add_1:0' shape=(1,) dtype=float32>], [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/add:0' shape=(1,) dtype=float32>], <tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/pfor/Tile:0' shape=(1,) dtype=float32>, [<tensorflow.python.framework.indexed_slices.IndexedSlices object at 0x000001E28DA833D0>]]

Second structure: type=list str=[<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/iter:0' shape=() dtype=int32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/add:0' shape=(1,) dtype=float32>], [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_6:0' shape=(1,) dtype=float32>], <tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_7:0' shape=(1,) dtype=float32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_8:0' shape=(1,) dtype=float32>]]

More specifically: Substructure "type=IndexedSlices str=IndexedSlices(indices=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Reshape_1:0", shape=(1,), dtype=int32), values=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Reshape:0", shape=(1,), dtype=float32), dense_shape=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/loop_build_sub_tree/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Cast:0", shape=(1,), dtype=int32))" is a sequence, while substructure "type=Tensor str=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/NoUTurnSampler/.one_step/while/loop_tree_doubling/build_sub_tree/while/Placeholder_8:0", shape=(1,), dtype=float32)" is not
Entire first structure:
[., [.], [.], ., [.]]
Entire second structure:
[., [.], [.], ., [.]]

Versions

  • Tensorflow Nightly : 2.4.0-dev20200828
  • Tensorflow probability Nightly : 0.12.0-dev20200830
  • Numpy : 1.19.0

@rrkarim could you take a look?

Everything seems to be fine on: tf==2.4.0-dev20200731, tfp==0.12.0-dev20200830 (the newest). Checking the newest tf nightly.

[UPD] breaks on tf==2.4.0-dev20200830

So the issue is caused by this commit and it is not about the recent discrete/compound merge (just checked). Also everything is fine with rwm, so I assume some issues appear in leapfrog integration.
Will return to the issue later. @tirthasheshpatel @junpenglao

it is something to do with the vectorized_map and tf.function - I am however struggle to find a minimal reproducible example.