def safe_loss(logits):
    # same as log(1 + exp(-pairwise_logits)).
    return tf.nn.relu(-logits) + tf.math.log1p(
        tf.exp(-tf.abs(logits)))


def new_pw_loss(qs_logits, qs_label):
    logits_diff = tf.expand_dims(qs_logits, 2) - tf.expand_dims(qs_logits, 1)
    label_diff = tf.expand_dims(qs_label, 2) - tf.expand_dims(qs_label, 1)
    label_mask = tf.cast(tf.greater(label_diff, 0.), logits_diff.dtype)
    valid_label = tf.greater_equal(qs_label, 0.)
    label_diff_valid = tf.math.logical_and(tf.expand_dims(valid_label, 2), tf.expand_dims(valid_label, 1))
    label_mask = label_mask * label_diff_valid

    ls_greater_0_num = tf.reduce_sum(tf.cast(tf.greater(label_diff, 0.), logits_diff.dtype))

    ls_mean_logits_diff = tf.reduce_sum(logits_diff * label_mask) / ls_greater_0_num

    ls_pairwise_loss = tf.reduce_sum(safe_loss(logits_diff) * label_mask) / ls_greater_0_num

    return ls_pairwise_loss