class MeanPool(Layer):
def __init__(self, **kwargs):
self.supports_masking = True
super(MeanPool, self).__init__(**kwargs)
def compute_mask(self, input, input_mask=None):
# do not pass the mask to the next layers
return None
def call(self, x, mask=None):
if mask is not None:
# mask (batch, time)
mask = K.cast(mask, K.floatx())
# mask (batch, x_dim, time)
mask = K.repeat(mask, x.shape[-1])
# mask (batch, time, x_dim)
mask = tf.transpose(mask, [0,2,1])
x = x * mask
return tf.divide(K.sum(x, axis=1), K.sum(mask, axis=1) + 0.000001)
# return K.sum(x, axis=1)
else:
return K.mean(x, axis=1)
def compute_output_shape(self, input_shape):
# remove temporal dimension
return (input_shape[0], input_shape[2])
# BATCH_SIZE * QUERY_LEN
input_word_ids = Input(shape=(config.QUERY_LEN,), dtype=tf.int32)
# BATCH_SIZE * QUERY_LEN * EMB_SIZE
input_word_emb = Embedding(config.VOC_SIZE, config.EMB_SIZE, name="word2vec")(input_word_ids)
# BATCH_SIZE * QUERY_LEN
input_mask = Input(shape=(config.QUERY_LEN,), dtype=tf.int32)
avg_pooling = MeanPool()(input_word_emb, mask=input_mask)
hidden1 = Dense(768, activation=keras.activations.relu)(avg_pooling)
drop1 = Dropout(0.5)(hidden1)
hidden2 = Dense(512, activation=keras.activations.relu)(drop1)
drop2 = Dropout(0.5)(hidden2)
hidden3 = Dense(256, activation=keras.activations.relu)(drop2)
output = Dense(config.CATEGORY_NUMBER, activation=keras.activations.softmax)(hidden3)
mlp = Model(inputs=[input_mask, input_word_ids], outputs=[output])
mlp.compile(optimizer=tf.train.AdamOptimizer(0.01),
loss=keras.metrics.categorical_crossentropy,
metrics=[keras.metrics.categorical_accuracy])
mlp.fit([train_mask, train_features], train_labels, epochs=20, batch_size=200, validation_data=[[test_mask, test_features], test_labels])
016 tensorflow | 带masking的meanpooling
tensorflow相关文章
最近热门
- 论文《Applying Deep Learning To Airbnb Search》阅读笔记
- SFT(Supervised Fine-Tuning,即有监督微调)
- 论文 | PAL: A Position-bias Aware Learning Framework for CTR Prediction in Live Recommender Systems
- 凸优化中的 Slater 条件
- 因果推断 | uplift | 营销增长 | 增长算法 | 智能营销
- Context Parallel(简称CP)并行化技术
- ITC(Image-Text Contrastive)loss和ITM(Image-Text Matching)loss
- tf.losses.log_loss
- 论文:Dataset Regeneration for Sequential Recommendation
- 论文 | POSO: Personalized Cold Start Modules for Large-scale Recommender Systems
最常浏览
- 016 推荐系统 | 排序学习(LTR - Learning To Rank)
- 偏微分符号
- i.i.d(又称IID)
- 利普希茨连续条件(Lipschitz continuity)
- (error) MOVED 原因和解决方案
- TextCNN详解
- 找不到com.google.protobuf.GeneratedMessageV3的类文件
- Deployment failed: repository element was not specified in the POM inside distributionManagement
- cannot access com.google.protobuf.GeneratedMessageV3 解决方案
- CLUSTERDOWN Hash slot not served 问题原因和解决办法
×