目录

为什么用DatasetAPI

  • 常规方式:用 python 代码来进行 batch,shuffle,padding 等 numpy 类型的数据处理,再用 placeholder + feed_dict 来将其导入到 graph 中变成 tensor 类型。因此在网络的训练过程中,不得不在 tensorflow 的代码中穿插 python 代码来实现控制。
  • Dataset API:将数据直接放在 graph 中进行处理,整体对数据集进行上述数据操作,使代码更加简洁。

可以与Estimator对接

TensorFlow 中也加入了高级 API (Estimator、Experiment,Dataset)帮助建立网络,和 Keras 等库不一样的是:这些 API 并不注重网络结构的搭建,而是将不同类型的操作分开,帮助周边操作。可以在保证网络结构控制权的基础上,节省工作量。若使用 Dataset API 导入数据,后续还可选择与 Estimator 对接。

# Preprocess 4 files concurrently.
filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
dataset = filenames.apply(
    tf.contrib.data.parallel_interleave(
        lambda filename: tf.data.TFRecordDataset(filename),
        cycle_length=4))

参考

  • https://www.sohu.com/a/219765050_717210