Loading model fails in tutorial
velezbeltran opened this issue · 4 comments
Hello!
I have been working on the ACDC_Main_Demo.ipynb
repo and I am currently facing an issue where if I attempt to load the model from a subgraph I get an error. In particular I attempt the following steps.
- I run the notebook as is adding one additional line where I save the output of the subgraph using the command below
exp.save_subgraph(
save_path,
return_it=True,
)
- I run all of the cells except for the cell containing the code
for i in range(args.max_num_epochs):
exp.step(testing=False)
show(
exp.corr,
f"ims/img_new_{i+1}.png",
show_full_index=False,
)
if IN_COLAB or ipython is not None:
# so long as we're not running this as a script, show the image!
display(Image(f"ims/img_new_{i+1}.png"))
print(i, "-" * 50)
print(exp.count_no_edges())
if i == 0:
exp.save_edges("edges.pkl")
if exp.current_node is None or SINGLE_STEP:
break
exp.save_edges("another_final_edges.pkl")
if USING_WANDB:
edges_fname = f"edges.pth"
exp.save_edges(edges_fname)
artifact = wandb.Artifact(edges_fname, type="dataset")
artifact.add_file(edges_fname)
wandb.log_artifact(artifact)
os.remove(edges_fname)
wandb.finish()
- I load the subgraph using
# load using torch
circuit = t.load(subgraph_path)
exp.load_subgraph(circuit)
If I do this I get the following assertion error:
AssertionError: Ensure that the dictionary includes exactly the correct keys... e.g missing [('blocks.1.hook_q_input', (None, None, 0), 'blocks.0.attn.hook_result', (None, None, 1))] and has excess stuff []
What could be causing this? Am I doing something wrong? Alternatively, what is the standard way of loading in circuits?
Also, if I do run the cell that contains the .step()
method I don't have this issue.
Thank you!
Nicolas
Possibly the TransformerLens version you're using is different from the one that was used to save the hypothesis, so the hook names are different. What's the list of edges from exp.corr.all_edges().keys()
?
Thanks for your lighting fast response!
Before running the `.step()` function block
ict_keys([('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 7]), ('blocks.1.hook_resid_post', [:],
'blocks.1.attn.hook_result', [:, :, 6]), ('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 5]),
('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 4]), ('blocks.1.hook_resid_post', [:],
'blocks.1.attn.hook_result', [:, :, 3]), ('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 2]),
('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 1]), ('blocks.1.hook_resid_post', [:],
'blocks.1.attn.hook_result', [:, :, 0]), ('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 7]),
('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_resid_post', [:],
'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 4]),
('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_resid_post', [:],
'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 1]),
('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_resid_post', [:],
'blocks.0.hook_resid_pre', [:]), ('blocks.1.attn.hook_result', [:, :, 7], 'blocks.1.attn.hook_q', [:, :, 7]),
('blocks.1.attn.hook_result', [:, :, 7], 'blocks.1.attn.hook_k', [:, :, 7]), ('blocks.1.attn.hook_result', [:, :, 7],
'blocks.1.attn.hook_v', [:, :, 7]), ('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_q', [:, :, 6]),
('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_k', [:, :, 6]), ('blocks.1.attn.hook_result', [:, :, 6],
'blocks.1.attn.hook_v', [:, :, 6]), ('blocks.1.attn.hook_result', [:, :, 5], 'blocks.1.attn.hook_q', [:, :, 5]),
('blocks.1.attn.hook_result', [:, :, 5], 'blocks.1.attn.hook_k', [:, :, 5]), ('blocks.1.attn.hook_result', [:, :, 5],
'blocks.1.attn.hook_v', [:, :, 5]), ('blocks.1.attn.hook_result', [:, :, 4], 'blocks.1.attn.hook_q', [:, :, 4]),
('blocks.1.attn.hook_result', [:, :, 4], 'blocks.1.attn.hook_k', [:, :, 4]), ('blocks.1.attn.hook_result', [:, :, 4],
'blocks.1.attn.hook_v', [:, :, 4]), ('blocks.1.attn.hook_result', [:, :, 3], 'blocks.1.attn.hook_q', [:, :, 3]),
('blocks.1.attn.hook_result', [:, :, 3], 'blocks.1.attn.hook_k', [:, :, 3]), ('blocks.1.attn.hook_result', [:, :, 3],
'blocks.1.attn.hook_v', [:, :, 3]), ('blocks.1.attn.hook_result', [:, :, 2], 'blocks.1.attn.hook_q', [:, :, 2]),
('blocks.1.attn.hook_result', [:, :, 2], 'blocks.1.attn.hook_k', [:, :, 2]), ('blocks.1.attn.hook_result', [:, :, 2],
'blocks.1.attn.hook_v', [:, :, 2]), ('blocks.1.attn.hook_result', [:, :, 1], 'blocks.1.attn.hook_q', [:, :, 1]),
('blocks.1.attn.hook_result', [:, :, 1], 'blocks.1.attn.hook_k', [:, :, 1]), ('blocks.1.attn.hook_result', [:, :, 1],
'blocks.1.attn.hook_v', [:, :, 1]), ('blocks.1.attn.hook_result', [:, :, 0], 'blocks.1.attn.hook_q', [:, :, 0]),
('blocks.1.attn.hook_result', [:, :, 0], 'blocks.1.attn.hook_k', [:, :, 0]), ('blocks.1.attn.hook_result', [:, :, 0],
'blocks.1.attn.hook_v', [:, :, 0]), ('blocks.1.attn.hook_q', [:, :, 7], 'blocks.1.hook_q_input', [:, :, 7]),
('blocks.1.attn.hook_q', [:, :, 6], 'blocks.1.hook_q_input', [:, :, 6]),
('blocks.1.attn.hook_q', [:, :, 5], 'blocks.1.hook_q_input', [:, :, 5]),
('blocks.1.attn.hook_q', [:, :, 4], 'blocks.1.hook_q_input', [:, :, 4]),
('blocks.1.attn.hook_q', [:, :, 3], 'blocks.1.hook_q_input', [:, :, 3]),
('blocks.1.attn.hook_q', [:, :, 2], 'blocks.1.hook_q_input', [:, :, 2]),
('blocks.1.attn.hook_q', [:, :, 1], 'blocks.1.hook_q_input', [:, :, 1]),
('blocks.1.attn.hook_q', [:, :, 0], 'blocks.1.hook_q_input', [:, :, 0]),
('blocks.1.attn.hook_k', [:, :, 7], 'blocks.1.hook_k_input', [:, :, 7]),
('blocks.1.attn.hook_k', [:, :, 6], 'blocks.1.hook_k_input', [:, :, 6]),
('blocks.1.attn.hook_k', [:, :, 5], 'blocks.1.hook_k_input', [:, :, 5]),
('blocks.1.attn.hook_k', [:, :, 4], 'blocks.1.hook_k_input', [:, :, 4]), ('blocks.1.attn.hook_k', [:, :, 3], 'blocks.1.hook_k_input', [:, :, 3]), ('blocks.1.attn.hook_k', [:, :, 2], 'blocks.1.hook_k_input', [:, :, 2]), ('blocks.1.attn.hook_k', [:, :, 1], 'blocks.1.hook_k_input', [:, :, 1]), ('blocks.1.attn.hook_k', [:, :, 0], 'blocks.1.hook_k_input', [:, :, 0]), ('blocks.1.attn.hook_v', [:, :, 7], 'blocks.1.hook_v_input', [:, :, 7]), ('blocks.1.attn.hook_v', [:, :, 6], 'blocks.1.hook_v_input', [:, :, 6]), ('blocks.1.attn.hook_v', [:, :, 5], 'blocks.1.hook_v_input', [:, :, 5]), ('blocks.1.attn.hook_v', [:, :, 4], 'blocks.1.hook_v_input', [:, :, 4]), ('blocks.1.attn.hook_v', [:, :, 3], 'blocks.1.hook_v_input', [:, :, 3]), ('blocks.1.attn.hook_v', [:, :, 2], 'blocks.1.hook_v_input', [:, :, 2]), ('blocks.1.attn.hook_v', [:, :, 1], 'blocks.1.hook_v_input', [:, :, 1]), ('blocks.1.attn.hook_v', [:, :, 0], 'blocks.1.hook_v_input', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.attn.hook_result', [:, :, 7], 'blocks.0.attn.hook_q', [:, :, 7]), ('blocks.0.attn.hook_result', [:, :, 7], 'blocks.0.attn.hook_k', [:, :, 7]), ('blocks.0.attn.hook_result', [:, :, 7], 'blocks.0.attn.hook_v', [:, :, 7]), ('blocks.0.attn.hook_result', [:, :, 6], 'blocks.0.attn.hook_q', [:, :, 6]), ('blocks.0.attn.hook_result', [:, :, 6], 'blocks.0.attn.hook_k', [:, :, 6]), ('blocks.0.attn.hook_result', [:, :, 6], 'blocks.0.attn.hook_v', [:, :, 6]), ('blocks.0.attn.hook_result', [:, :, 5], 'blocks.0.attn.hook_q', [:, :, 5]), ('blocks.0.attn.hook_result', [:, :, 5], 'blocks.0.attn.hook_k', [:, :, 5]), ('blocks.0.attn.hook_result', [:, :, 5], 'blocks.0.attn.hook_v', [:, :, 5]), ('blocks.0.attn.hook_result', [:, :, 4], 'blocks.0.attn.hook_q', [:, :, 4]), ('blocks.0.attn.hook_result', [:, :, 4], 'blocks.0.attn.hook_k', [:, :, 4]), ('blocks.0.attn.hook_result', [:, :, 4], 'blocks.0.attn.hook_v', [:, :, 4]), ('blocks.0.attn.hook_result', [:, :, 3], 'blocks.0.attn.hook_q', [:, :, 3]), ('blocks.0.attn.hook_result', [:, :, 3], 'blocks.0.attn.hook_k', [:, :, 3]), ('blocks.0.attn.hook_result', [:, :, 3], 'blocks.0.attn.hook_v', [:, :, 3]), ('blocks.0.attn.hook_result', [:, :, 2], 'blocks.0.attn.hook_q', [:, :, 2]), ('blocks.0.attn.hook_result', [:, :, 2], 'blocks.0.attn.hook_k', [:, :, 2]), ('blocks.0.attn.hook_result', [:, :, 2], 'blocks.0.attn.hook_v', [:, :, 2]), ('blocks.0.attn.hook_result', [:, :, 1], 'blocks.0.attn.hook_q', [:, :, 1]), ('blocks.0.attn.hook_result', [:, :, 1], 'blocks.0.attn.hook_k', [:, :, 1]), ('blocks.0.attn.hook_result', [:, :, 1], 'blocks.0.attn.hook_v', [:, :, 1]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_q', [:, :, 0]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_k', [:, :, 0]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_v', [:, :, 0]), ('blocks.0.attn.hook_q', [:, :, 7], 'blocks.0.hook_q_input', [:, :, 7]), ('blocks.0.attn.hook_q', [:, :, 6], 'blocks.0.hook_q_input', [:, :, 6]), ('blocks.0.attn.hook_q', [:, :, 5], 'blocks.0.hook_q_input', [:, :, 5]), ('blocks.0.attn.hook_q', [:, :, 4], 'blocks.0.hook_q_input', [:, :, 4]), ('blocks.0.attn.hook_q', [:, :, 3], 'blocks.0.hook_q_input', [:, :, 3]), ('blocks.0.attn.hook_q', [:, :, 2], 'blocks.0.hook_q_input', [:, :, 2]), ('blocks.0.attn.hook_q', [:, :, 1], 'blocks.0.hook_q_input', [:, :, 1]), ('blocks.0.attn.hook_q', [:, :, 0], 'blocks.0.hook_q_input', [:, :, 0]), ('blocks.0.attn.hook_k', [:, :, 7], 'blocks.0.hook_k_input', [:, :, 7]), ('blocks.0.attn.hook_k', [:, :, 6], 'blocks.0.hook_k_input', [:, :, 6]), ('blocks.0.attn.hook_k', [:, :, 5], 'blocks.0.hook_k_input', [:, :, 5]), ('blocks.0.attn.hook_k', [:, :, 4], 'blocks.0.hook_k_input', [:, :, 4]), ('blocks.0.attn.hook_k', [:, :, 3], 'blocks.0.hook_k_input', [:, :, 3]), ('blocks.0.attn.hook_k', [:, :, 2], 'blocks.0.hook_k_input', [:, :, 2]), ('blocks.0.attn.hook_k', [:, :, 1], 'blocks.0.hook_k_input', [:, :, 1]), ('blocks.0.attn.hook_k', [:, :, 0], 'blocks.0.hook_k_input', [:, :, 0]), ('blocks.0.attn.hook_v', [:, :, 7], 'blocks.0.hook_v_input', [:, :, 7]), ('blocks.0.attn.hook_v', [:, :, 6], 'blocks.0.hook_v_input', [:, :, 6]), ('blocks.0.attn.hook_v', [:, :, 5], 'blocks.0.hook_v_input', [:, :, 5]), ('blocks.0.attn.hook_v', [:, :, 4], 'blocks.0.hook_v_input', [:, :, 4]), ('blocks.0.attn.hook_v', [:, :, 3], 'blocks.0.hook_v_input', [:, :, 3]), ('blocks.0.attn.hook_v', [:, :, 2], 'blocks.0.hook_v_input', [:, :, 2]), ('blocks.0.attn.hook_v', [:, :, 1], 'blocks.0.hook_v_input', [:, :, 1]), ('blocks.0.attn.hook_v', [:, :, 0], 'blocks.0.hook_v_input', [:, :, 0]), ('blocks.0.hook_q_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:])])
After the step function
After running the `.step()` function block
dict_keys([('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 6]), ('blocks.1.hook_resid_post', [:], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_q', [:, :, 6]), ('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_k', [:, :, 6]), ('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_v', [:, :, 6]), ('blocks.1.attn.hook_q', [:, :, 6], 'blocks.1.hook_q_input', [:, :, 6]), ('blocks.1.attn.hook_k', [:, :, 6], 'blocks.1.hook_k_input', [:, :, 6]), ('blocks.1.attn.hook_v', [:, :, 6], 'blocks.1.hook_v_input', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_q', [:, :, 0]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_k', [:, :, 0]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_v', [:, :, 0]), ('blocks.0.attn.hook_v', [:, :, 0], 'blocks.0.hook_v_input', [:, :, 0]), ('blocks.0.hook_v_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:])])
I don't think the issue is that the TransformerLens versions are different because I can reproduce this all from the same notebook in colab.
Thank you
Turns out the explanation is: the ACDC algorithm literally removes edges (i.e. removes them from the correspondence dictionaries), as opposed to saying edge.present = False
. That makes it fail when loading.
The loading code should be changed to fix this.
@velezbeltran I'm curious if you would be so kind to share the working code for loading the subgraph weights edges.pth
for inference. I did not quite catch from @rhaps0dy what the modification should be and where. Thanks!