tf.searchsorted
是 TensorFlow 中的一个函数,它用于在排序序列中搜索值的位置。这个函数通常用于确定一个值在排序数组中的插入位置,以保持数组的排序。
参数
- sorted_sequence:一个 N-D
Tensor
,包含已排序的序列。 - values:一个 N-D
Tensor
,包含要搜索的值。 - side:一个字符串,要么是 'left' 要么是 'right'。'left' 对应于 lower_bound,'right' 对应于 upper_bound。
- out_type:输出类型,可以是
int32
或int64
。默认为tf.int32
。 - name:操作的可选名称。
返回值
- 一个 N-D
Tensor
,大小与values
相同,包含将 lower_bound 或 upper_bound(取决于side
)应用于每个值的结果。结果不是整个Tensor
的全局索引,而是最后一个维度中的索引。
使用示例
import tensorflow as tf
# 定义一个已排序的序列
edges = tf.constant([-1, 3.3, 9.1, 10.0])
# 定义要搜索的值
values = tf.constant([0.0, 4.1, 12.0])
# 使用 tf.searchsorted 搜索值在排序序列中的位置
result = tf.searchsorted(edges, values)
# 打印结果
print(result.numpy()) # 输出:[1 2 4]
在这个例子中,edges
是一个已排序的序列,values
是我们要搜索的值。tf.searchsorted
函数返回了一个数组,其中包含了 values
中每个值在 edges
中的位置。
注意事项
- 此操作假设
sorted_sequence
沿最内轴排序。如果序列未排序,则不会引发错误,并且返回的张量的内容没有很好的定义 。 - 如果
sorted_sequence
的最后一个维度有2^31-1
个元素,或者values
的总大小超过2^31 - 1
元素,或者两个张量的第一个N-1
维度不匹配,会抛出ValueError
。 side
参数控制如果值恰好落在边上,则返回哪个索引。例如,如果side
是 'left',则返回的索引是第一个不小于values
中值的索引;如果side
是 'right',则返回的索引是第一个不大于values
中值的索引 。
这个函数的典型用例是 "binning"、"bucketing" 或 "discretizing",即根据 sorted_sequence
中列出的边将 values
分配给 bucket-indices,并返回每个值的 bucket-index 。