目录
三种数据输入方式
- 通过feed_dict
- 通过文件名读取数据:一个输入流水线 在计算图的开始部分从文件中读取数据
- 把数据预加载到一个常量或者变量中 本文主要讲解第二种。
Queue 队列
Queue,队列,用来存放数据(跟Variable似的),tensorflow中的Queue中已经实现了同步机制,所以我们可以放心的往里面添加数据还有读取数据.如果Queue中的数据满了,那么en_queue操作将会阻塞,如果Queue是空的,那么dequeue操作就会阻塞.在常用环境中,一般是有多个en_queue线程同时像Queue中放数据,有一个dequeue操作从Queue中取数据.一般来说enqueue线程就是准备数据的线程,dequeue线程就是训练数据的线程.
Coordinator 协调者
Coordinator就是用来帮助多个线程同时停止.线程组需要一个Coordinator来协调它们之间的工作.
# Thread body: loop until the coordinator indicates a stop was requested.
# If some condition becomes true, ask the coordinator to stop.
#将coord传入到线程中,来帮助它们同时停止工作
def MyLoop(coord):
while not coord.should_stop():
...do something...
if ...some condition...:
coord.request_stop()
# Main thread: create a coordinator.
coord = tf.train.Coordinator()
# Create 10 threads that run 'MyLoop()'
threads = [threading.Thread(target=MyLoop, args=(coord,)) for i in xrange(10)]
# Start the threads and wait for all of them to stop.
for t in threads:
t.start()
coord.join(threads)
QueueRunner
QueueRunner创建多个线程对Queue进行enqueue操作.它是一个op.这些线程可以通过上面所述的Coordinator来协调它们同时停止工作.
输入流水线 input-pipeline
Queue是一个队列,QueueRunner用来创建多个线程对Queue进行enqueue操作.Coordinator可用来协调QueueRunner创建出来的线程共同停止工作.
流程
- 准备文件名
- 创建一个Reader从文件中读取数据
- 定义文件中数据的解码规则
- 解析数据
代码
import tensorflow as tf
#一个Queue,用来保存文件名字.对此Queue,只读取,不dequeue
filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])
#用来从文件中读取数据, LineReader,每次读一行
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
#在调用run或eval执行读取之前,必须
#用tf.train.start_queue_runners来填充队列
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
# Retrieve a single instance:
example, label = sess.run([features, col5])
print(example, label)
coord.request_stop()
coord.join(threads)