[QUESTION] Figure 4-8
momodz16 opened this issue · 0 comments
momodz16 commented
def plot_gradient_descent(theta, eta):
m = len(X_b)
n_epochs = 1000
n_shown = 20
theta_path = []
for epoch in range(n_epochs):
if epoch < n_shown:
y_predict = X_new_b @ theta
color = mpl.colors.rgb2hex(plt.cm.OrRd(epoch / n_shown + 0.15))
plt.plot(X_new, y_predict, linestyle="solid", color=color)
gradients = 2 / m * X_b.T @ (X_b @ theta - y)
theta = theta - eta * gradients
theta_path.append(theta)
#PUT THIS LINE AFTER THE LOOP TO GET BETTER Figure 4-8. THE DATASET IS NOW MORE VISIBLE
plt.plot(X, y, "b.")
plt.xlabel("$x_1$")
plt.axis([0, 2, 0, 15])
plt.grid()
plt.title(fr"$\eta = {eta}$")
return theta_path