备注
本博客图片为本博主原创,未经允许不得使用。
训练和保存
#coding:utf-8 ''' 简单softmax识别手写数字 ''' import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #################################################### # # @brief 定义模型参数 # #################################################### LEARNING_RATE = 0.01 BATCH_NUMBER = 1000 BATCH_SIZE = 50 #################################################### # # @brief 定义输入 # #################################################### # x: 输入向量, None*784维 x = tf.placeholder("float", [None, 784]) # y: 输入标签 y = tf.placeholder("float", [None, 10]) #################################################### # # @brief 定义求解参数 # #################################################### # w: 权重, 784 * 10 维 w = tf.Variable(tf.zeros([784,10]), name = "weights") # b: 截距, 10维 b = tf.Variable(tf.zeros([10]), name = "biases") #################################################### # # @brief 定义优化过程 # #################################################### # predict_y: 预测的结果 x * w + b, None * 10 维 predict_y = tf.nn.softmax(tf.matmul(x, w) + b) # 交叉熵损失函数 cross_entropy = - tf.reduce_sum(y*tf.log(predict_y)) # 梯度下降 train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy) #################################################### # # @brief 定义衡量指标 op # #################################################### equal_op = tf.equal(tf.argmax(y, 1), tf.argmax(predict_y, 1)) accuracy_op = tf.reduce_mean(tf.cast(equal_op, "float")) #################################################### # # @brief 开始求解 # #################################################### # 初始化 op init_op = tf.initialize_all_variables() #mnist数据输入 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) test_xs = mnist.test.images test_ys = mnist.test.labels # 创建一个保存变量的op saver = tf.train.Saver() with tf.Session() as sess: sess = tf.Session() sess.run(init_op) # 迭代 for step in range(BATCH_NUMBER): train_xs, train_ys = mnist.train.next_batch(BATCH_SIZE) sess.run(train_op, feed_dict={x: train_xs, y: train_ys}) train_accuracy = sess.run(accuracy_op, feed_dict={x: train_xs, y: train_ys}) test_accuracy = sess.run(accuracy_op, feed_dict={x: test_xs, y: test_ys}) print("step-%d accuracy: train-%f test-%f" % \ (step, train_accuracy, test_accuracy)) merged = tf.summary.merge_all() logdir = "mnist_softmax" writer = tf.summary.FileWriter(logdir, sess.graph) saver.save(sess, "mnist_softmax.ckpt")
输出
step-0 accuracy: train-0.760000 test-0.387800 step-1 accuracy: train-0.460000 test-0.246000 step-2 accuracy: train-0.660000 test-0.409300 step-3 accuracy: train-0.680000 test-0.511800 step-4 accuracy: train-0.540000 test-0.401700 step-5 accuracy: train-0.740000 test-0.608500 step-6 accuracy: train-0.780000 test-0.602300 step-7 accuracy: train-0.860000 test-0.695600 step-8 accuracy: train-0.920000 test-0.745200 step-9 accuracy: train-0.800000 test-0.795300 step-10 accuracy: train-0.860000 test-0.767400 step-11 accuracy: train-0.840000 test-0.787000 step-12 accuracy: train-0.900000 test-0.728200 step-13 accuracy: train-0.760000 test-0.755700 step-14 accuracy: train-0.880000 test-0.813600 step-15 accuracy: train-0.820000 test-0.736000 step-16 accuracy: train-0.880000 test-0.804200 step-17 accuracy: train-0.860000 test-0.811800 step-18 accuracy: train-0.820000 test-0.808700 step-19 accuracy: train-0.940000 test-0.848500 step-20 accuracy: train-0.900000 test-0.848700 ... step-980 accuracy: train-0.920000 test-0.916600 step-981 accuracy: train-0.980000 test-0.916900 step-982 accuracy: train-0.980000 test-0.916100 step-983 accuracy: train-0.960000 test-0.902900 step-984 accuracy: train-0.900000 test-0.909600 step-985 accuracy: train-0.960000 test-0.911900 step-986 accuracy: train-0.960000 test-0.915300 step-987 accuracy: train-0.920000 test-0.912800 step-988 accuracy: train-0.960000 test-0.912100 step-989 accuracy: train-0.980000 test-0.913600 step-990 accuracy: train-1.000000 test-0.914300 step-991 accuracy: train-1.000000 test-0.912900 step-992 accuracy: train-0.920000 test-0.915700 step-993 accuracy: train-0.940000 test-0.915300 step-994 accuracy: train-0.920000 test-0.917200 step-995 accuracy: train-1.000000 test-0.916600 step-996 accuracy: train-0.960000 test-0.912300 step-997 accuracy: train-0.980000 test-0.915400 step-998 accuracy: train-0.980000 test-0.912800 step-999 accuracy: train-0.960000 test-0.912100
对测试集的正确率为91.2%。
加载和预测
#coding:utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # x: 输入向量, None*784维 x = tf.placeholder("float", [None, 784]) # y: 输入标签 y = tf.placeholder("float", [None, 10]) # w: 权重, 784 * 10 维 w = tf.Variable(tf.zeros([784,10]), name="weights") # b: 截距, 10维 b = tf.Variable(tf.zeros([10]), name="biases") # predict_y: 预测的结果 x * w + b, None * 10 维 predict_y = tf.nn.softmax(tf.matmul(x, w) + b) predict_index = tf.argmax(predict_y, 1) actual_index = tf.argmax(y, 1) # 创建一个初始化变量的op init_op = tf.initialize_all_variables() # 创建一个恢复变量的op restore_saver = tf.train.Saver() #mnist数据输入 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) with tf.Session() as sess: restore_saver.restore(sess, "mnist_softmax.ckpt") for i in range(20): train_x, train_y = mnist.train.next_batch(1) predict_y_result = sess.run(predict_y, feed_dict={x: train_x, y: train_y}) actual_index_result = sess.run(actual_index, feed_dict={x: train_x, y: train_y}) predict_index_result = sess.run(predict_index, feed_dict={x: train_x, y: train_y}) print(predict_y_result) print("actual_index-predict_index: %d-%d" % (actual_index_result[0], predict_index_result[0])) print("") # 显示图像 # sample_image = train_x[0].reshape([28, 28]) # plt.imshow(sample_image, cmap='Greys') # plt.show()
结果:
[[ 2.27737673e-09 1.61471628e-10 5.61408031e-10 4.25548336e-07 1.07422402e-07 3.21843515e-07 3.41314060e-11 9.99893427e-01 7.13851432e-06 9.84186699e-05]] actual_index-predict_index: 7-7 [[ 2.91705161e-04 4.31512708e-05 4.94682230e-03 1.61197092e-02 1.95524044e-05 6.34678360e-03 4.88140467e-05 4.43770614e-06 9.60317433e-01 1.18615944e-02]] actual_index-predict_index: 8-8 [[ 9.65593907e-04 9.07296999e-05 1.73506065e-04 9.88007545e-01 4.58489785e-06 1.04939556e-02 1.06244761e-06 1.47739092e-06 2.48050317e-04 1.34743759e-05]] actual_index-predict_index: 3-3 [[ 1.06552109e-08 2.25054997e-09 1.40786678e-05 9.99978065e-01 1.11751615e-10 5.19393772e-08 2.63402120e-13 7.62957279e-06 1.86133022e-07 3.01438590e-08]] actual_index-predict_index: 3-3 [[ 1.67198846e-06 1.75920504e-05 9.98799562e-01 2.65138922e-04 9.97389270e-06 4.60144520e-06 1.84868681e-04 5.42581780e-04 1.53610366e-04 2.03241852e-05]] actual_index-predict_index: 2-2 [[ 3.69739723e-07 1.74503839e-07 1.00697041e-03 9.98828828e-01 1.56175065e-06 3.33302014e-05 7.79888438e-08 1.10812106e-08 1.28463435e-04 2.14918870e-07]] actual_index-predict_index: 3-3 [[ 2.40125181e-03 4.56383568e-04 8.93679261e-03 4.62570973e-03 1.18294789e-03 9.65041101e-01 6.64983469e-04 8.88758514e-05 1.65748727e-02 2.71801273e-05]] actual_index-predict_index: 5-5 [[ 1.77966540e-05 1.59042357e-09 1.69808357e-06 7.15173201e-07 9.34773505e-01 7.32704430e-05 3.01878637e-04 2.05124146e-04 9.37270117e-04 6.36887103e-02]] actual_index-predict_index: 4-4 [[ 9.51241191e-07 9.74131048e-01 3.95762821e-04 1.12173446e-02 8.90884694e-05 3.00013972e-03 3.69138026e-04 1.07145368e-03 6.07078290e-03 3.65441106e-03]] actual_index-predict_index: 1-1 [[ 4.04238598e-08 3.96338903e-04 6.11227704e-03 8.80150183e-05 4.64237407e-02 4.99675400e-04 9.44205821e-01 6.43441163e-05 1.91634858e-03 2.93356541e-04]] actual_index-predict_index: 6-6 [[ 5.39879807e-07 1.18812895e-03 9.73667026e-01 1.29849568e-03 4.50128391e-06 1.38689313e-06 2.31403075e-02 3.88985136e-05 6.31470699e-04 2.92315763e-05]] actual_index-predict_index: 2-2 [[ 6.03592639e-07 4.97014215e-03 1.13697453e-04 4.27568890e-03 2.15098672e-02 1.18463486e-03 1.52472048e-05 4.04654145e-01 4.63497564e-02 5.16926169e-01]] actual_index-predict_index: 9-9 [[ 2.35136740e-05 5.77873434e-04 1.13111427e-02 1.12604386e-04 8.26002579e-05 7.75026949e-03 6.34858952e-05 3.67183304e-07 9.80032861e-01 4.53466237e-05]] actual_index-predict_index: 8-8 [[ 3.33027217e-10 3.93268307e-09 3.60386339e-11 4.50129755e-06 4.91710729e-04 1.53179906e-06 1.61973068e-09 9.98264849e-01 1.64548197e-04 1.07282540e-03]] actual_index-predict_index: 7-7 [[ 1.26447651e-06 1.93439351e-04 9.77273166e-01 2.05740426e-03 1.00293578e-08 1.01640808e-05 2.70258147e-06 1.88517957e-09 2.04617772e-02 2.35834854e-08]] actual_index-predict_index: 2-2 [[ 3.79381308e-05 6.86537874e-07 8.17690452e-04 7.74423679e-05 1.53992577e-02 7.43901008e-04 2.05788703e-04 1.32311573e-02 5.59218116e-02 9.13564324e-01]] actual_index-predict_index: 9-9 [[ 2.73115584e-03 1.74025251e-06 6.87140755e-06 2.08953805e-02 5.36242624e-06 9.69347894e-01 1.07080101e-04 6.18056438e-05 6.35029143e-03 4.92293271e-04]] actual_index-predict_index: 5-5 [[ 7.86116743e-07 1.04552935e-08 1.81989253e-05 7.68434547e-05 1.13810336e-06 7.77041862e-07 1.55642266e-09 9.99853492e-01 1.77232632e-05 3.09368443e-05]] actual_index-predict_index: 7-7 [[ 1.02413753e-04 9.51294351e-06 5.58499340e-03 1.88416038e-06 1.57446731e-02 3.39030550e-04 9.76482511e-01 2.36741049e-04 8.16290150e-04 6.81902049e-04]] actual_index-predict_index: 6-6 [[ 6.79365621e-05 3.67236702e-04 2.03259886e-04 6.43721316e-04 4.16841246e-02 4.65681590e-03 3.68818961e-04 6.46927476e-01 9.28802881e-03 2.95792609e-01]] actual_index-predict_index: 9-7
最后一个预测错误的例子,其实预测为9的概率是0.2958,预测为7的概率为0.6469,预测为9的概率就比预测为7的概率小点。
举个例子:
[[ 6.56793145e-06 6.04155775e-06 2.56523769e-02 1.30315811e-05
1.28089413e-02 1.09560778e-02 9.47358847e-01 1.76747517e-06
3.17676389e-03 1.94756776e-05]]
actual_index-predict_index: 6-6
可视化表示
可视化表示代码:
for i in range(10): data = w[:,i].eval().reshape([28, 28]) cmap = plt.cm.winter # 可以使用自定义的colormap im = plt.imshow(data, cmap=cmap) plt.colorbar(im) plt.show()
将权重可视化表示,黄色代表权重为正,褐色代表权重为负,图中图片依次是9、8、7...、2、1、0的特征权重可视化的图片。