目录

avg_pooling

class Model:
    def __init__(self, config):
        voc_size = config.VOC_SIZE
        emb_size = config.EMB_SIZE
        query_len = config.QUERY_LEN
        neg_num = config.NEG_NUM
        learning_rate = config.LEARNING_RATE

        self.WordEmb = Embedding(voc_size, emb_size, name="que_emb")

        # None * LEN
        input_word_ids = Input(shape=(query_len,), dtype=tf.int32, name="input_word_ids")
        input_mask = Input(shape=(query_len,), dtype=tf.int32, name="input_word_mask")
        output_word_ids = Input(shape=(1,), dtype=tf.int32, name="output")

        # input_word
        input_word_emb = self.WordEmb(input_word_ids)

        # input_mask
        input_mask_cast = Lambda(lambda x: K.cast(x, K.floatx()), name="mask_cast")(input_mask)
        input_mask_repeat = Lambda(lambda x: K.repeat(x, emb_size), name="mask_repeat")(input_mask_cast)
        # None * LEN * EMB_SIZE
        input_mask_final = Lambda(lambda x: tf.transpose(x, [0, 2, 1]), name="mask_final")(input_mask_repeat)

        # None * LEN * EMB_SIZE
        mask_input = Lambda(lambda x: x[0] * x[1], name="mask_input")([input_word_emb, input_mask_final])
        avg_pooling = Lambda(
            lambda x: tf.divide(K.sum(x[0], axis=1), K.sum(x[1], axis=1) + 0.000001, name="avg_pooling"))(
            [mask_input, input_mask_final])

        # que_hidden_1 = Dense(config.EMB_SIZE * 2, activation=tf.nn.relu, name="relu1")(avg_pooling)
        # que_hidden_2 = Dense(config.EMB_SIZE * 2, activation=tf.nn.relu, name="relu2")(que_hidden_1)
        # que_hidden_3 = Dense(config.EMB_SIZE, activation=tf.nn.sigmoid, name="relu3")(que_hidden_2)

        # que_repeat = Lambda(lambda x: K.repeat(x, config.NEG_NUM + 1), name="que_repeat")(avg_pooling)

        negatives_word_input = Lambda(
            lambda x: K.random_uniform((K.shape(x)[0], neg_num), 0, voc_size, 'int32'))(output_word_ids)
        candidate_words_input = Lambda(lambda x: K.concatenate(x))([output_word_ids, negatives_word_input])

        ans_word_emb = self.WordEmb(candidate_words_input)
        softmax = Lambda(lambda x: K.softmax((K.batch_dot(x[0], K.expand_dims(x[1], 2)))[:, :, 0])
                         )([ans_word_emb, avg_pooling])

        # que_norm = Lambda(lambda x: tf.sqrt(tf.reduce_sum(tf.square(x), axis=2)))(que_repeat)
        # ans_norm = Lambda(lambda x: tf.sqrt(tf.reduce_sum(tf.square(x), axis=2)))(ans_word_emb)

        # x1_x2 = Lambda(lambda x: tf.reduce_sum(x[0] * x[1], axis=2))([que_repeat, ans_word_emb])
        # cos_sim = Lambda(lambda x: tf.divide(x[0], x[1] * x[2] + 0.00000001))([x1_x2, que_norm, ans_norm])
        # prob = Lambda(lambda x: tf.nn.softmax(x))(cos_sim)
        # output = Lambda(lambda x: tf.slice(x, [0, 0], [-1, 1]))(cos_sim)
        # output = cos_sim

        self.model = tf.keras.Model(inputs=[input_mask, input_word_ids, output_word_ids], outputs=[softmax])

        self.model.compile(optimizer=tf.train.AdamOptimizer(learning_rate),
                      loss=keras.metrics.sparse_categorical_crossentropy,
                      metrics=[keras.metrics.categorical_accuracy])

    def train(self, question_mask, question_ids, answer_ids, new_labels, epoch=1, batch_size=200):
        self.model.fit([question_mask, question_ids, answer_ids], new_labels, epochs=epoch, batch_size=batch_size)
        # self.model.fit([question_mask, question_ids, answer_ids], new_labels, epochs=1, batch_size=100)