注意传入的是logits
@tf_export("nn.weighted_cross_entropy_with_logits")
def weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None):
"""Computes a weighted cross entropy.
This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
allows one to trade off recall and precision by up- or down-weighting the
cost of a positive error relative to a negative error.
The usual cross-entropy cost is defined as:
targets * -log(sigmoid(logits)) +
(1 - targets) * -log(1 - sigmoid(logits))
A value `pos_weights > 1` decreases the false negative count, hence increasing
the recall.
Conversely setting `pos_weights < 1` decreases the false positive count and
increases the precision.
This can be seen from the fact that `pos_weight` is introduced as a
multiplicative coefficient for the positive targets term
in the loss expression:
targets * -log(sigmoid(logits)) * pos_weight +
(1 - targets) * -log(1 - sigmoid(logits))
For brevity, let `x = logits`, `z = targets`, `q = pos_weight`.
The loss is:
qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + (qz + 1 - z) * log(1 + exp(-x))
= (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,
the implementation uses
(1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
`logits` and `targets` must have the same type and shape.
Args:
targets: A `Tensor` of the same type and shape as `logits`.
logits: A `Tensor` of type `float32` or `float64`.
pos_weight: A coefficient to use on the positive examples.
name: A name for the operation (optional).
Returns:
A `Tensor` of the same shape as `logits` with the componentwise
weighted logistic losses.
Raises:
ValueError: If `logits` and `targets` do not have the same shape.
"""
with ops.name_scope(name, "logistic_loss", [logits, targets]) as name:
logits = ops.convert_to_tensor(logits, name="logits")
targets = ops.convert_to_tensor(targets, name="targets")
try:
targets.get_shape().merge_with(logits.get_shape())
except ValueError:
raise ValueError(
"logits and targets must have the same shape (%s vs %s)" %
(logits.get_shape(), targets.get_shape()))
# The logistic loss formula from above is
# (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
# For x < 0, a more numerically stable formula is
# (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(x)) - l * x
# To avoid branching, we use the combined version
# (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
log_weight = 1 + (pos_weight - 1) * targets
return math_ops.add(
(1 - targets) * logits,
log_weight * (math_ops.log1p(math_ops.exp(-math_ops.abs(logits))) +
nn_ops.relu(-logits)),
name=name)