Bug: Task_3 test 2 fails for any Q initialization greater than 100
Opened this issue · 0 comments
Jesperoka commented
Test fails if Q is initialized to anything >= 100.
Rendered:
# Test 2
Q1 = Q
key = random.PRNGKey(758493)
opt_action = 2
Q1 = Q1.at[5, 5, opt_action].set(100) # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
count_greedy = 0
count_random = 0
test_count = 10000
pbar = tqdm(total=test_count)
for i in range(test_count):
a, key = evaluate_policy(Q1, jnp.array([5, 5]), SARSA_PARAMS, key)
if a == opt_action:
count_greedy = count_greedy + 1
else:
count_random = count_random + 1
pbar.update(1)
pbar.close()
count_greedy_ref = (
1 - SARSA_PARAMS["epsilon"] + SARSA_PARAMS["epsilon"] * 1 / SARSA_PARAMS["actions"]
)
assert (
count_greedy / test_count > 0.9 * count_greedy_ref
and count_greedy / test_count < 1.1 * count_greedy_ref
)