tud-phi/ics-pa-sv

Bug: Task_3 test 2 fails for any Q initialization greater than 100

Opened this issue · 0 comments

Bug in:
https://github.com/tud-cor-sr/ics-pa-sv/blob/ca94a78cdb3c2cfdd9f3475e7d6869a6d83205ae/assignment/problem_3/task_3_rl.ipynb?short_path=28a759c#L1219-L1247

Test fails if Q is initialized to anything >= 100.

Rendered:

# Test 2
Q1 = Q

key = random.PRNGKey(758493)

opt_action = 2
Q1 = Q1.at[5, 5, opt_action].set(100) # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
count_greedy = 0
count_random = 0
test_count = 10000

pbar = tqdm(total=test_count)
for i in range(test_count):
    a, key = evaluate_policy(Q1, jnp.array([5, 5]), SARSA_PARAMS, key)
    if a == opt_action:
        count_greedy = count_greedy + 1
    else:
        count_random = count_random + 1
    pbar.update(1)
pbar.close()

count_greedy_ref = (
    1 - SARSA_PARAMS["epsilon"] + SARSA_PARAMS["epsilon"] * 1 / SARSA_PARAMS["actions"]
)

assert (
    count_greedy / test_count > 0.9 * count_greedy_ref
    and count_greedy / test_count < 1.1 * count_greedy_ref
)