目录
可视化代码
from keras.utils.vis_utils import plot_model
plot_model(model, to_file="model.png",show_shapes=True);
# coding: utf8
from keras.layers import Input,Embedding,Lambda
from keras.models import Model
import keras.backend as K
context_words_input = Input(shape=(5 * 2,), dtype='int32')
context_embedding_input = Embedding(20000, 128, name='word2vec')(context_words_input)
context_sum_embedding_input = Lambda(lambda x: K.sum(x, axis=1))(context_embedding_input)
target_word_input = Input(shape=(1,), dtype='int32')
negatives_word_input = Lambda(lambda x:K.random_uniform((K.shape(x)[0], 15), 0, 20000, 'int32'))(target_word_input)
candidate_words_input = Lambda(lambda x: K.concatenate(x))([target_word_input, negatives_word_input])
candidate_embedding_weights = Embedding(20000, 128, name='W')(candidate_words_input)
candidate_embedding_biases = Embedding(20000, 1, name='b')(candidate_words_input)
softmax = Lambda(lambda x: K.softmax((K.batch_dot(x[0], K.expand_dims(x[1], 2)) + x[2])[:, :, 0])
)([candidate_embedding_weights, context_sum_embedding_input, candidate_embedding_biases])
model = Model(input=[context_words_input, target_word_input], outputs=softmax)
from keras.utils import plot_model
plot_model(model, to_file='model.png',show_shapes=True)