目录
为什么用DatasetAPI
- 常规方式:用 python 代码来进行 batch,shuffle,padding 等 numpy 类型的数据处理,再用 placeholder + feed_dict 来将其导入到 graph 中变成 tensor 类型。因此在网络的训练过程中,不得不在 tensorflow 的代码中穿插 python 代码来实现控制。
- Dataset API:将数据直接放在 graph 中进行处理,整体对数据集进行上述数据操作,使代码更加简洁。
可以与Estimator对接
TensorFlow 中也加入了高级 API (Estimator、Experiment,Dataset)帮助建立网络,和 Keras 等库不一样的是:这些 API 并不注重网络结构的搭建,而是将不同类型的操作分开,帮助周边操作。可以在保证网络结构控制权的基础上,节省工作量。若使用 Dataset API 导入数据,后续还可选择与 Estimator 对接。
# Preprocess 4 files concurrently.
filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
dataset = filenames.apply(
tf.contrib.data.parallel_interleave(
lambda filename: tf.data.TFRecordDataset(filename),
cycle_length=4))
参考
- https://www.sohu.com/a/219765050_717210