parse_example(
serialized,
features,
name=None,
example_names=None
)
把 Example 原型解析成张量字典。
解析 serialized 中给定的一系列序列化 Example 原型。我们提到的 serialized 作为一个批次与 batch_size 的许多条目的个别 Example 原型。
# coding: utf8
import tensorflow as tf
group_id_counts_dict ={'item': 4156046, 'title': 179083, 'video': 129024, 'user': 589290, 'face': 2092}
group_emb_size_dict ={'item': 97, 'title': 38, 'video': 34, 'user': 54, 'face': 10}
sess = tf.InteractiveSession()
init = (tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init)
def parse_fn(example):
features_dict = dict()
features_dict['finish'] = tf.FixedLenFeature([1], tf.float32)
features_dict['uid'] = tf.FixedLenFeature([1], tf.string)
features_dict['itemid'] = tf.FixedLenFeature([1], tf.string)
features_dict['instanceId'] = tf.FixedLenFeature([1], tf.string)
for group_name in group_id_counts_dict.keys():
features_dict[group_name] = tf.VarLenFeature(tf.int64)
features = tf.parse_example(example, features=features_dict)
return features
def cm_input_fn(data_path, batch_size, worker_num=None, worker_idx=None, is_train=True):
files = tf.data.Dataset.list_files(data_path)
if is_train and worker_num is not None and worker_idx is not None:
files = files.shard(worker_num, worker_idx)
dataset = files.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=16))
if is_train:
dataset = dataset.repeat(5)
else:
dataset = dataset.repeat(500000)
dataset = dataset.batch(batch_size).map(parse_fn, num_parallel_calls=16)
dataset = dataset.prefetch(buffer_size=batch_size*10)
iterator = dataset.make_one_shot_iterator()
return iterator.get_next()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
train_dir = "train_part"
train = cm_input_fn(train_dir, 100)
predict_dir = "predict_part"
predict = cm_input_fn(predict_dir, 100)