1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests `FeedingQueueRunner` using arrays and `DataFrames`.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.client import session 24 from tensorflow.python.estimator.inputs.queues.feeding_functions import _enqueue_data as enqueue_data 25 from tensorflow.python.framework import ops 26 from tensorflow.python.platform import test 27 from tensorflow.python.training import coordinator 28 from tensorflow.python.training import queue_runner_impl 29 30 # pylint: disable=g-import-not-at-top 31 try: 32 import pandas as pd 33 HAS_PANDAS = True 34 except ImportError: 35 HAS_PANDAS = False 36 37 38 def get_rows(array, row_indices): 39 rows = [array[i] for i in row_indices] 40 return np.vstack(rows) 41 42 43 class FeedingQueueRunnerTestCase(test.TestCase): 44 """Tests for `FeedingQueueRunner`.""" 45 46 def testArrayFeeding(self): 47 with ops.Graph().as_default(): 48 array = np.arange(32).reshape([16, 2]) 49 q = enqueue_data(array, capacity=100) 50 batch_size = 3 51 dq_op = q.dequeue_many(batch_size) 52 with session.Session() as sess: 53 coord = coordinator.Coordinator() 54 threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) 55 for i in range(100): 56 indices = [ 57 j % array.shape[0] 58 for j in range(batch_size * i, batch_size * (i + 1)) 59 ] 60 expected_dq = get_rows(array, indices) 61 dq = sess.run(dq_op) 62 np.testing.assert_array_equal(indices, dq[0]) 63 np.testing.assert_array_equal(expected_dq, dq[1]) 64 coord.request_stop() 65 coord.join(threads) 66 67 def testArrayFeedingMultiThread(self): 68 with ops.Graph().as_default(): 69 array = np.arange(256).reshape([128, 2]) 70 q = enqueue_data(array, capacity=128, num_threads=8, shuffle=True) 71 batch_size = 3 72 dq_op = q.dequeue_many(batch_size) 73 with session.Session() as sess: 74 coord = coordinator.Coordinator() 75 threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) 76 for _ in range(100): 77 dq = sess.run(dq_op) 78 indices = dq[0] 79 expected_dq = get_rows(array, indices) 80 np.testing.assert_array_equal(expected_dq, dq[1]) 81 coord.request_stop() 82 coord.join(threads) 83 84 def testPandasFeeding(self): 85 if not HAS_PANDAS: 86 return 87 with ops.Graph().as_default(): 88 array1 = np.arange(32) 89 array2 = np.arange(32, 64) 90 df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(64, 96)) 91 q = enqueue_data(df, capacity=100) 92 batch_size = 5 93 dq_op = q.dequeue_many(5) 94 with session.Session() as sess: 95 coord = coordinator.Coordinator() 96 threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) 97 for i in range(100): 98 indices = [ 99 j % array1.shape[0] 100 for j in range(batch_size * i, batch_size * (i + 1)) 101 ] 102 expected_df_indices = df.index[indices] 103 expected_rows = df.iloc[indices] 104 dq = sess.run(dq_op) 105 np.testing.assert_array_equal(expected_df_indices, dq[0]) 106 for col_num, col in enumerate(df.columns): 107 np.testing.assert_array_equal(expected_rows[col].values, 108 dq[col_num + 1]) 109 coord.request_stop() 110 coord.join(threads) 111 112 def testPandasFeedingMultiThread(self): 113 if not HAS_PANDAS: 114 return 115 with ops.Graph().as_default(): 116 array1 = np.arange(128, 256) 117 array2 = 2 * array1 118 df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(128)) 119 q = enqueue_data(df, capacity=128, num_threads=8, shuffle=True) 120 batch_size = 5 121 dq_op = q.dequeue_many(batch_size) 122 with session.Session() as sess: 123 coord = coordinator.Coordinator() 124 threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) 125 for _ in range(100): 126 dq = sess.run(dq_op) 127 indices = dq[0] 128 expected_rows = df.iloc[indices] 129 for col_num, col in enumerate(df.columns): 130 np.testing.assert_array_equal(expected_rows[col].values, 131 dq[col_num + 1]) 132 coord.request_stop() 133 coord.join(threads) 134 135 136 if __name__ == "__main__": 137 test.main() 138