class MeanPool(Layer):
  def __init__(self, **kwargs):
      self.supports_masking = True
      super(MeanPool, self).__init__(**kwargs)
  def compute_mask(self, input, input_mask=None):
      # do not pass the mask to the next layers
      return None
  def call(self, x, mask=None):
    if mask is not None:
        # mask (batch, time)
        mask = K.cast(mask, K.floatx())
        # mask (batch, x_dim, time)
        mask = K.repeat(mask, x.shape[-1])
        # mask (batch, time, x_dim)
        mask = tf.transpose(mask, [0,2,1])
        x = x * mask
        return tf.divide(K.sum(x, axis=1), K.sum(mask, axis=1) + 0.000001)
        # return K.sum(x, axis=1)
    else:
        return K.mean(x, axis=1)
  def compute_output_shape(self, input_shape):
      # remove temporal dimension
      return (input_shape[0], input_shape[2])

# BATCH_SIZE * QUERY_LEN
input_word_ids = Input(shape=(config.QUERY_LEN,), dtype=tf.int32)
# BATCH_SIZE * QUERY_LEN * EMB_SIZE
input_word_emb = Embedding(config.VOC_SIZE, config.EMB_SIZE, name="word2vec")(input_word_ids)
# BATCH_SIZE * QUERY_LEN
input_mask = Input(shape=(config.QUERY_LEN,), dtype=tf.int32)
avg_pooling = MeanPool()(input_word_emb, mask=input_mask)
hidden1 = Dense(768, activation=keras.activations.relu)(avg_pooling)
drop1 = Dropout(0.5)(hidden1)
hidden2 = Dense(512, activation=keras.activations.relu)(drop1)
drop2 = Dropout(0.5)(hidden2)
hidden3 = Dense(256, activation=keras.activations.relu)(drop2)
output = Dense(config.CATEGORY_NUMBER, activation=keras.activations.softmax)(hidden3)

mlp = Model(inputs=[input_mask, input_word_ids], outputs=[output])

mlp.compile(optimizer=tf.train.AdamOptimizer(0.01),
            loss=keras.metrics.categorical_crossentropy,
            metrics=[keras.metrics.categorical_accuracy])

mlp.fit([train_mask, train_features], train_labels, epochs=20, batch_size=200, validation_data=[[test_mask, test_features], test_labels])