[Figure 3] - Linear Probe Accuracy
kalunho17 opened this issue · 1 comments
Hi,
first of all, great work and thanks for sharing the code !
I want to know, how the linear probe accuracy was obtained in Figure 3.
I tried to run the following command as described in the README.md file:
python src/run.py --eval --n_embd 512 --n_head 8 --n_layer 24
I got around 10% accuracy (instead of >90% as reported in Figure 3 of the paper)
Thanks !
Thanks for the kind words!
During pre-training, the model only minimizes gen_loss and not clf_loss. Therefore, wclf is still the same random matrix it was at initialization, and the logits will be meaningless (explaining the 10% accuracy).
To get the linear probe numbers, you must first extract features (the variable h at a middle layer) for each training example and then train a logistic regression classifier to map the features to the labels on the train set. I personally used scikit-learn's lbfgs solver.