ageron/handson-ml3

[QUESTION] Figure 4-8

momodz16 opened this issue · 0 comments

def plot_gradient_descent(theta, eta):
m = len(X_b)

n_epochs = 1000
n_shown = 20
theta_path = []
for epoch in range(n_epochs):
    if epoch < n_shown:
        y_predict = X_new_b @ theta
        color = mpl.colors.rgb2hex(plt.cm.OrRd(epoch / n_shown + 0.15))
        plt.plot(X_new, y_predict, linestyle="solid", color=color)
    gradients = 2 / m * X_b.T @ (X_b @ theta - y)
    theta = theta - eta * gradients
    theta_path.append(theta)

#PUT THIS LINE AFTER THE LOOP TO GET BETTER Figure 4-8. THE DATASET IS NOW MORE VISIBLE
plt.plot(X, y, "b.")

plt.xlabel("$x_1$")
plt.axis([0, 2, 0, 15])
plt.grid()
plt.title(fr"$\eta = {eta}$")
return theta_path