备注

本博客图片为本博主原创,未经允许不得使用。

训练和保存

#coding:utf-8
'''
简单softmax识别手写数字
'''
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

####################################################
#
# @brief 定义模型参数
#
####################################################

LEARNING_RATE   = 0.01
BATCH_NUMBER    = 1000
BATCH_SIZE      = 50


####################################################
#
# @brief 定义输入
#
####################################################

# x: 输入向量, None*784维
x = tf.placeholder("float", [None, 784])

# y: 输入标签
y = tf.placeholder("float", [None, 10])


####################################################
#
# @brief 定义求解参数
#
####################################################

# w: 权重, 784 * 10 维
w = tf.Variable(tf.zeros([784,10]), name = "weights")

# b: 截距, 10维
b = tf.Variable(tf.zeros([10]), name = "biases")


####################################################
#
# @brief 定义优化过程
#
####################################################

# predict_y: 预测的结果 x * w + b, None * 10 维
predict_y = tf.nn.softmax(tf.matmul(x, w) + b)

# 交叉熵损失函数
cross_entropy = - tf.reduce_sum(y*tf.log(predict_y))

# 梯度下降
train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy)


####################################################
#
# @brief 定义衡量指标 op
#
####################################################

equal_op        = tf.equal(tf.argmax(y, 1), tf.argmax(predict_y, 1))
accuracy_op     = tf.reduce_mean(tf.cast(equal_op, "float"))


####################################################
#
# @brief 开始求解
#
####################################################

# 初始化 op
init_op = tf.initialize_all_variables()

#mnist数据输入
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

test_xs = mnist.test.images
test_ys = mnist.test.labels

# 创建一个保存变量的op
saver = tf.train.Saver()

with tf.Session() as sess:
    sess = tf.Session()
    sess.run(init_op)
    # 迭代
    for step in range(BATCH_NUMBER):
        train_xs, train_ys = mnist.train.next_batch(BATCH_SIZE)
        sess.run(train_op, feed_dict={x: train_xs, y: train_ys})
        train_accuracy = sess.run(accuracy_op, feed_dict={x: train_xs, y: train_ys})
        test_accuracy = sess.run(accuracy_op, feed_dict={x: test_xs, y: test_ys})

        print("step-%d accuracy: train-%f test-%f" % \
              (step, train_accuracy, test_accuracy))

    merged = tf.summary.merge_all()
    logdir = "mnist_softmax"
    writer = tf.summary.FileWriter(logdir, sess.graph)

    saver.save(sess, "mnist_softmax.ckpt")


输出

step-0 accuracy: train-0.760000 test-0.387800
step-1 accuracy: train-0.460000 test-0.246000
step-2 accuracy: train-0.660000 test-0.409300
step-3 accuracy: train-0.680000 test-0.511800
step-4 accuracy: train-0.540000 test-0.401700
step-5 accuracy: train-0.740000 test-0.608500
step-6 accuracy: train-0.780000 test-0.602300
step-7 accuracy: train-0.860000 test-0.695600
step-8 accuracy: train-0.920000 test-0.745200
step-9 accuracy: train-0.800000 test-0.795300
step-10 accuracy: train-0.860000 test-0.767400
step-11 accuracy: train-0.840000 test-0.787000
step-12 accuracy: train-0.900000 test-0.728200
step-13 accuracy: train-0.760000 test-0.755700
step-14 accuracy: train-0.880000 test-0.813600
step-15 accuracy: train-0.820000 test-0.736000
step-16 accuracy: train-0.880000 test-0.804200
step-17 accuracy: train-0.860000 test-0.811800
step-18 accuracy: train-0.820000 test-0.808700
step-19 accuracy: train-0.940000 test-0.848500
step-20 accuracy: train-0.900000 test-0.848700
...
step-980 accuracy: train-0.920000 test-0.916600
step-981 accuracy: train-0.980000 test-0.916900
step-982 accuracy: train-0.980000 test-0.916100
step-983 accuracy: train-0.960000 test-0.902900
step-984 accuracy: train-0.900000 test-0.909600
step-985 accuracy: train-0.960000 test-0.911900
step-986 accuracy: train-0.960000 test-0.915300
step-987 accuracy: train-0.920000 test-0.912800
step-988 accuracy: train-0.960000 test-0.912100
step-989 accuracy: train-0.980000 test-0.913600
step-990 accuracy: train-1.000000 test-0.914300
step-991 accuracy: train-1.000000 test-0.912900
step-992 accuracy: train-0.920000 test-0.915700
step-993 accuracy: train-0.940000 test-0.915300
step-994 accuracy: train-0.920000 test-0.917200
step-995 accuracy: train-1.000000 test-0.916600
step-996 accuracy: train-0.960000 test-0.912300
step-997 accuracy: train-0.980000 test-0.915400
step-998 accuracy: train-0.980000 test-0.912800
step-999 accuracy: train-0.960000 test-0.912100

