THUwangcy/ReChorus

Fix warning for numpy ndarray creation in BaseModel.py

Closed this issue · 1 comments

Warning: ReChorus/src/models/BaseModel.py:147: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.

Location: BaseModel.py line 147

Reason: in the latest Numpy, the np.array() requires to specify dtype=np.object if the lengths of arrays are different.

Possible Solution: replace line 147 with

if isinstance(feed_dicts[0][key], np.ndarray):
    tmp_list = [len(d[key]) for d in feed_dicts]
    if any([tmp_list[0] != l for l in tmp_list]):
        stack_val = np.array([d[key] for d in feed_dicts], dtype=np.object)
    else:
        stack_val = np.array([d[key] for d in feed_dicts])
else:
    stack_val = np.array([d[key] for d in feed_dicts])

Thanks for the nice suggestion! We have fixed this and will update in the next commit.