简介

本文介绍如何将文本文件转换成tfrecord格式文件。

文件格式

每行为output_size + input_size个浮点数,前面output_size个浮点数表示输出,后面input_size个浮点数表示输入。

参数说明

  • input_file: 输入文件
  • train_file: 训练文件
  • test_file: 测试文件
  • input_size: 输入层大小
  • output_size: 输出层大小
  • test_data_ratio: 测试集占比

示例代码

def encode_to_tfrecords(input_file, train_file, test_file, input_size, output_size, test_data_ratio=0.2):
    train_writer = tf.python_io.TFRecordWriter(train_file)
    test_writer = tf.python_io.TFRecordWriter(test_file)

    with open(input_file, "r") as reader:
        for line in reader:
            splits = line.strip().split("\t")
            values = [float(_) for _ in splits]
            if len(values) < input_size + output_size:
                continue
            y = values[:output_size]
            x = values[output_size:]
            # 清除掉一些为nan的样本
            if math.isnan(sum(values)):
                continue
            try:
                label = tf.train.Feature(float_list=tf.train.FloatList(value=y))
                features = tf.train.Feature(float_list=tf.train.FloatList(value=x))
                example = tf.train.Example(
                    features=tf.train.Features(
                        feature={
                            'label': label,
                            'features': features
                        }
                    )
                )
                ratio = random.random()
                if ratio < test_data_ratio:
                    test_writer.write(example.SerializeToString())
                else:
                    train_writer.write(example.SerializeToString())
            except Exception as e:
                print(e)
                pass

    train_writer.close()
    test_writer.close()
train_writer = tf.python_io.TFRecordWriter("input/train.tfr")
for index in range(len(labels)):
    label = tf.train.Feature(float_list=tf.train.FloatList(value=labels[index]))
    features = tf.train.Feature(int64_list=tf.train.Int64List(value=train[0][index]))
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'label': label,
                'features': features
            }
        )
    )
    train_writer.write(example.SerializeToString())

train_writer.close()
def get_batch_sample(filename, batch_size):
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialized_sample = reader.read(filename_queue)
    sample = tf.parse_single_example(
        serialized_sample,
        features={
            'label': tf.FixedLenFeature([7], tf.float32),
            'features': tf.FixedLenFeature([25], tf.int64),
        })
    features = sample['features']
    label = sample['label']
    while True:
        features, label = tf.train.shuffle_batch(
            [features, label],
            batch_size=batch_size,
            capacity=5000,
            min_after_dequeue=1000,
            allow_smaller_final_batch=True,
            num_threads=10)
        features.set_shape([batch_size, 25])
        label.set_shape([batch_size, 7])
        yield features, label

参考