dynamic_partition
tf.dynamic_partition(
data,
partitions,
num_partitions,
name=None
)
Partitions data into num_partitions tensors using indices from partitions.
使用分区中的索引将数据分区到num_partitions张量中。
# Scalar partitions.
partitions = 1
num_partitions = 2
data = [10, 20]
outputs[0] = [] # Empty with shape [0, 2]
outputs[1] = [[10, 20]]
# Vector partitions.
partitions = [0, 0, 1, 1, 0]
num_partitions = 2
data = [10, 20, 30, 40, 50]
outputs[0] = [10, 20, 50]
outputs[1] = [30, 40]
dynamic_stitch
import tensorflow as tf
x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4])
#判断x里面的元素是否是1
condition_mask=tf.not_equal(x,tf.constant(-1.))
#[ True, False, True, True, False, True]
#将张量拆成两个,按照condition_mask的对应位置
partitioned_data = tf.dynamic_partition(
x, tf.cast(condition_mask, tf.int32) , 2)
#partitioned_data[0]=[-1., -1.]
#partitioned_data[1]=[2.1, 7.2, 6.3, 9.4]
partitioned_data[1] = partitioned_data[1] + 1.0
#这行代码是提取索引位置
condition_indices = tf.dynamic_partition(
tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2)
x = tf.dynamic_stitch(condition_indices, partitioned_data)
# Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain
# unchanged.
条件分割和拼接
import tensorflow as tf
x = tf.constant([-2, 3, -1])
condition_mask = tf.greater(x, 0)
# <tf.Tensor: shape=(3,), dtype=bool, numpy=array([False, True, False])>
y = tf.constant(
[[11, 12],
[21, 22],
[31, 32]])
partitions = tf.dynamic_partition(y, tf.cast(condition_mask, tf.int32) , 2)
#[<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
# array([[11, 12],
# [31, 32]], dtype=int32)>,
# <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[21, 22]], dtype=int32)>]
indexs = tf.range(tf.shape(x)[0])
# <tf.Tensor: shape=(3,), dtype=int32, numpy=array([0, 1, 2], dtype=int32)>
indices = tf.dynamic_partition(indexs, tf.cast(condition_mask, tf.int32), 2)
# [<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 2], dtype=int32)>,
# <tf.Tensor: shape=(1,), dtype=int32, numpy=array([1], dtype=int32)>]
tf.dynamic_stitch(indices, partitions)
# <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
# array([[11, 12],
# [21, 22],
# [31, 32]], dtype=int32)>
参考
- https://vimsky.com/examples/usage/python-tensorflow-dynamic_stitch.html