import numpy as np
train_mask = (train_features != 0).astype("float32")


def mlp_mean_pooling_with_mask(word_emb_size=100, output_size=7):
    input_ids = Input(shape=(15,), dtype=tf.int32, name="input_ids")
    input_mask = Input(shape=(15,), dtype=tf.float32, name="input_mask")
    input_mask_repeat = Lambda(lambda x: K.repeat(x, word_emb_size), name="mask_repeat")(input_mask)
    input_mask_transpose = Lambda(lambda x: tf.transpose(x, [0, 2, 1]), name="mask_transpose")(input_mask_repeat)
    input_word_emb = Embedding(len(tokenizer.word_index.items())+1, word_emb_size, name="class_emb")(input_ids)
    mask_word_emb = Lambda(lambda x: x[0] * x[1], name="mask_softmax")([input_word_emb, input_mask_transpose])
    mean_pooling = Lambda(lambda x: tf.divide(K.sum(x[0], axis=1), K.sum(x[1], axis=1) + 0.000001))([mask_word_emb, input_mask_transpose])
    hidden1 = Dense(100, activation=tf.nn.relu)(mean_pooling)
    output = Dense(output_size, activation=tf.nn.softmax)(hidden1)
    model = tf.keras.Model(inputs=[input_ids, input_mask], outputs=[output])
    return model