对测试集的正确率为91.2%。


加载和预测

#coding:utf-8

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# x: 输入向量, None*784维
x = tf.placeholder("float", [None, 784])

# y: 输入标签
y = tf.placeholder("float", [None, 10])

# w: 权重, 784 * 10 维
w = tf.Variable(tf.zeros([784,10]), name="weights")

# b: 截距, 10维
b = tf.Variable(tf.zeros([10]), name="biases")

# predict_y: 预测的结果 x * w + b, None * 10 维
predict_y = tf.nn.softmax(tf.matmul(x, w) + b)

predict_index   = tf.argmax(predict_y, 1)
actual_index    = tf.argmax(y, 1)

# 创建一个初始化变量的op
init_op = tf.initialize_all_variables()

# 创建一个恢复变量的op
restore_saver = tf.train.Saver()

#mnist数据输入
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

with tf.Session() as sess:
    restore_saver.restore(sess, "mnist_softmax.ckpt")
    for i in range(20):
        train_x, train_y = mnist.train.next_batch(1)
        predict_y_result        = sess.run(predict_y, feed_dict={x: train_x, y: train_y})
        actual_index_result     = sess.run(actual_index, feed_dict={x: train_x, y: train_y})
        predict_index_result    = sess.run(predict_index, feed_dict={x: train_x, y: train_y})
        print(predict_y_result)
        print("actual_index-predict_index: %d-%d" % (actual_index_result[0], predict_index_result[0]))
        print("")
        
        # 显示图像
        # sample_image = train_x[0].reshape([28, 28])
        # plt.imshow(sample_image, cmap='Greys')
        # plt.show()

结果:

[[  2.27737673e-09   1.61471628e-10   5.61408031e-10   4.25548336e-07
    1.07422402e-07   3.21843515e-07   3.41314060e-11   9.99893427e-01
    7.13851432e-06   9.84186699e-05]]
actual_index-predict_index: 7-7

[[  2.91705161e-04   4.31512708e-05   4.94682230e-03   1.61197092e-02
    1.95524044e-05   6.34678360e-03   4.88140467e-05   4.43770614e-06
    9.60317433e-01   1.18615944e-02]]
actual_index-predict_index: 8-8

[[  9.65593907e-04   9.07296999e-05   1.73506065e-04   9.88007545e-01
    4.58489785e-06   1.04939556e-02   1.06244761e-06   1.47739092e-06
    2.48050317e-04   1.34743759e-05]]
actual_index-predict_index: 3-3

[[  1.06552109e-08   2.25054997e-09   1.40786678e-05   9.99978065e-01
    1.11751615e-10   5.19393772e-08   2.63402120e-13   7.62957279e-06
    1.86133022e-07   3.01438590e-08]]
actual_index-predict_index: 3-3

[[  1.67198846e-06   1.75920504e-05   9.98799562e-01   2.65138922e-04
    9.97389270e-06   4.60144520e-06   1.84868681e-04   5.42581780e-04
    1.53610366e-04   2.03241852e-05]]
actual_index-predict_index: 2-2

[[  3.69739723e-07   1.74503839e-07   1.00697041e-03   9.98828828e-01
    1.56175065e-06   3.33302014e-05   7.79888438e-08   1.10812106e-08
    1.28463435e-04   2.14918870e-07]]
actual_index-predict_index: 3-3

[[  2.40125181e-03   4.56383568e-04   8.93679261e-03   4.62570973e-03
    1.18294789e-03   9.65041101e-01   6.64983469e-04   8.88758514e-05
    1.65748727e-02   2.71801273e-05]]
actual_index-predict_index: 5-5

[[  1.77966540e-05   1.59042357e-09   1.69808357e-06   7.15173201e-07
    9.34773505e-01   7.32704430e-05   3.01878637e-04   2.05124146e-04
    9.37270117e-04   6.36887103e-02]]
