# coding: utf8

from collections import Counter
import pandas as pd
import jieba
import tensorflow as tf
from tensorflow import keras
import numpy as np
from tensorflow.keras.layers import Lambda, Embedding, Dense, Dropout
import tensorflow.keras.backend as K

class Config:
    def __init__(self):
        # 频次大小
        self.WORD_THRESHOLD = 5
        # 词汇大小
        self.VOC_SIZE = None
        # emb size
        self.EMB_SIZE = 256
        # query 长度
        self.QUERY_LEN = 25
        # 类别数目
        self.CATEGORY_NUMBER = 7

config = Config()

def get_word_list(line):
    return [term for term in jieba.cut(line)][-config.QUERY_LEN:]


def get_id_list(word_list, word_id_dict):
    return ([0 for i in range(config.QUERY_LEN)] + [word_id_dict.get(word, 0) for word in word_list])[-config.QUERY_LEN:]


def predict_query(query):
    sample = get_id_list(get_word_list(query), word_id_dict)
    return mlp.predict([[sample]])


def get_features_labels(df):
    df[0] = df[0].apply(lambda x: get_id_list(x, word_id_dict))
    features = np.mat(df[0].tolist())
    labels = np.array(df.drop(0, axis=1))
    return features, labels


def get_most_similar(word, word_id_dict, id_word_dict, id_embedding):
    word_id = word_id_dict.get(word, 0)
    word_emb = id_embedding[word_id]
    sims = np.dot(id_embedding, word_emb)
    sort = sims.argsort()[::-1]
    sort = sort[sort > 0]
    return [(id_word_dict[i], sims[i]) for i in sort[:10]]

train_filename = "input/train.csv"
test_filename = "input/test.csv"

train = pd.read_csv(train_filename, converters={0:get_word_list}, header=None, sep="\t")
test = pd.read_csv(test_filename, converters={0:get_word_list}, header=None, sep="\t")

word_counter = Counter()

for word_list in train[0]:
    word_counter.update(word_list)

word_freq_dict = dict(item for item in word_counter.items() if item[1] >= config.WORD_THRESHOLD)
word_id_dict = dict([word, index+1] for index,word in enumerate(word_freq_dict.keys()))
id_word_dict = dict([index, word] for word, index in word_id_dict.items())
# +1 is important
config.VOC_SIZE = len(word_id_dict.keys()) + 1

train_features, train_labels = get_features_labels(train)
test_features, test_labels = get_features_labels(test)

init = tf.global_variables_initializer()
sess = tf.InteractiveSession()
sess.run(init)

mlp = keras.Sequential([
    Embedding(config.VOC_SIZE, config.EMB_SIZE, name='word2vec'),
    Lambda(lambda x: K.mean(x, axis=1)),
    Dense(256, activation=keras.activations.relu),
    Dropout(0.5),
    Dense(128, activation=keras.activations.relu),
    Dropout(0.5),
    Dense(64, activation=keras.activations.relu),
    Dense(config.CATEGORY_NUMBER, activation=keras.activations.softmax)
])

mlp.compile(optimizer=tf.train.AdamOptimizer(0.001),
            loss=keras.metrics.categorical_crossentropy,
            metrics=[keras.metrics.categorical_accuracy])

mlp.fit(train_features, train_labels, epochs=5, batch_size=200)

mlp.evaluate(test_features, test_labels, batch_size=200)

id_embedding = mlp.layers[0].get_weights()[0]
get_most_similar("图片", word_id_dict, id_word_dict, id_embedding)
predict_query("周星驰图片")