目录
模型训练
# coding: utf8
import tensorflow as tf
import pandas as pd
import numpy as np
import jieba
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Input, Embedding, Lambda, Layer
import tensorflow.keras.backend as K
train = pd.read_csv("input/train.csv", header=None, sep="\t")
init = tf.global_variables_initializer()
sess = tf.InteractiveSession()
sess.run(init)
input_ids = Input(shape=(15,), dtype=tf.int32, name="input_ids")
input_mask = Input(shape=(15,), dtype=tf.float32, name="input_mask")
input_word_emb = Embedding(len(tokenizer.word_index.items())+1, 7, name="class_emb")(input_ids)
input_mask_repeat = Lambda(lambda x: K.repeat(x, 7), name="mask_repeat")(input_mask)
input_mask_transpose = Lambda(lambda x: tf.transpose(x, [0, 2, 1]), name="mask_transpose")(input_mask_repeat)
multiply_result = Lambda(lambda x: x[0] * x[1], name="multiply_result")([input_word_emb, input_mask_transpose])
multiply_softmax = Lambda(lambda x: tf.nn.softmax(x), name="multiply_softmax")(multiply_result)
mask_softmax = Lambda(lambda x: x[0] * x[1], name="mask_softmax")([multiply_softmax, input_mask_transpose])
output = Lambda(lambda x: tf.divide(K.sum(x[0], axis=1), K.sum(x[1], axis=1) + 0.000001), name="softmax_mean")([mask_softmax, input_mask_transpose])
model = tf.keras.Model(inputs=[input_ids, input_mask], outputs=[output])
model.compile(optimizer=tf.train.AdamOptimizer(0.01),
loss=tf.keras.metrics.categorical_crossentropy,
metrics=[tf.keras.metrics.categorical_accuracy])
model.fit([train_features, train_mask], train[labels].values, batch_size=200, epochs=1)
可视化
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.font_manager import FontProperties
def text2words(text):
return "#BEGIN " + " ".join([word.lower() for word in jieba.cut(text)]) + " #END"
% matplotlib inline
test_words = text2words("林书豪成名战完整版")
test_ids = pad_sequences(tokenizer.texts_to_sequences([test_words]), maxlen=15)
test_mask = (test_ids != 0).astype("float32")
myfont = FontProperties(fname='your_chinese_font.ttc')
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
cmap = sns.color_palette("Blues")
f, ax = plt.subplots(figsize=(8,8),nrows=1)
a = sess.run(model.layers[7].output, feed_dict={input_ids:test_ids, input_mask:test_mask})
b = sess.run(model.layers[8].output, feed_dict={input_ids:test_ids, input_mask:test_mask})
sns.heatmap(np.r_[a[0], b],annot=True, ax=ax, cmap=cmap)
ax.set_xticklabels("问答 新闻 视频 图片 文库 地图 其他".split(" "), fontproperties=myfont)
ax.set_yticklabels((["PADDING" for i in range(16)] + test_words.split(" "))[-15:] + [test_words], rotation=0, fontproperties=myfont)