actual_index-predict_index: 4-4

[[  9.51241191e-07   9.74131048e-01   3.95762821e-04   1.12173446e-02
    8.90884694e-05   3.00013972e-03   3.69138026e-04   1.07145368e-03
    6.07078290e-03   3.65441106e-03]]
actual_index-predict_index: 1-1

[[  4.04238598e-08   3.96338903e-04   6.11227704e-03   8.80150183e-05
    4.64237407e-02   4.99675400e-04   9.44205821e-01   6.43441163e-05
    1.91634858e-03   2.93356541e-04]]
actual_index-predict_index: 6-6

[[  5.39879807e-07   1.18812895e-03   9.73667026e-01   1.29849568e-03
    4.50128391e-06   1.38689313e-06   2.31403075e-02   3.88985136e-05
    6.31470699e-04   2.92315763e-05]]
actual_index-predict_index: 2-2

[[  6.03592639e-07   4.97014215e-03   1.13697453e-04   4.27568890e-03
    2.15098672e-02   1.18463486e-03   1.52472048e-05   4.04654145e-01
    4.63497564e-02   5.16926169e-01]]
actual_index-predict_index: 9-9

[[  2.35136740e-05   5.77873434e-04   1.13111427e-02   1.12604386e-04
    8.26002579e-05   7.75026949e-03   6.34858952e-05   3.67183304e-07
    9.80032861e-01   4.53466237e-05]]
actual_index-predict_index: 8-8

[[  3.33027217e-10   3.93268307e-09   3.60386339e-11   4.50129755e-06
    4.91710729e-04   1.53179906e-06   1.61973068e-09   9.98264849e-01
    1.64548197e-04   1.07282540e-03]]
actual_index-predict_index: 7-7

[[  1.26447651e-06   1.93439351e-04   9.77273166e-01   2.05740426e-03
    1.00293578e-08   1.01640808e-05   2.70258147e-06   1.88517957e-09
    2.04617772e-02   2.35834854e-08]]
actual_index-predict_index: 2-2

[[  3.79381308e-05   6.86537874e-07   8.17690452e-04   7.74423679e-05
    1.53992577e-02   7.43901008e-04   2.05788703e-04   1.32311573e-02
    5.59218116e-02   9.13564324e-01]]
actual_index-predict_index: 9-9

[[  2.73115584e-03   1.74025251e-06   6.87140755e-06   2.08953805e-02
    5.36242624e-06   9.69347894e-01   1.07080101e-04   6.18056438e-05
    6.35029143e-03   4.92293271e-04]]
actual_index-predict_index: 5-5

[[  7.86116743e-07   1.04552935e-08   1.81989253e-05   7.68434547e-05
    1.13810336e-06   7.77041862e-07   1.55642266e-09   9.99853492e-01
    1.77232632e-05   3.09368443e-05]]
actual_index-predict_index: 7-7

[[  1.02413753e-04   9.51294351e-06   5.58499340e-03   1.88416038e-06
    1.57446731e-02   3.39030550e-04   9.76482511e-01   2.36741049e-04
    8.16290150e-04   6.81902049e-04]]
actual_index-predict_index: 6-6

[[  6.79365621e-05   3.67236702e-04   2.03259886e-04   6.43721316e-04
    4.16841246e-02   4.65681590e-03   3.68818961e-04   6.46927476e-01
    9.28802881e-03   2.95792609e-01]]
actual_index-predict_index: 9-7

最后一个预测错误的例子,其实预测为9的概率是0.2958,预测为7的概率为0.6469,预测为9的概率就比预测为7的概率小点。


举个例子:

image.png

[[  6.56793145e-06   6.04155775e-06   2.56523769e-02   1.30315811e-05

    1.28089413e-02   1.09560778e-02   9.47358847e-01   1.76747517e-06

    3.17676389e-03   1.94756776e-05]]

actual_index-predict_index: 6-6


可视化表示

可视化表示代码:

for i in range(10):
    data = w[:,i].eval().reshape([28, 28])
    cmap = plt.cm.winter  # 可以使用自定义的colormap
    im = plt.imshow(data, cmap=cmap)
    plt.colorbar(im)
    plt.show()

将权重可视化表示,黄色代表权重为正,褐色代表权重为负,图中图片依次是9、8、7...、2、1、0的特征权重可视化的图片。

选区_042.png