Random `TypeError` with NumPy array with all values -100
versae opened this issue · 3 comments
I'm hitting this error message now and then. It does not seem to be affecting training, but I only see it when training on TPU. The same dataset was used in GPU with no errors. Just posting here in case there is something else going on that I am missing.
Step... (75000/759160 | Eval Loss: 0.11205478757619858 | Eval wer: 0.09877239458498131 | Eval cer: 0.02933955305671511 |): 10%|████▏ | 4/40 [23:39:27<205:34:14, 20557.06s/it/
data/flax/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:719: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndar
rays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
tensor = as_tensor(value)
--- Logging error ---
Traceback (most recent call last):
File "run_flax_speech_recognition_ctc.py", line 1631, in <module>
main()
File "run_flax_speech_recognition_ctc.py", line 1544, in main
state, train_metric = p_train_step(state, batch)
File "/data/flax/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/data/flax/lib/python3.8/site-packages/jax/_src/api.py", line 2158, in cache_miss
out_tree, out_flat = f_pmapped_(*args, **kwargs)
File "/data/flax/lib/python3.8/site-packages/jax/_src/api.py", line 2031, in pmap_f
p = _prepare_pmap(
File "/data/flax/lib/python3.8/site-packages/jax/_src/api.py", line 1969, in _prepare_pmap
_check_arg(arg)
File "/data/flax/lib/python3.8/site-packages/jax/_src/api.py", line 2994, in _check_arg
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Argument '[[-100 -100]
[-100 -100]
[-100 -100]
[-100 -100]
[-100 -100]
[-100 -100]
[-100 -100]
[-100 -100]]' of type <class 'numpy.ndarray'> is not a valid JAX type.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "run_flax_speech_recognition_ctc.py", line 1544, in main
state, train_metric = p_train_step(state, batch)
TypeError: Argument '[[-100 -100]
[-100 -100]
[-100 -100]
[-100 -100]
[-100 -100]
[-100 -100]
[-100 -100]
[-100 -100]]' of type <class 'numpy.ndarray'> is not a valid JAX type.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/lib/python3.8/logging/__init__.py", line 1085, in emit
msg = self.format(record)
File "/usr/lib/python3.8/logging/__init__.py", line 929, in format
return fmt.format(record)
File "/usr/lib/python3.8/logging/__init__.py", line 668, in format
record.message = record.getMessage()
File "/usr/lib/python3.8/logging/__init__.py", line 373, in getMessage
msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
File "run_flax_speech_recognition_ctc.py", line 1631, in <module>
main()
File "run_flax_speech_recognition_ctc.py", line 1546, in main
logger.warning("Encountered following error: \n", e)
Message: 'Encountered following error: \n'
Arguments: (TypeError("Argument '[[-100 -100]\n [-100 -100]\n [-100 -100]\n [-100 -100]\n [-100 -100]\n [-100 -100]\n [-100 -100]\n [-100 -100]]' of type <class 'numpy.ndarray'> is not a valid JAX type."),)
It looks as though there are no valid training labels in the batch (all labels are equal to the padding mask idx and overridden to -100 in the data collator). The fact that this only occurs for this batch and on TPU only suggests it's a JAX bug! I'll try and reproduce by saving the numpy array to disk and forcing it through a jit
/pmap
I see. It could then be a tokenization issue? I might've use do_lower_case
in this training like pointed in #23.
For CTC, you can set the max_labels_length=1024
and this should bypass the error. The error is (likely) occurring as the target sequence is longer than the max_labels_length
and is thus being truncated.
Let me know if this doesn't work and we can dig into this further.