目录
简介
《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》
机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障。那BatchNorm的作用是什么呢?
BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的。
白化
之前的研究表明如果在图像处理中对输入图像进行白化(Whiten)操作的话——所谓白化,就是对输入数据分布变换到0均值,单位方差的正态分布。
实现时需要注意的问题
https://blog.csdn.net/huitailangyz/article/details/85015611
实现细节
Note: when training, the moving_mean and moving_variance need to be updated.
By default the update ops are placed in tf.GraphKeys.UPDATE_OPS
, so they
need to be added as a dependency to the train_op
. For example:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
参考
在使用 tf.nn.batch_normalization
时,如果训练阶段(批式,batch size 较大)和推理阶段(流式,batch size 较小甚至为1)的输入数据量不同,需特别注意 统计量的计算与使用方式。以下是具体处理方法:
关键原则
• 训练阶段:使用当前批次的均值和方差(batch_mean
和 batch_variance
),并更新全局的移动平均统计量。
• 推理阶段:使用训练阶段累积的移动平均统计量(moving_mean
和 moving_variance
),与当前推理的 batch size 无关。
处理步骤
1. 定义移动平均变量
在训练时维护两个额外变量(moving_mean
和 moving_variance
),用于保存全局统计量:
# 初始化移动平均变量(与输入张量的维度一致)
moving_mean = tf.Variable(initial_value=tf.zeros(shape=[num_features]), trainable=False)
moving_variance = tf.Variable(initial_value=tf.ones(shape=[num_features]), trainable=False)
2. 训练阶段
在每次训练时:
• 计算当前批次的 batch_mean
和 batch_variance
。
• 更新移动平均变量(使用指数平滑法,如 decay=0.99
)。
• 调用 tf.nn.batch_normalization
时传入当前批次的统计量。
# 训练模式
def train_step(x, training=True):
# 计算当前批次的统计量
batch_mean, batch_variance = tf.nn.moments(x, axes=[0]) # 假设输入 x 的 shape 为 [batch_size, features]
# 更新移动平均
decay = 0.99
update_moving_mean = moving_mean.assign(moving_mean * decay + batch_mean * (1 - decay))
update_moving_variance = moving_variance.assign(moving_variance * decay + batch_variance * (1 - decay))
# 在训练时使用当前批次的统计量
with tf.control_dependencies([update_moving_mean, update_moving_variance]):
normalized_output = tf.nn.batch_normalization(
x, batch_mean, batch_variance, offset, scale, variance_epsilon
)
return normalized_output
3. 推理阶段 在推理时直接使用移动平均的统计量:
# 推理模式(无论 batch size 是否为1)
def inference_step(x):
normalized_output = tf.nn.batch_normalization(
x, moving_mean, moving_variance, offset, scale, variance_epsilon
)
return normalized_output
注意事项
1. 避免推理时误用当前批次的统计量:
• 如果推理时传入当前批次的 mean
/variance
(尤其是 batch size=1 时),方差可能接近0,导致数值不稳定。
• 务必确保推理时传入的是 moving_mean
和 moving_variance
。
- Batch Size=1 的特殊情况: • 当推理时 batch size=1 时,直接使用移动平均统计量,避免因单样本方差为0导致归一化失效。
• variance_epsilon
(如 1e-3
)需添加到方差中,防止除以零。
- 与高阶 API 的兼容性:
• 手动实现需谨慎维护移动平均变量。推荐优先使用
tf.keras.layers.BatchNormalization
,它会自动处理训练/推理模式切换和统计量更新。
• 使用 Keras 层的示例:
```python
bn_layer = tf.keras.layers.BatchNormalization()
# 训练时自动更新统计量
normalized_output = bn_layer(x, training=True)
# 推理时自动使用移动平均
normalized_output = bn_layer(x, training=False)
```
代码示例(完整流程)
import tensorflow as tf
# 定义可学习参数和移动平均变量
num_features = 64
offset = tf.Variable(tf.zeros([num_features]))
scale = tf.Variable(tf.ones([num_features]))
moving_mean = tf.Variable(tf.zeros([num_features]), trainable=False)
moving_variance = tf.Variable(tf.ones([num_features]), trainable=False)
variance_epsilon = 1e-3
def batch_norm(x, training):
if training:
# 训练模式:计算当前 batch 的统计量,并更新移动平均
batch_mean, batch_variance = tf.nn.moments(x, axes=[0])
decay = 0.99
update_moving_mean = moving_mean.assign(moving_mean * decay + batch_mean * (1 - decay))
update_moving_variance = moving_variance.assign(moving_variance * decay + batch_variance * (1 - decay))
with tf.control_dependencies([update_moving_mean, update_moving_variance]):
return tf.nn.batch_normalization(x, batch_mean, batch_variance, offset, scale, variance_epsilon)
else:
# 推理模式:使用移动平均统计量
return tf.nn.batch_normalization(x, moving_mean, moving_variance, offset, scale, variance_epsilon)
# 模拟训练(batch_size=32)
x_train = tf.random.normal(shape=[32, num_features])
output_train = batch_norm(x_train, training=True)
# 模拟推理(batch_size=1)
x_inference = tf.random.normal(shape=[1, num_features])
output_inference = batch_norm(x_inference, training=False)
总结 • 训练阶段:动态更新移动平均统计量,不受 batch size 影响。
• 推理阶段:无论输入 batch size 多大(包括1),均使用训练时累积的 moving_mean
和 moving_variance
。
• 手动实现时需严格区分训练/推理逻辑,推荐使用 tf.keras.layers.BatchNormalization
简化流程。