parse_example(
    serialized,
    features,
    name=None,
    example_names=None
)

把 Example 原型解析成张量字典。

解析 serialized 中给定的一系列序列化 Example 原型。我们提到的 serialized 作为一个批次与 batch_size 的许多条目的个别 Example 原型。

# coding: utf8

import tensorflow as tf

group_id_counts_dict ={'item': 4156046, 'title': 179083, 'video': 129024, 'user': 589290, 'face': 2092}
group_emb_size_dict ={'item': 97, 'title': 38, 'video': 34, 'user': 54, 'face': 10}
sess = tf.InteractiveSession()
init = (tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init)


def parse_fn(example):
    features_dict = dict()
    features_dict['finish'] = tf.FixedLenFeature([1], tf.float32)
    features_dict['uid'] = tf.FixedLenFeature([1], tf.string)
    features_dict['itemid'] = tf.FixedLenFeature([1], tf.string)
    features_dict['instanceId'] = tf.FixedLenFeature([1], tf.string)
    for group_name in group_id_counts_dict.keys():
        features_dict[group_name] = tf.VarLenFeature(tf.int64)
    features = tf.parse_example(example, features=features_dict)
    return features


def cm_input_fn(data_path, batch_size, worker_num=None, worker_idx=None, is_train=True):
    files = tf.data.Dataset.list_files(data_path)
    if is_train and worker_num is not None and worker_idx is not None:
        files = files.shard(worker_num, worker_idx)
    dataset = files.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=16))
    if is_train:
        dataset = dataset.repeat(5)
    else:
        dataset = dataset.repeat(500000)
    dataset = dataset.batch(batch_size).map(parse_fn, num_parallel_calls=16)
    dataset = dataset.prefetch(buffer_size=batch_size*10)
    iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()


coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord, sess=sess)

train_dir = "train_part"
train = cm_input_fn(train_dir, 100)

predict_dir = "predict_part"
predict = cm_input_fn(predict_dir, 100)