Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests `FeedingQueueRunner` using arrays and `DataFrames`."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.client import session
     24 from tensorflow.python.estimator.inputs.queues.feeding_functions import _enqueue_data as enqueue_data
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.platform import test
     27 from tensorflow.python.training import coordinator
     28 from tensorflow.python.training import queue_runner_impl
     29 
     30 # pylint: disable=g-import-not-at-top
     31 try:
     32   import pandas as pd
     33   HAS_PANDAS = True
     34 except ImportError:
     35   HAS_PANDAS = False
     36 
     37 
     38 def get_rows(array, row_indices):
     39   rows = [array[i] for i in row_indices]
     40   return np.vstack(rows)
     41 
     42 
     43 class FeedingQueueRunnerTestCase(test.TestCase):
     44   """Tests for `FeedingQueueRunner`."""
     45 
     46   def testArrayFeeding(self):
     47     with ops.Graph().as_default():
     48       array = np.arange(32).reshape([16, 2])
     49       q = enqueue_data(array, capacity=100)
     50       batch_size = 3
     51       dq_op = q.dequeue_many(batch_size)
     52       with session.Session() as sess:
     53         coord = coordinator.Coordinator()
     54         threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
     55         for i in range(100):
     56           indices = [
     57               j % array.shape[0]
     58               for j in range(batch_size * i, batch_size * (i + 1))
     59           ]
     60           expected_dq = get_rows(array, indices)
     61           dq = sess.run(dq_op)
     62           np.testing.assert_array_equal(indices, dq[0])
     63           np.testing.assert_array_equal(expected_dq, dq[1])
     64         coord.request_stop()
     65         coord.join(threads)
     66 
     67   def testArrayFeedingMultiThread(self):
     68     with ops.Graph().as_default():
     69       array = np.arange(256).reshape([128, 2])
     70       q = enqueue_data(array, capacity=128, num_threads=8, shuffle=True)
     71       batch_size = 3
     72       dq_op = q.dequeue_many(batch_size)
     73       with session.Session() as sess:
     74         coord = coordinator.Coordinator()
     75         threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
     76         for _ in range(100):
     77           dq = sess.run(dq_op)
     78           indices = dq[0]
     79           expected_dq = get_rows(array, indices)
     80           np.testing.assert_array_equal(expected_dq, dq[1])
     81         coord.request_stop()
     82         coord.join(threads)
     83 
     84   def testPandasFeeding(self):
     85     if not HAS_PANDAS:
     86       return
     87     with ops.Graph().as_default():
     88       array1 = np.arange(32)
     89       array2 = np.arange(32, 64)
     90       df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(64, 96))
     91       q = enqueue_data(df, capacity=100)
     92       batch_size = 5
     93       dq_op = q.dequeue_many(5)
     94       with session.Session() as sess:
     95         coord = coordinator.Coordinator()
     96         threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
     97         for i in range(100):
     98           indices = [
     99               j % array1.shape[0]
    100               for j in range(batch_size * i, batch_size * (i + 1))
    101           ]
    102           expected_df_indices = df.index[indices]
    103           expected_rows = df.iloc[indices]
    104           dq = sess.run(dq_op)
    105           np.testing.assert_array_equal(expected_df_indices, dq[0])
    106           for col_num, col in enumerate(df.columns):
    107             np.testing.assert_array_equal(expected_rows[col].values,
    108                                           dq[col_num + 1])
    109         coord.request_stop()
    110         coord.join(threads)
    111 
    112   def testPandasFeedingMultiThread(self):
    113     if not HAS_PANDAS:
    114       return
    115     with ops.Graph().as_default():
    116       array1 = np.arange(128, 256)
    117       array2 = 2 * array1
    118       df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(128))
    119       q = enqueue_data(df, capacity=128, num_threads=8, shuffle=True)
    120       batch_size = 5
    121       dq_op = q.dequeue_many(batch_size)
    122       with session.Session() as sess:
    123         coord = coordinator.Coordinator()
    124         threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
    125         for _ in range(100):
    126           dq = sess.run(dq_op)
    127           indices = dq[0]
    128           expected_rows = df.iloc[indices]
    129           for col_num, col in enumerate(df.columns):
    130             np.testing.assert_array_equal(expected_rows[col].values,
    131                                           dq[col_num + 1])
    132         coord.request_stop()
    133         coord.join(threads)
    134 
    135 
    136 if __name__ == "__main__":
    137   test.main()
    138