# coding: utf8
from collections import Counter
import pandas as pd
import jieba
import tensorflow as tf
from tensorflow import keras
import numpy as np
from tensorflow.keras.layers import Lambda, Embedding, Dense, Dropout
import tensorflow.keras.backend as K
class Config:
def __init__(self):
# 频次大小
self.WORD_THRESHOLD = 5
# 词汇大小
self.VOC_SIZE = None
# emb size
self.EMB_SIZE = 256
# query 长度
self.QUERY_LEN = 25
# 类别数目
self.CATEGORY_NUMBER = 7
config = Config()
def get_word_list(line):
return [term for term in jieba.cut(line)][-config.QUERY_LEN:]
def get_id_list(word_list, word_id_dict):
return ([0 for i in range(config.QUERY_LEN)] + [word_id_dict.get(word, 0) for word in word_list])[-config.QUERY_LEN:]
def predict_query(query):
sample = get_id_list(get_word_list(query), word_id_dict)
return mlp.predict([[sample]])
def get_features_labels(df):
df[0] = df[0].apply(lambda x: get_id_list(x, word_id_dict))
features = np.mat(df[0].tolist())
labels = np.array(df.drop(0, axis=1))
return features, labels
def get_most_similar(word, word_id_dict, id_word_dict, id_embedding):
word_id = word_id_dict.get(word, 0)
word_emb = id_embedding[word_id]
sims = np.dot(id_embedding, word_emb)
sort = sims.argsort()[::-1]
sort = sort[sort > 0]
return [(id_word_dict[i], sims[i]) for i in sort[:10]]
train_filename = "input/train.csv"
test_filename = "input/test.csv"
train = pd.read_csv(train_filename, converters={0:get_word_list}, header=None, sep="\t")
test = pd.read_csv(test_filename, converters={0:get_word_list}, header=None, sep="\t")
word_counter = Counter()
for word_list in train[0]:
word_counter.update(word_list)
word_freq_dict = dict(item for item in word_counter.items() if item[1] >= config.WORD_THRESHOLD)
word_id_dict = dict([word, index+1] for index,word in enumerate(word_freq_dict.keys()))
id_word_dict = dict([index, word] for word, index in word_id_dict.items())
# +1 is important
config.VOC_SIZE = len(word_id_dict.keys()) + 1
train_features, train_labels = get_features_labels(train)
test_features, test_labels = get_features_labels(test)
init = tf.global_variables_initializer()
sess = tf.InteractiveSession()
sess.run(init)
mlp = keras.Sequential([
Embedding(config.VOC_SIZE, config.EMB_SIZE, name='word2vec'),
Lambda(lambda x: K.mean(x, axis=1)),
Dense(256, activation=keras.activations.relu),
Dropout(0.5),
Dense(128, activation=keras.activations.relu),
Dropout(0.5),
Dense(64, activation=keras.activations.relu),
Dense(config.CATEGORY_NUMBER, activation=keras.activations.softmax)
])
mlp.compile(optimizer=tf.train.AdamOptimizer(0.001),
loss=keras.metrics.categorical_crossentropy,
metrics=[keras.metrics.categorical_accuracy])
mlp.fit(train_features, train_labels, epochs=5, batch_size=200)
mlp.evaluate(test_features, test_labels, batch_size=200)
id_embedding = mlp.layers[0].get_weights()[0]
get_most_similar("图片", word_id_dict, id_word_dict, id_embedding)
predict_query("周星驰图片")
001 项目实战 | 文本分类
tensorflow相关文章
nlp相关文章
最近热门
- 模型学习率预热 warm up
- Pandas实战:将分组后的结果展平并以列的形式展示
- 推荐系统 | LFM(Latent Factor Model) 隐因子模型的原理与应用
- Kendall秩相关系数,肯德尔秩相关系数
- 特征工程、特征设计、特征梳理
- mysql 显示运行的线程
- TimesNet:一个用于时间序列分析的通用基础模型
- 腾讯终身交叉网络LCN模型:Cross-Domain LifeLong Sequential Modeling for Online Click-Through Rate Prediction
- Can't reconnect until invalid transaction is rolled back
- PSI(Population Stability Index,群体稳定性指标)
最常浏览
- 016 推荐系统 | 排序学习(LTR - Learning To Rank)
- 偏微分符号
- i.i.d(又称IID)
- 利普希茨连续条件(Lipschitz continuity)
- (error) MOVED 原因和解决方案
- TextCNN详解
- 找不到com.google.protobuf.GeneratedMessageV3的类文件
- Deployment failed: repository element was not specified in the POM inside distributionManagement
- cannot access com.google.protobuf.GeneratedMessageV3 解决方案
- CLUSTERDOWN Hash slot not served 问题原因和解决办法
×