在 TensorFlow 中,tf.keras.losses.binary_crossentropy
是一个用于计算二元交叉熵损失的函数。它通常用于二元分类任务,其中每个样本的标签是0或1。
参数
y_true
: 真实标签的张量,形状与预测标签相同。y_pred
: 模型预测的标签,形状与真实标签相同,值在0到1之间。from_logits
: 布尔值,如果为True,则假定y_pred
未经sigmoid激活函数处理,即直接使用logits。如果为False(默认值),则假定y_pred
已经通过sigmoid函数转换为概率值。
使用方法
作为模型编译的损失函数
在构建模型并编译时,可以直接使用binary_crossentropy
作为损失函数:
import tensorflow as tf
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(10, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.binary_crossentropy,
metrics=['accuracy'])
在这个例子中,模型的最后一层使用sigmoid
激活函数,输出预测为正类的概率。binary_crossentropy
损失函数用于计算预测概率和真实标签之间的损失。
直接计算损失
你也可以直接使用binary_crossentropy
函数来计算两个张量之间的损失:
import tensorflow as tf
# 真实标签
y_true = [[0], [1], [1], [0]]
# 预测概率
y_pred = [[0.2], [0.9], [0.8], [0.1]]
# 计算损失
loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
print(loss.numpy()) # 输出损失值
注意事项
- 当
from_logits
参数设置为True时,函数内部会先对预测值应用sigmoid函数,然后再计算交叉熵损失。这在模型的最后一层未使用激活函数或使用线性激活函数时非常有用。 - 确保预测输出(
y_pred
)在[0, 1]范围内,因为log
函数在0或1处是未定义的。 - 如果你的模型输出层没有使用
sigmoid
激活函数,而是直接输出logits,你应该将from_logits
参数设置为True。
这个损失函数非常适合处理二元分类问题,如垃圾邮件检测、疾病诊断等。