error while finetuning wav2vec2 with bart.
arvindmn01 opened this issue · 1 comments
arvindmn01 commented
I tried to finetune wav2vec2 model along with bart model on my custom dataset using the following command
python run_flax_speech_recognition_seq2seq.py ...
but I got this error.
main()
File "run_flax_speech_recognition_seq2seq.py", line 1189, in main
state, train_metric = p_train_step(state, batch)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/api.py", line 1779, in cache_miss
execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 411, in xla_pmap_impl_lazy
compiled_fun, fingerprint = parallel_callable(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 678, in parallel_callable
pmap_computation = lower_parallel_callable(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 825, in lower_parallel_callable
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 748, in stage_parallel_callable
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2233, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "run_flax_speech_recognition_seq2seq.py", line 1051, in train_step
"encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])),
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/api.py", line 306, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 486, in norm
check_arraylike("jnp.linalg.norm", x)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 328, in check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "run_flax_speech_recognition_seq2seq.py", line 1266, in <module>
main()
File "run_flax_speech_recognition_seq2seq.py", line 1189, in main
state, train_metric = p_train_step(state, batch)
File "run_flax_speech_recognition_seq2seq.py", line 1051, in train_step
"encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])),
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 486, in norm
check_arraylike("jnp.linalg.norm", x)
File "/home/arvind/wav2vec2_env/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 328, in check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.
jax version is jax==0.4.13
jaxlib version is jaxlib==0.4.13
flax version is flax==0.7.2
sanchit-gandhi commented
Hey @arvindmn01 - could you provide a reproducible code snippet for this error?