tf.trainable_variables()
顾名思义,这个函数可以也仅可以查看可训练的变量,在我们生成变量时,无论是使用tf.Variable()还是tf.get_variable()生成变量,都会涉及一个参数trainable,其默认为True。
tf.all_variables()
tf.global_variables()
版本更新
在早期版本的 TensorFlow(1.x 版本)中,tf.all_variables()
是一个常用的函数,不过在 TensorFlow 2.x 版本中,该函数已被弃用。下面为你详细介绍它的相关信息。
1. tf.all_variables()
在 TensorFlow 1.x 中的作用
在 TensorFlow 1.x 里,tf.all_variables()
函数用于返回当前图中所有的变量(tf.Variable
对象)列表。这些变量可以是可训练的变量(如神经网络中的权重和偏置),也可以是不可训练的变量(例如用于计数的变量)。
示例代码
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
# 创建一些变量
var1 = tf.Variable(1.0, name='var1')
var2 = tf.Variable(2.0, name='var2')
# 获取所有变量
all_vars = tf.all_variables()
# 初始化所有变量
init = tf.global_variables_initializer()
# 创建会话并运行初始化操作
with tf.Session() as sess:
sess.run(init)
# 打印所有变量
for var in all_vars:
print(f"Variable name: {var.name}, Value: {sess.run(var)}")
代码解释
- 首先,我们导入
tensorflow.compat.v1
并禁用 TensorFlow 2.x 的行为,以模拟 TensorFlow 1.x 的环境。 - 接着,创建了两个变量
var1
和var2
。 - 然后,使用
tf.all_variables()
函数获取当前图中的所有变量。 - 之后,创建一个初始化操作
init
来初始化所有变量。 - 最后,在会话中运行初始化操作,并打印每个变量的名称和值。
2. TensorFlow 2.x 中的替代方案
在 TensorFlow 2.x 中,不再使用静态图和会话的概念,因此 tf.all_variables()
被弃用。如果需要获取所有可训练的变量,可以使用 model.trainable_variables
;如果需要获取所有变量(包括可训练和不可训练的),可以使用 model.variables
。
示例代码
import tensorflow as tf
# 创建一个简单的模型
class SimpleModel(tf.keras.Model):
def __init__(self):
super(SimpleModel, self).__init__()
self.dense = tf.keras.layers.Dense(1, input_shape=(1,))
def call(self, inputs):
return self.dense(inputs)
# 实例化模型
model = SimpleModel()
# 获取所有可训练变量
trainable_vars = model.trainable_variables
# 获取所有变量
all_vars = model.variables
# 打印变量信息
print("Trainable variables:")
for var in trainable_vars:
print(f"Variable name: {var.name}, Shape: {var.shape}")
print("\nAll variables:")
for var in all_vars:
print(f"Variable name: {var.name}, Shape: {var.shape}")
代码解释
- 首先,定义了一个简单的 Keras 模型
SimpleModel
,其中包含一个全连接层。 - 然后,实例化该模型。
- 接着,使用
model.trainable_variables
获取所有可训练的变量,使用model.variables
获取所有变量。 - 最后,打印每个变量的名称和形状。
参考
- https://blog.csdn.net/Cerisier/article/details/86523446