目录

三种数据输入方式

  • 通过feed_dict
  • 通过文件名读取数据:一个输入流水线 在计算图的开始部分从文件中读取数据
  • 把数据预加载到一个常量或者变量中 本文主要讲解第二种。

Queue 队列

Queue,队列,用来存放数据(跟Variable似的),tensorflow中的Queue中已经实现了同步机制,所以我们可以放心的往里面添加数据还有读取数据.如果Queue中的数据满了,那么en_queue操作将会阻塞,如果Queue是空的,那么dequeue操作就会阻塞.在常用环境中,一般是有多个en_queue线程同时像Queue中放数据,有一个dequeue操作从Queue中取数据.一般来说enqueue线程就是准备数据的线程,dequeue线程就是训练数据的线程.

Coordinator 协调者

Coordinator就是用来帮助多个线程同时停止.线程组需要一个Coordinator来协调它们之间的工作.

# Thread body: loop until the coordinator indicates a stop was requested.
# If some condition becomes true, ask the coordinator to stop.
#将coord传入到线程中,来帮助它们同时停止工作
def MyLoop(coord):
  while not coord.should_stop():
    ...do something...
    if ...some condition...:
      coord.request_stop()

# Main thread: create a coordinator.
coord = tf.train.Coordinator()

# Create 10 threads that run 'MyLoop()'
threads = [threading.Thread(target=MyLoop, args=(coord,)) for i in xrange(10)]

# Start the threads and wait for all of them to stop.
for t in threads:
  t.start()
coord.join(threads)

QueueRunner

QueueRunner创建多个线程对Queue进行enqueue操作.它是一个op.这些线程可以通过上面所述的Coordinator来协调它们同时停止工作.

输入流水线 input-pipeline

Queue是一个队列,QueueRunner用来创建多个线程对Queue进行enqueue操作.Coordinator可用来协调QueueRunner创建出来的线程共同停止工作.

流程

  • 准备文件名
  • 创建一个Reader从文件中读取数据
  • 定义文件中数据的解码规则
  • 解析数据

代码

import tensorflow as tf
#一个Queue,用来保存文件名字.对此Queue,只读取,不dequeue
filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])

#用来从文件中读取数据, LineReader,每次读一行
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
    value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])

with tf.Session() as sess:
  # Start populating the filename queue.
  coord = tf.train.Coordinator()
  #在调用run或eval执行读取之前,必须
  #用tf.train.start_queue_runners来填充队列
  threads = tf.train.start_queue_runners(coord=coord)

  for i in range(10):
    # Retrieve a single instance:
    example, label = sess.run([features, col5])
    print(example, label)
  coord.request_stop()
  coord.join(threads)