lightaime/TensorAgent

About adding batch normalization

zhiyuanyou opened this issue · 0 comments

博主,你好!
非常感谢你的开源代码!对于我理解DDPG有非常大的帮助!
我使用你的开源代码在新的environment中进行DDPG学习的时候,出现了输出的动作总是动作边界值的情况。经过查找资料,我认为batch normalization可以解决这个问题。
于是,我在不同的层之间,使用tf.contrib.layers.batch_norm函数进行batch normalization。但是改动之后,我收到了很长的一堆报错。

Traceback (most recent call last):
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
    return fn(*args)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InternalError: 2 root error(s) found.
  (0) Internal: cuDNN launch failure : input shape ([1,100,1,1])
	 [[{{node critic_net/ddpg/critic_net/cond_1/batch_norm_2/FusedBatchNorm}}]]
	 [[critic_net/ddpg/critic_net/q_output/Relu/_67]]
  (1) Internal: cuDNN launch failure : input shape ([1,100,1,1])
	 [[{{node critic_net/ddpg/critic_net/cond_1/batch_norm_2/FusedBatchNorm}}]]
0 successful operations.
0 derived errors ignored.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "run_ddpg.py", line 168, in <module>
    main()
  File "run_ddpg.py", line 78, in main
    action_without_clip, q = agent.select_action(state, p)
  File "/home/youzhiyuan/Desktop/TensorAgent_single/agent/ddpg.py", line 21, in select_action
    pred_action, pred_q = self.predict_action(observation)
  File "/home/youzhiyuan/Desktop/TensorAgent_single/agent/ddpg.py", line 14, in predict_action
    return self.model.predict_action_q(observation)
  File "/home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_model.py", line 79, in predict_action_q
    q = self.critic.predict_q_source_net(observation, action, sess)
  File "/home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_critic.py", line 210, in predict_q_source_net
    self.input_action: feed_action})
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
    run_metadata)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: 2 root error(s) found.
  (0) Internal: cuDNN launch failure : input shape ([1,100,1,1])
	 [[node critic_net/ddpg/critic_net/cond_1/batch_norm_2/FusedBatchNorm (defined at /home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_critic.py:259) ]]
	 [[critic_net/ddpg/critic_net/q_output/Relu/_67]]
  (1) Internal: cuDNN launch failure : input shape ([1,100,1,1])
	 [[node critic_net/ddpg/critic_net/cond_1/batch_norm_2/FusedBatchNorm (defined at /home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_critic.py:259) ]]
0 successful operations.
0 derived errors ignored.

Original stack trace for 'critic_net/ddpg/critic_net/cond_1/batch_norm_2/FusedBatchNorm':
  File "run_ddpg.py", line 168, in <module>
    main()
  File "run_ddpg.py", line 53, in main
    tau=TAU)
  File "/home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_model.py", line 52, in __init__
    sess=self.sess)
  File "/home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_critic.py", line 53, in __init__
    self.q_output = self.__create_critic_network()
  File "/home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_critic.py", line 99, in __create_critic_network
    activation=tf.nn.relu)
  File "/home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_critic.py", line 260, in batch_norm_layer
    lambda: batch_norm(x, activation_fn=activation, center=True, scale=True,
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1977, in cond
    orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1814, in BuildCondBranch
    original_result = fn()
  File "/home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_critic.py", line 259, in <lambda>
    scope=scope_bn, decay=0.9, epsilon=1e-5),
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 182, in func_with_args
    return func(*args, **current_args)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/contrib/layers/python/layers/layers.py", line 596, in batch_norm
    scope=scope)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/contrib/layers/python/layers/layers.py", line 383, in _fused_batch_norm
    _fused_batch_norm_inference)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/contrib/layers/python/layers/utils.py", line 214, in smart_cond
    return static_cond(pred_value, fn1, fn2)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/contrib/layers/python/layers/utils.py", line 192, in static_cond
    return fn1()
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/contrib/layers/python/layers/layers.py", line 368, in _fused_batch_norm_training
    inputs, gamma, beta, epsilon=epsilon, data_format=data_format)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py", line 1329, in fused_batch_norm
    name=name)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/ops/gen_nn_ops.py", line 3946, in _fused_batch_norm
    name=name)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
    op_def=op_def)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 2005, in __init__
    self._traceback = tf_stack.extract_stack()

我的系统是ubuntu16.04,我使用的是tensorflow 1.14.0。
卡在这个问题上数天,依然未能解决这个问题。希望博主在有时间的时候可以帮忙解答。
谢谢博主!