TypeError: sub got incompatible shapes for broadcasting
CCOSerika opened this issue · 3 comments
CCOSerika commented
Hi! I run the code with synthetic dataset successfully. And the result shows perfectly. But when it comes to my own real360 dataset, it can't run and gives following exception:
jax._src.traceback_util.UnfilteredStackTrace: TypeError: sub got incompatible shapes for broadcasting: (256, 3), (256, 4).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
Do u know how to fix it? Thanx in advance.
CCOSerika commented
The complete error messages are as follows.
Traceback (most recent call last):
File "stage1.py", line 1538, in <module>
threshold
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/_src/api.py", line 1676, in f_pmapped
global_arg_shapes=tuple(global_arg_shapes_flat))
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/core.py", line 1620, in bind
return call_bind(self, fun, *args, **params)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/core.py", line 1623, in process
return trace.process_map(self, fun, tracers, params)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/core.py", line 606, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/interpreters/pxla.py", line 628, in xla_pmap_impl
*abstract_args)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/linear_util.py", line 262, in memoized_fun
ans = call(fun, *args)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/interpreters/pxla.py", line 713, in parallel_callable
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1284, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "stage1.py", line 1467, in train_step
(total_loss, color_loss_l2), grad = grad_fn(state.target)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/_src/api.py", line 904, in value_and_grad_f
f_partial, *dyn_args, has_aux=True, reduce_axes=reduce_axes)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/_src/api.py", line 2002, in _vjp
flat_fun, primals_flat, has_aux=True, reduce_axes=reduce_axes)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/interpreters/ad.py", line 117, in vjp
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/interpreters/ad.py", line 102, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 505, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "stage1.py", line 1449, in loss_fn
loss_color_l2 = np.mean(np.square(rgb_est - pixels))
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/core.py", line 504, in __sub__
def __sub__(self, other): return self.aval._sub(self, other)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/_src/numpy/lax_numpy.py", line 5869, in deferring_binary_op
return binary_op(self, other)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/_src/numpy/lax_numpy.py", line 421, in <lambda>
fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/_src/lax/lax.py", line 344, in sub
return sub_p.bind(x, y)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/core.py", line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/interpreters/ad.py", line 283, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/_src/lax/lax.py", line 2713, in _sub_jvp
primal_out = sub(x, y)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/_src/lax/lax.py", line 344, in sub
return sub_p.bind(x, y)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/core.py", line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1059, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/_src/lax/lax.py", line 2125, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/_src/lax/lax.py", line 2221, in _broadcasting_shape_rule
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: sub got incompatible shapes for broadcasting: (256, 3), (256, 4).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "stage1.py", line 1538, in <module>
threshold
File "stage1.py", line 1467, in train_step
(total_loss, color_loss_l2), grad = grad_fn(state.target)
File "stage1.py", line 1449, in loss_fn
loss_color_l2 = np.mean(np.square(rgb_est - pixels))
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/_src/numpy/lax_numpy.py", line 5869, in deferring_binary_op
return binary_op(self, other)
File "/home/vt/anaconda3/envs/mobilenerf/lib/python3.6/site-packages/jax/_src/numpy/lax_numpy.py", line 421, in <lambda>
fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
TypeError: sub got incompatible shapes for broadcasting: (256, 3), (256, 4).
CCOSerika commented
jax version: 0.2.17
jaxlib version: 0.1.65_cu111