函数定义
sequence_mask(
lengths,
maxlen=None,
dtype=tf.bool,
name=None
)
返回数据
return mask类型数据
示例代码
# 比如这个lenght就是记录了第一个句子2个单词,第二个句子2个单词,第三个句子4个单词
lenght = [2,2,4]
mask_data = tf.sequence_mask(lengths=lenght)
# 长度为max(lenght)
array([[ True, True, False, False],
[ True, True, False, False],
[ True, True, True, True]])
# 定义maxlen时
mask_data = tf.sequence_mask(lengths=lenght,maxlen=6)
# 长度为maxlen
array([[ True, True, False, False, False, False],
[ True, True, False, False, False, False],
[ True, True, True, True, False, False]])
# 定义dtype时
mask_data = tf.sequence_mask(lengths=lenght,maxlen=6,dtype=tf.float32)
# 长度为maxlen,数据格式为float32
array([[1., 1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0.]], dtype=float32)