rasbt/machine-learning-book

detach error

MoroderMind opened this issue · 1 comments

Regarding this file
https://github.com/rasbt/machine-learning-book/blob/main/ch12/ch12_part2.ipynb
At topic (Model training via the torch.nn and torch.optim modules)

---CODE BELOW--
1 y_pred = model(X_test_norm).detach().numpy()
2 fig = plt.figure(figsize=(13, 5))
3 ax = fig.add_subplot(1, 2, 1)
4 plt.plot(X_train_norm.detach().numpy(), y_train.detach().numpy(), 'o', markersize=10)
5 plt.plot(X_test_norm.detach().numpy(), y_pred.detach().numpy(), '--', lw=3)

--ERROR--
Line 5 y_pred.detach().numpy() should just be y_pred, since we already detached and converted to to numpy array.

Thanks for the note. What you are describing sounds correct. I remember changing this though because some readers had issues otherwise; maybe it was OS or PyTorch version related. Have to look into that again. Thanks for reporting though!