Fix warning for numpy ndarray creation in BaseModel.py
Closed this issue · 1 comments
TalonCB commented
Warning: ReChorus/src/models/BaseModel.py:147: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
Location: BaseModel.py line 147
Reason: in the latest Numpy, the np.array()
requires to specify dtype=np.object
if the lengths of arrays are different.
Possible Solution: replace line 147 with
if isinstance(feed_dicts[0][key], np.ndarray):
tmp_list = [len(d[key]) for d in feed_dicts]
if any([tmp_list[0] != l for l in tmp_list]):
stack_val = np.array([d[key] for d in feed_dicts], dtype=np.object)
else:
stack_val = np.array([d[key] for d in feed_dicts])
else:
stack_val = np.array([d[key] for d in feed_dicts])
THUwangcy commented
Thanks for the nice suggestion! We have fixed this and will update in the next commit.