目录
解析libsvm数据
#1 1:0.5 2:0.03519 3:1 4:0.02567 7:0.03708 8:0.01705 9:0.06296 10:0.18185 11:0.02497 12:1 14:0.02565 15:0.03267 17:0.0247 18:0.03158 20:1 22:1 23:0.13169 24:0.02933 27:0.18159 31:0.0177 34:0.02888 38:1 51:1 63:1 132:1 164:1 236:1
def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
print('Parsing', filenames)
def decode_libsvm(line):
columns = tf.string_split([line], ' ')
labels = tf.string_to_number(columns.values[0], out_type=tf.float32)
splits = tf.string_split(columns.values[1:], ':')
id_vals = tf.reshape(splits.values,splits.dense_shape)
feat_ids, feat_vals = tf.split(id_vals,num_or_size_splits=2,axis=1)
feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels
# Extract lines from input files using the Dataset API, can pass one filename or filename list
dataset = tf.data.TextLineDataset(filenames).map(decode_libsvm, num_parallel_calls=10).prefetch(500000) # multi-thread pre-process then prefetch
# Randomizes input using a window of 256 elements (read into memory)
if perform_shuffle:
dataset = dataset.shuffle(buffer_size=256)
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size) # Batch size to use
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
输入文件夹
def cm_input_fn(multiple_input_path, batch_size,
worker_num=None, worker_idx=None, is_train=True):
total_files = []
for input_path in multiple_input_path:
tmp = tf.gfile.ListDirectory(input_path)
files = sorted(map(lambda x: input_path + '/' + x, tmp))
total_files += files
random.shuffle(total_files)
dataset = tf.data.Dataset.from_tensor_slices(total_files)
if is_train:
dataset = dataset.shard(worker_num, worker_idx)
dataset = dataset.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=8))
if is_train:
dataset = dataset.repeat(1)
else:
dataset = dataset.repeat(500000)
dataset = dataset.batch(batch_size).map(parse_fn, num_parallel_calls=8)
dataset = dataset.prefetch(buffer_size=batch_size)
iterator = dataset.make_one_shot_iterator()
return iterator.get_next()