from tensorflow.python.keras.layers import Layer
import tensorflow as tf

class bi_interaction(Layer):
    def __init__(self, **kwargs):
        super(bi_interaction, self).__init__(**kwargs)

    def build(self, input_shape):
        super(bi_interaction, self).build(input_shape)

    def call(self, inputs):
        concat_embed_value = inputs
        square_of_sum = tf.square(tf.reduce_sum(concat_embed_value, axis=1, keep_dims=True))
        sum_of_square = tf.reduce_sum(concat_embed_value * concat_embed_value, axis=1, keep_dims=True)
        cross = 0.5 * (square_of_sum - sum_of_square)
        return cross#(batch , 1, embed_size)