MorvanZhou/train-classifier-from-scratch

可视化有问题

daidai21 opened this issue · 0 comments

就是可视化的时候,你原来的代码对于绘图是不懂的,只会显示第一次的,我改成下面这样就好了,但这个貌似不是最优的方法吧,求指导
`# training
accuracies, steps = [], []
for t in range(4000):
# training
batch_index = np.random.randint(len(train_data), size=32)
sess.run(train_op, {tf_input: train_data[batch_index]})

if t % 50 == 0:
    # testing
    acc_, pred_, loss_ = sess.run([accuracy, prediction, loss], {tf_input: test_data})
    accuracies.append(acc_)
    steps.append(t)
    print("Step: %i" % t,"| Accurate: %.2f" % acc_,"| Loss: %.2f" % loss_,)

    # visualize testing
    plt.ion()
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
    ax1.cla()
    for c in range(4):
        bp = ax1.bar(c+0.1, height=sum((np.argmax(pred_, axis=1) == c)), width=0.2, color='red')
        bt = ax1.bar(c-0.1, height=sum((np.argmax(test_data[:, 21:], axis=1) == c)), width=0.2, color='blue')
    ax1.set_xticks(range(4), ["accepted", "good", "unaccepted", "very good"])
    ax1.legend(handles=[bp, bt], labels=["prediction", "target"])
    ax1.set_ylim((0, 400))
    ax2.cla()
    ax2.plot(steps, accuracies, label="accuracy")
    ax2.set_ylim(ymax=1)
    ax2.set_ylabel("accuracy")
    plt.pause(0.01)
    plt.ioff()
    plt.show()

`