tf.searchsorted 是 TensorFlow 中的一个函数,它用于在排序序列中搜索值的位置。这个函数通常用于确定一个值在排序数组中的插入位置,以保持数组的排序。

参数

  • sorted_sequence:一个 N-D Tensor,包含已排序的序列。
  • values:一个 N-D Tensor,包含要搜索的值。
  • side:一个字符串,要么是 'left' 要么是 'right'。'left' 对应于 lower_bound,'right' 对应于 upper_bound。
  • out_type:输出类型,可以是 int32int64。默认为 tf.int32
  • name:操作的可选名称。

返回值

  • 一个 N-D Tensor,大小与 values 相同,包含将 lower_bound 或 upper_bound(取决于 side)应用于每个值的结果。结果不是整个 Tensor 的全局索引,而是最后一个维度中的索引。

使用示例

import tensorflow as tf

# 定义一个已排序的序列
edges = tf.constant([-1, 3.3, 9.1, 10.0])

# 定义要搜索的值
values = tf.constant([0.0, 4.1, 12.0])

# 使用 tf.searchsorted 搜索值在排序序列中的位置
result = tf.searchsorted(edges, values)

# 打印结果
print(result.numpy())  # 输出:[1 2 4]

在这个例子中,edges 是一个已排序的序列,values 是我们要搜索的值。tf.searchsorted 函数返回了一个数组,其中包含了 values 中每个值在 edges 中的位置。

注意事项

  • 此操作假设 sorted_sequence 沿最内轴排序。如果序列未排序,则不会引发错误,并且返回的张量的内容没有很好的定义 。
  • 如果 sorted_sequence 的最后一个维度有 2^31-1 个元素,或者 values 的总大小超过 2^31 - 1 元素,或者两个张量的第一个 N-1 维度不匹配,会抛出 ValueError
  • side 参数控制如果值恰好落在边上,则返回哪个索引。例如,如果 side 是 'left',则返回的索引是第一个不小于 values 中值的索引;如果 side 是 'right',则返回的索引是第一个不大于 values 中值的索引 。

这个函数的典型用例是 "binning"、"bucketing" 或 "discretizing",即根据 sorted_sequence 中列出的边将 values 分配给 bucket-indices,并返回每个值的 bucket-index 。