目录

解析libsvm数据


#1 1:0.5 2:0.03519 3:1 4:0.02567 7:0.03708 8:0.01705 9:0.06296 10:0.18185 11:0.02497 12:1 14:0.02565 15:0.03267 17:0.0247 18:0.03158 20:1 22:1 23:0.13169 24:0.02933 27:0.18159 31:0.0177 34:0.02888 38:1 51:1 63:1 132:1 164:1 236:1
def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
    print('Parsing', filenames)
    def decode_libsvm(line):
        columns = tf.string_split([line], ' ')
        labels = tf.string_to_number(columns.values[0], out_type=tf.float32)
        splits = tf.string_split(columns.values[1:], ':')
        id_vals = tf.reshape(splits.values,splits.dense_shape)
        feat_ids, feat_vals = tf.split(id_vals,num_or_size_splits=2,axis=1)
        feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
        feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
        return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TextLineDataset(filenames).map(decode_libsvm, num_parallel_calls=10).prefetch(500000)    # multi-thread pre-process then prefetch

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size) # Batch size to use

    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

输入文件夹


def cm_input_fn(multiple_input_path, batch_size,
                worker_num=None, worker_idx=None, is_train=True):
    total_files = []
    for input_path in multiple_input_path:
        tmp = tf.gfile.ListDirectory(input_path)
        files = sorted(map(lambda x: input_path + '/' + x, tmp))
        total_files += files
    random.shuffle(total_files)
    dataset = tf.data.Dataset.from_tensor_slices(total_files)
    if is_train:
        dataset = dataset.shard(worker_num, worker_idx)

    dataset = dataset.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=8))
    if is_train:
        dataset = dataset.repeat(1)
    else:
        dataset = dataset.repeat(500000)
    dataset = dataset.batch(batch_size).map(parse_fn, num_parallel_calls=8)
    dataset = dataset.prefetch(buffer_size=batch_size)
    iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()