目录

可视化代码

from keras.utils.vis_utils import plot_model
plot_model(model, to_file="model.png",show_shapes=True);
# coding: utf8

from keras.layers import Input,Embedding,Lambda
from keras.models import Model
import keras.backend as K

context_words_input = Input(shape=(5 * 2,), dtype='int32')

context_embedding_input = Embedding(20000, 128, name='word2vec')(context_words_input)

context_sum_embedding_input = Lambda(lambda x: K.sum(x, axis=1))(context_embedding_input)

target_word_input = Input(shape=(1,), dtype='int32')
negatives_word_input = Lambda(lambda x:K.random_uniform((K.shape(x)[0], 15), 0, 20000, 'int32'))(target_word_input)

candidate_words_input = Lambda(lambda x: K.concatenate(x))([target_word_input, negatives_word_input])

candidate_embedding_weights = Embedding(20000, 128, name='W')(candidate_words_input)
candidate_embedding_biases = Embedding(20000, 1, name='b')(candidate_words_input)

softmax = Lambda(lambda x: K.softmax((K.batch_dot(x[0], K.expand_dims(x[1], 2)) + x[2])[:, :, 0])
                         )([candidate_embedding_weights, context_sum_embedding_input, candidate_embedding_biases])

model = Model(input=[context_words_input, target_word_input], outputs=softmax)

from keras.utils import plot_model

plot_model(model, to_file='model.png',show_shapes=True)