jsikyoon/dreamer-torch

Bug when setting config.precision=16

momanto opened this issue · 1 comments

there is a little bug when setting config.precision=16

env = wrappers.CollectDataset(env, callbacks)

which forgets parsing config.precision, resulting still using float precision collecting data. I think it should be like below.

env = wrappers.CollectDataset(env, callbacks, config.precision)

specificly, below shows that if config.precision not passed, the code will use dtype=np.float32 instead of np.float16 as expected

dreamer-torch/wrappers.py

Lines 225 to 230 in 7c2331a

class CollectDataset:
def __init__(self, env, callbacks=None, precision=32):
self._env = env
self._callbacks = callbacks or ()
self._precision = precision

dreamer-torch/wrappers.py

Lines 271 to 274 in 7c2331a

if np.issubdtype(value.dtype, np.floating):
dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision]
elif np.issubdtype(value.dtype, np.signedinteger):
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]