dynamic_partition

tf.dynamic_partition(
    data,
    partitions,
    num_partitions,
    name=None
)

Partitions data into num_partitions tensors using indices from partitions.

使用分区中的索引将数据分区到num_partitions张量中。

    # Scalar partitions.
    partitions = 1
    num_partitions = 2
    data = [10, 20]
    outputs[0] = []  # Empty with shape [0, 2]
    outputs[1] = [[10, 20]]

    # Vector partitions.
    partitions = [0, 0, 1, 1, 0]
    num_partitions = 2
    data = [10, 20, 30, 40, 50]
    outputs[0] = [10, 20, 50]
    outputs[1] = [30, 40]

dynamic_stitch

import  tensorflow as tf
x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4])

#判断x里面的元素是否是1
condition_mask=tf.not_equal(x,tf.constant(-1.))

#[ True, False,  True,  True, False,  True]

#将张量拆成两个,按照condition_mask的对应位置
partitioned_data = tf.dynamic_partition(
    x, tf.cast(condition_mask, tf.int32) , 2)


#partitioned_data[0]=[-1., -1.]
#partitioned_data[1]=[2.1, 7.2, 6.3, 9.4]


partitioned_data[1] = partitioned_data[1] + 1.0



#这行代码是提取索引位置
condition_indices = tf.dynamic_partition(
    tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2)


x = tf.dynamic_stitch(condition_indices, partitioned_data)
# Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain
# unchanged.

条件分割和拼接

import tensorflow as tf
x = tf.constant([-2, 3, -1])
condition_mask = tf.greater(x, 0)
# <tf.Tensor: shape=(3,), dtype=bool, numpy=array([False,  True, False])>

y = tf.constant(
    [[11, 12],
     [21, 22],
     [31, 32]])

partitions = tf.dynamic_partition(y, tf.cast(condition_mask, tf.int32) , 2)

#[<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
# array([[11, 12],
#        [31, 32]], dtype=int32)>,
# <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[21, 22]], dtype=int32)>]

indexs = tf.range(tf.shape(x)[0])
# <tf.Tensor: shape=(3,), dtype=int32, numpy=array([0, 1, 2], dtype=int32)>

indices = tf.dynamic_partition(indexs, tf.cast(condition_mask, tf.int32), 2)

# [<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 2], dtype=int32)>,
# <tf.Tensor: shape=(1,), dtype=int32, numpy=array([1], dtype=int32)>]

tf.dynamic_stitch(indices, partitions)

# <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
# array([[11, 12],
#        [21, 22],
#      [31, 32]], dtype=int32)>

参考

  • https://vimsky.com/examples/usage/python-tensorflow-dynamic_stitch.html