简介
本文介绍如何将文本文件转换成tfrecord格式文件。
文件格式
每行为output_size + input_size个浮点数,前面output_size个浮点数表示输出,后面input_size个浮点数表示输入。
参数说明
- input_file: 输入文件
- train_file: 训练文件
- test_file: 测试文件
- input_size: 输入层大小
- output_size: 输出层大小
- test_data_ratio: 测试集占比
示例代码
def encode_to_tfrecords(input_file, train_file, test_file, input_size, output_size, test_data_ratio=0.2):
train_writer = tf.python_io.TFRecordWriter(train_file)
test_writer = tf.python_io.TFRecordWriter(test_file)
with open(input_file, "r") as reader:
for line in reader:
splits = line.strip().split("\t")
values = [float(_) for _ in splits]
if len(values) < input_size + output_size:
continue
y = values[:output_size]
x = values[output_size:]
# 清除掉一些为nan的样本
if math.isnan(sum(values)):
continue
try:
label = tf.train.Feature(float_list=tf.train.FloatList(value=y))
features = tf.train.Feature(float_list=tf.train.FloatList(value=x))
example = tf.train.Example(
features=tf.train.Features(
feature={
'label': label,
'features': features
}
)
)
ratio = random.random()
if ratio < test_data_ratio:
test_writer.write(example.SerializeToString())
else:
train_writer.write(example.SerializeToString())
except Exception as e:
print(e)
pass
train_writer.close()
test_writer.close()
train_writer = tf.python_io.TFRecordWriter("input/train.tfr")
for index in range(len(labels)):
label = tf.train.Feature(float_list=tf.train.FloatList(value=labels[index]))
features = tf.train.Feature(int64_list=tf.train.Int64List(value=train[0][index]))
example = tf.train.Example(
features=tf.train.Features(
feature={
'label': label,
'features': features
}
)
)
train_writer.write(example.SerializeToString())
train_writer.close()
def get_batch_sample(filename, batch_size):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_sample = reader.read(filename_queue)
sample = tf.parse_single_example(
serialized_sample,
features={
'label': tf.FixedLenFeature([7], tf.float32),
'features': tf.FixedLenFeature([25], tf.int64),
})
features = sample['features']
label = sample['label']
while True:
features, label = tf.train.shuffle_batch(
[features, label],
batch_size=batch_size,
capacity=5000,
min_after_dequeue=1000,
allow_smaller_final_batch=True,
num_threads=10)
features.set_shape([batch_size, 25])
label.set_shape([batch_size, 7])
yield features, label