AssertionError in visualizing trained models.
Closed this issue · 1 comments
gcruiser commented
Hi! Thanks for the excellent work!
After I used train_all_copo_dist.py
to train the model and try to visualize it, I found that the dimensions of OBS and Weights did not match, even the same problem in the <new_best_checkpoints>. This problem occurred in every environment I trained, and the number of rows in OBS is always 1 less than the number of columns in Weights.
But I can visualize the <new_best_checkpoints> by the new_vis.py, which is wired.
Traceback (most recent call last):
File "/home/CoPO_tf/CoPO/copo_code/copo/vis_from_checkpoint.py", line 91, in <module>
action = policy_function(o, d)
File "/home/CoPO_tf/CoPO/copo_code/copo/eval/get_policy_function.py", line 170, in __call__
actions = self.policy(obs_to_be_eval)
File "/home/CoPO_tf/CoPO/copo_code/copo/vis_from_checkpoint.py", line 72, in policy
ret = policy_class(weights, obs, policy_name=policy_name, layer_name_suffix="_1", deterministic=deterministic)
File "/home/CoPO_tf/CoPO/copo_code/copo/eval/get_policy_function.py", line 61, in _compute_actions_for_tf_policy
assert obs.shape[1] == weights[s].shape[0], (obs.shape, weights[s].shape)
AssertionError: ((20, 96), (97, 256))
pengzhenghao commented
The problem must come from the missing LCF dimension. In CoPO we add one extra dimension to the obs to include the social orientation.Best,ZhenghaoOn Mar 8, 2024, at 04:06, gcruiser ***@***.***> wrote:
Hi! Thanks for the excellent work!
After I used train_all_copo_dist.py to train the model and try to visualize it, I found that the dimensions of OBS and Weights did not match, even the same problem in the <new_best_checkpoints>. This problem occurred in every environment I trained, and the number of rows in OBS is always 1 less than the number of columns in Weights.
But I can visualize the <new_best_checkpoints> by the new_vis.py, which is wired.
Traceback (most recent call last):
File "/home/CoPO_tf/CoPO/copo_code/copo/vis_from_checkpoint.py", line 91, in <module>
action = policy_function(o, d)
File "/home/CoPO_tf/CoPO/copo_code/copo/eval/get_policy_function.py", line 170, in __call__
actions = self.policy(obs_to_be_eval)
File "/home/CoPO_tf/CoPO/copo_code/copo/vis_from_checkpoint.py", line 72, in policy
ret = policy_class(weights, obs, policy_name=policy_name, layer_name_suffix="_1", deterministic=deterministic)
File "/home/CoPO_tf/CoPO/copo_code/copo/eval/get_policy_function.py", line 61, in _compute_actions_for_tf_policy
assert obs.shape[1] == weights[s].shape[0], (obs.shape, weights[s].shape)
AssertionError: ((20, 96), (97, 256))
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you are subscribed to this thread.Message ID: ***@***.***>