函数定义

sequence_mask(
    lengths,
    maxlen=None,
    dtype=tf.bool,
    name=None
)

返回数据

return mask类型数据

示例代码

# 比如这个lenght就是记录了第一个句子2个单词,第二个句子2个单词,第三个句子4个单词
lenght = [2,2,4] 
mask_data = tf.sequence_mask(lengths=lenght)
# 长度为max(lenght)
array([[ True,  True, False, False],
       [ True,  True, False, False],
       [ True,  True,  True,  True]])

# 定义maxlen时
mask_data = tf.sequence_mask(lengths=lenght,maxlen=6)
# 长度为maxlen
array([[ True,  True, False, False, False, False],
       [ True,  True, False, False, False, False],
       [ True,  True,  True,  True, False, False]])

# 定义dtype时
mask_data = tf.sequence_mask(lengths=lenght,maxlen=6,dtype=tf.float32)
# 长度为maxlen,数据格式为float32
array([[1., 1., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0., 0.],
       [1., 1., 1., 1., 0., 0.]], dtype=float32)