# coding: utf8
from collections import Counter
import pandas as pd
import jieba
import tensorflow as tf
from tensorflow import keras
import numpy as np
from tensorflow.keras.layers import Lambda, Embedding, Dense, Dropout
import tensorflow.keras.backend as K
class Config:
def __init__(self):
# 频次大小
self.WORD_THRESHOLD = 5
# 词汇大小
self.VOC_SIZE = None
# emb size
self.EMB_SIZE = 256
# query 长度
self.QUERY_LEN = 25
# 类别数目
self.CATEGORY_NUMBER = 7
config = Config()
def get_word_list(line):
return [term for term in jieba.cut(line)][-config.QUERY_LEN:]
def get_id_list(word_list, word_id_dict):
return ([0 for i in range(config.QUERY_LEN)] + [word_id_dict.get(word, 0) for word in word_list])[-config.QUERY_LEN:]
def predict_query(query):
sample = get_id_list(get_word_list(query), word_id_dict)
return mlp.predict([[sample]])
def get_features_labels(df):
df[0] = df[0].apply(lambda x: get_id_list(x, word_id_dict))
features = np.mat(df[0].tolist())
labels = np.array(df.drop(0, axis=1))
return features, labels
def get_most_similar(word, word_id_dict, id_word_dict, id_embedding):
word_id = word_id_dict.get(word, 0)
word_emb = id_embedding[word_id]
sims = np.dot(id_embedding, word_emb)
sort = sims.argsort()[::-1]
sort = sort[sort > 0]
return [(id_word_dict[i], sims[i]) for i in sort[:10]]
train_filename = "input/train.csv"
test_filename = "input/test.csv"
train = pd.read_csv(train_filename, converters={0:get_word_list}, header=None, sep="\t")
test = pd.read_csv(test_filename, converters={0:get_word_list}, header=None, sep="\t")
word_counter = Counter()
for word_list in train[0]:
word_counter.update(word_list)
word_freq_dict = dict(item for item in word_counter.items() if item[1] >= config.WORD_THRESHOLD)
word_id_dict = dict([word, index+1] for index,word in enumerate(word_freq_dict.keys()))
id_word_dict = dict([index, word] for word, index in word_id_dict.items())
# +1 is important
config.VOC_SIZE = len(word_id_dict.keys()) + 1
train_features, train_labels = get_features_labels(train)
test_features, test_labels = get_features_labels(test)
init = tf.global_variables_initializer()
sess = tf.InteractiveSession()
sess.run(init)
mlp = keras.Sequential([
Embedding(config.VOC_SIZE, config.EMB_SIZE, name='word2vec'),
Lambda(lambda x: K.mean(x, axis=1)),
Dense(256, activation=keras.activations.relu),
Dropout(0.5),
Dense(128, activation=keras.activations.relu),
Dropout(0.5),
Dense(64, activation=keras.activations.relu),
Dense(config.CATEGORY_NUMBER, activation=keras.activations.softmax)
])
mlp.compile(optimizer=tf.train.AdamOptimizer(0.001),
loss=keras.metrics.categorical_crossentropy,
metrics=[keras.metrics.categorical_accuracy])
mlp.fit(train_features, train_labels, epochs=5, batch_size=200)
mlp.evaluate(test_features, test_labels, batch_size=200)
id_embedding = mlp.layers[0].get_weights()[0]
get_most_similar("图片", word_id_dict, id_word_dict, id_embedding)
predict_query("周星驰图片")
001 项目实战 | 文本分类
tensorflow相关文章
nlp相关文章
- ELECTRA+Biaffine
- Self Instruct技术
- SO-PMI(Semantic Orientation Pointwise Mutual Information,情感倾向点互信息算法)
- Masked Language Model(MLM)掩码语言模型
- Siamese CBOW(Continuous Bag of Words)一种Word2Vec的变体
- RoBERTa-wwm(RoBERTa-Whole Word Masking)模型
- cs224n
- No Language Left Behind(NLLB,不让任何一门语言掉队)
- 共指消解(Coreference Resolution)
- Coherence模型(一致性模型)
最近热门
- 论文《Applying Deep Learning To Airbnb Search》阅读笔记
- SFT(Supervised Fine-Tuning,即有监督微调)
- 论文 | PAL: A Position-bias Aware Learning Framework for CTR Prediction in Live Recommender Systems
- 凸优化中的 Slater 条件
- 因果推断 | uplift | 营销增长 | 增长算法 | 智能营销
- Context Parallel(简称CP)并行化技术
- ITC(Image-Text Contrastive)loss和ITM(Image-Text Matching)loss
- tf.losses.log_loss
- 论文:Dataset Regeneration for Sequential Recommendation
- 论文 | POSO: Personalized Cold Start Modules for Large-scale Recommender Systems
最常浏览
- 016 推荐系统 | 排序学习(LTR - Learning To Rank)
- 偏微分符号
- i.i.d(又称IID)
- 利普希茨连续条件(Lipschitz continuity)
- (error) MOVED 原因和解决方案
- TextCNN详解
- 找不到com.google.protobuf.GeneratedMessageV3的类文件
- Deployment failed: repository element was not specified in the POM inside distributionManagement
- cannot access com.google.protobuf.GeneratedMessageV3 解决方案
- CLUSTERDOWN Hash slot not served 问题原因和解决办法
×