在tensorflow中,变量用于存储和更新参数。


变量包括创建、初始化、保存和加载等几个操作,下面将进行详细介绍。

创建

创建变量时,需要将一个初始值传给构造函数Variable(),初始值通常是常量或者随机值。传递初始值时,需要指定初始值的shape。

import tensorflow as tf

# 创建随机值
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="weights")

# 创建常量
biases = tf.Variable(tf.zeros([200]), name="biases")

创建变量时,可以使用另一个变量的初始化值来创建一个新的变量,创建语句如下:

new_weights = tf.Variable(weights.initialized_value() * 0.5, name="new_weights")

初始化

创建完变量后记得对变量进行初始化,tensorflow需要使用op来初始化变量。注意初始化变量的op需要在所有变量定义完成之后,在模型运行之前进行初始化,初始化语句如下:

# 创建一个初始化变量的op
init_op = tf.initialize_all_variables()

# op需要在会话中执行
with tf.Session() as sess:
  sess.run(init_op)
  print(weights)
  print(biases)

保存

在计算完变量值后,我们可以保存变量的值,这时就用到了保存变量的op。保存变量的语句如下:

# 创建一个保存变量的op
saver = tf.train.Saver()
 
# op需要在会话中执行,ckpt是checkpoints的意思
# ckpt文件是一个二进制文件,它把变量名映射到对应的tensor值
with tf.Session() as sess:
	saver.save(sess, "/tmp/model.ckpt")

恢复

保存变量后,后续模型加载可以直接从保存文件中进行恢复变量,恢复变量的语句如下:

import tensorflow as tf

restore_weights = tf.Variable(tf.random_normal([784,200], stddev=0.35), name="weights")
restore_biased = tf.Variable(tf.zeros([200]), name="biases")

# 创建一个恢复变量的op
restore_saver = tf.train.Saver()

# op需要在会话中执行
with tf.Session() as sess:
    restore_saver.restore(sess, "/tmp/model.ckpt")
    print(restore_weights)
    print(restore_biased)

在保存和恢复变量时,可以通过给Saver传递字典参数来指定保存和恢复那些变量,并给变量重命名,相应语句如下:

save_saver = tf.train.Saver({"save_weights": weights, "save_biases": biases})

restore_weights = tf.Variable(tf.random_normal([784,200], stddev=0.35), name="weights")
restore_biased = tf.Variable(tf.zeros([200]), name="biases")

# 创建一个恢复变量的op
restore_saver = tf.train.Saver({"save_weights": restore_weights, "save_biases": restore_biased})

完整示例

创建、初始化、保存

import tensorflow as tf

low_memory = False

# 创建随机值
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="weights")
 
# 创建常量
biases = tf.Variable(tf.zeros([200]), name="biases")

# 创建一个初始化变量的op
init_op = tf.initialize_all_variables()

# 创建一个保存变量的op
saver = tf.train.Saver()
 
# op需要在会话中执行
with tf.Session() as sess:
	sess.run(init_op)
	print(weights)
	print(biases)
	saver.save(sess, "/tmp/model.ckpt")

恢复

import tensorflow as tf

restore_weights = tf.Variable(tf.random_normal([784,200], stddev=0.35), name="weights")
restore_biased = tf.Variable(tf.zeros([200]), name="biases")

# 创建一个恢复变量的op
restore_saver = tf.train.Saver()

# op需要在会话中执行
with tf.Session() as sess:
	restore_saver.restore(sess, "/tmp/model.ckpt")
	print(restore_weights)
	print(restore_biased)