apachecn/hands-on-ml-zh

第十章使用tensorflow高级api报错

jwc19890114 opened this issue · 5 comments

按照教程中的代码录入

import tensorflow as tf
import numpy as np
import os
from sklearn.metrics import accuracy_score
from tensorflow.examples.tutorials.mnist import input_data

### tensorflow警告记录,可以避免在运行文件时出现红色警告
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
old_v = tf.logging.get_verbosity()
tf.logging.set_verbosity(tf.logging.ERROR)

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train.astype(np.float32).reshape(-1, 28 * 28) / 255.0
X_test = X_test.astype(np.float32).reshape(-1, 28 * 28) / 255.0
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)

X_valid, X_train = X_train[:5000], X_train[5000:]
y_valid, y_train = y_train[:5000], y_train[5000:]

feature_cols = [tf.feature_column.numeric_column("X", shape=[28 * 28])]
# 下面的代码训练两个隐藏层的 DNN(一个具有 300 个神经元,另一个具有 100 个神经元)和一个具有 10 个神经元的 SOFTMax 输出层
dnn_clf = tf.estimator.DNNClassifier(hidden_units=[300, 100], n_classes=10,
                                     feature_columns=feature_cols)

input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"X": X_train}, y=y_train, num_epochs=40, batch_size=50, shuffle=True)
dnn_clf.train(input_fn=input_fn)

y_pred = list(dnn_clf.predict(X_test))
accuracy=accuracy_score(y_test, y_pred)
print(accuracy)

报错

Traceback (most recent call last):
  File "C:\ProgramData\Anaconda3\envs\python36\lib\inspect.py", line 1119, in getfullargspec
    sigcls=Signature)
  File "C:\ProgramData\Anaconda3\envs\python36\lib\inspect.py", line 2186, in _signature_from_callable
    raise TypeError('{!r} is not a callable object'.format(obj))
TypeError: array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32) is not a callable object

The above exception was the direct cause of the following exception:

!!! list(dnn_clf.predict(X_test))

传入的应该是一个函数

传入的应该是一个函数

我试试看,好像还是搞错了

传入的应该是一个函数

我试试看,好像还是搞错了

我看了下文档 是个 传入的是一个fn 函数 类似 input_fn这样的东西

传入的应该是一个函数

我试试看,好像还是搞错了

我看了下文档 是个 传入的是一个fn 函数 类似 input_fn这样的东西

是我的问题,从新搞了一遍