import numpy as np
train_mask = (train_features != 0).astype("float32")
def mlp_mean_pooling_with_mask(word_emb_size=100, output_size=7):
input_ids = Input(shape=(15,), dtype=tf.int32, name="input_ids")
input_mask = Input(shape=(15,), dtype=tf.float32, name="input_mask")
input_mask_repeat = Lambda(lambda x: K.repeat(x, word_emb_size), name="mask_repeat")(input_mask)
input_mask_transpose = Lambda(lambda x: tf.transpose(x, [0, 2, 1]), name="mask_transpose")(input_mask_repeat)
input_word_emb = Embedding(len(tokenizer.word_index.items())+1, word_emb_size, name="class_emb")(input_ids)
mask_word_emb = Lambda(lambda x: x[0] * x[1], name="mask_softmax")([input_word_emb, input_mask_transpose])
mean_pooling = Lambda(lambda x: tf.divide(K.sum(x[0], axis=1), K.sum(x[1], axis=1) + 0.000001))([mask_word_emb, input_mask_transpose])
hidden1 = Dense(100, activation=tf.nn.relu)(mean_pooling)
output = Dense(output_size, activation=tf.nn.softmax)(hidden1)
model = tf.keras.Model(inputs=[input_ids, input_mask], outputs=[output])
return model
002 算法实战 | mask
算法实战相关文章
最近热门
- 7.1.1 设置spark.driver.maxResultSize
- C++ 整形数转字符串
- 论文:Capturing Delayed Feedback in Conversion Rate Prediction via Elapsed-Time Sampling
- Minimum Detectable Effect(MDE)最小可检测效应
- werkzeug ImportError: cannot import name 'secure_filename'
- SFT(Supervised Fine-Tuning,即有监督微调)
- STT模型(Speech-to-Text)
- 论文阅读 TOKEN MERGING: YOUR VIT BUT FASTER(ToMe模型)
- 因果推断 | uplift | 营销增长 | 增长算法 | 智能营销
- SSB - Sample Selection Bias - 样本选择偏差问题
最常浏览
- 016 推荐系统 | 排序学习(LTR - Learning To Rank)
- 偏微分符号
- i.i.d(又称IID)
- 利普希茨连续条件(Lipschitz continuity)
- (error) MOVED 原因和解决方案
- TextCNN详解
- 找不到com.google.protobuf.GeneratedMessageV3的类文件
- Deployment failed: repository element was not specified in the POM inside distributionManagement
- cannot access com.google.protobuf.GeneratedMessageV3 解决方案
- CLUSTERDOWN Hash slot not served 问题原因和解决办法
×