在tensorflow中,变量用于存储和更新参数。
变量包括创建、初始化、保存和加载等几个操作,下面将进行详细介绍。
创建
创建变量时,需要将一个初始值传给构造函数Variable(),初始值通常是常量或者随机值。传递初始值时,需要指定初始值的shape。
import tensorflow as tf # 创建随机值 weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="weights") # 创建常量 biases = tf.Variable(tf.zeros([200]), name="biases")
创建变量时,可以使用另一个变量的初始化值来创建一个新的变量,创建语句如下:
new_weights = tf.Variable(weights.initialized_value() * 0.5, name="new_weights")
初始化
创建完变量后记得对变量进行初始化,tensorflow需要使用op来初始化变量。注意初始化变量的op需要在所有变量定义完成之后,在模型运行之前进行初始化,初始化语句如下:
# 创建一个初始化变量的op init_op = tf.initialize_all_variables() # op需要在会话中执行 with tf.Session() as sess: sess.run(init_op) print(weights) print(biases)
保存
在计算完变量值后,我们可以保存变量的值,这时就用到了保存变量的op。保存变量的语句如下:
# 创建一个保存变量的op saver = tf.train.Saver() # op需要在会话中执行,ckpt是checkpoints的意思 # ckpt文件是一个二进制文件,它把变量名映射到对应的tensor值 with tf.Session() as sess: saver.save(sess, "/tmp/model.ckpt")
恢复
保存变量后,后续模型加载可以直接从保存文件中进行恢复变量,恢复变量的语句如下:
import tensorflow as tf restore_weights = tf.Variable(tf.random_normal([784,200], stddev=0.35), name="weights") restore_biased = tf.Variable(tf.zeros([200]), name="biases") # 创建一个恢复变量的op restore_saver = tf.train.Saver() # op需要在会话中执行 with tf.Session() as sess: restore_saver.restore(sess, "/tmp/model.ckpt") print(restore_weights) print(restore_biased)
在保存和恢复变量时,可以通过给Saver传递字典参数来指定保存和恢复那些变量,并给变量重命名,相应语句如下:
save_saver = tf.train.Saver({"save_weights": weights, "save_biases": biases}) restore_weights = tf.Variable(tf.random_normal([784,200], stddev=0.35), name="weights") restore_biased = tf.Variable(tf.zeros([200]), name="biases") # 创建一个恢复变量的op restore_saver = tf.train.Saver({"save_weights": restore_weights, "save_biases": restore_biased})
完整示例
创建、初始化、保存
import tensorflow as tf low_memory = False # 创建随机值 weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="weights") # 创建常量 biases = tf.Variable(tf.zeros([200]), name="biases") # 创建一个初始化变量的op init_op = tf.initialize_all_variables() # 创建一个保存变量的op saver = tf.train.Saver() # op需要在会话中执行 with tf.Session() as sess: sess.run(init_op) print(weights) print(biases) saver.save(sess, "/tmp/model.ckpt")
恢复
import tensorflow as tf restore_weights = tf.Variable(tf.random_normal([784,200], stddev=0.35), name="weights") restore_biased = tf.Variable(tf.zeros([200]), name="biases") # 创建一个恢复变量的op restore_saver = tf.train.Saver() # op需要在会话中执行 with tf.Session() as sess: restore_saver.restore(sess, "/tmp/model.ckpt") print(restore_weights) print(restore_biased)