1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for pandas_io.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.estimator.inputs import pandas_io 24 from tensorflow.python.framework import errors 25 from tensorflow.python.platform import test 26 from tensorflow.python.training import coordinator 27 from tensorflow.python.training import queue_runner_impl 28 29 try: 30 # pylint: disable=g-import-not-at-top 31 import pandas as pd 32 HAS_PANDAS = True 33 except IOError: 34 # Pandas writes a temporary file during import. If it fails, don't use pandas. 35 HAS_PANDAS = False 36 except ImportError: 37 HAS_PANDAS = False 38 39 40 class PandasIoTest(test.TestCase): 41 42 def makeTestDataFrame(self): 43 index = np.arange(100, 104) 44 a = np.arange(4) 45 b = np.arange(32, 36) 46 x = pd.DataFrame({'a': a, 'b': b}, index=index) 47 y = pd.Series(np.arange(-32, -28), index=index) 48 return x, y 49 50 def callInputFnOnce(self, input_fn, session): 51 results = input_fn() 52 coord = coordinator.Coordinator() 53 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 54 result_values = session.run(results) 55 coord.request_stop() 56 coord.join(threads) 57 return result_values 58 59 def testPandasInputFn_IndexMismatch(self): 60 if not HAS_PANDAS: 61 return 62 x, _ = self.makeTestDataFrame() 63 y_noindex = pd.Series(np.arange(-32, -28)) 64 with self.assertRaises(ValueError): 65 pandas_io.pandas_input_fn( 66 x, y_noindex, batch_size=2, shuffle=False, num_epochs=1) 67 68 def testPandasInputFn_NonBoolShuffle(self): 69 if not HAS_PANDAS: 70 return 71 x, _ = self.makeTestDataFrame() 72 y_noindex = pd.Series(np.arange(-32, -28)) 73 with self.assertRaisesRegexp(TypeError, 74 'shuffle must be explicitly set as boolean'): 75 # Default shuffle is None 76 pandas_io.pandas_input_fn(x, y_noindex) 77 78 def testPandasInputFn_ProducesExpectedOutputs(self): 79 if not HAS_PANDAS: 80 return 81 with self.test_session() as session: 82 x, y = self.makeTestDataFrame() 83 input_fn = pandas_io.pandas_input_fn( 84 x, y, batch_size=2, shuffle=False, num_epochs=1) 85 86 features, target = self.callInputFnOnce(input_fn, session) 87 88 self.assertAllEqual(features['a'], [0, 1]) 89 self.assertAllEqual(features['b'], [32, 33]) 90 self.assertAllEqual(target, [-32, -31]) 91 92 def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self): 93 if not HAS_PANDAS: 94 return 95 with self.test_session() as session: 96 index = np.arange(100, 102) 97 a = np.arange(2) 98 b = np.arange(32, 34) 99 x = pd.DataFrame({'a': a, 'b': b}, index=index) 100 y = pd.Series(np.arange(-32, -30), index=index) 101 input_fn = pandas_io.pandas_input_fn( 102 x, y, batch_size=128, shuffle=False, num_epochs=2) 103 104 results = input_fn() 105 106 coord = coordinator.Coordinator() 107 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 108 109 features, target = session.run(results) 110 self.assertAllEqual(features['a'], [0, 1, 0, 1]) 111 self.assertAllEqual(features['b'], [32, 33, 32, 33]) 112 self.assertAllEqual(target, [-32, -31, -32, -31]) 113 114 with self.assertRaises(errors.OutOfRangeError): 115 session.run(results) 116 117 coord.request_stop() 118 coord.join(threads) 119 120 def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self): 121 if not HAS_PANDAS: 122 return 123 with self.test_session() as session: 124 index = np.arange(100, 105) 125 a = np.arange(5) 126 b = np.arange(32, 37) 127 x = pd.DataFrame({'a': a, 'b': b}, index=index) 128 y = pd.Series(np.arange(-32, -27), index=index) 129 130 input_fn = pandas_io.pandas_input_fn( 131 x, y, batch_size=2, shuffle=False, num_epochs=1) 132 133 results = input_fn() 134 135 coord = coordinator.Coordinator() 136 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 137 138 features, target = session.run(results) 139 self.assertAllEqual(features['a'], [0, 1]) 140 self.assertAllEqual(features['b'], [32, 33]) 141 self.assertAllEqual(target, [-32, -31]) 142 143 features, target = session.run(results) 144 self.assertAllEqual(features['a'], [2, 3]) 145 self.assertAllEqual(features['b'], [34, 35]) 146 self.assertAllEqual(target, [-30, -29]) 147 148 features, target = session.run(results) 149 self.assertAllEqual(features['a'], [4]) 150 self.assertAllEqual(features['b'], [36]) 151 self.assertAllEqual(target, [-28]) 152 153 with self.assertRaises(errors.OutOfRangeError): 154 session.run(results) 155 156 coord.request_stop() 157 coord.join(threads) 158 159 def testPandasInputFn_OnlyX(self): 160 if not HAS_PANDAS: 161 return 162 with self.test_session() as session: 163 x, _ = self.makeTestDataFrame() 164 input_fn = pandas_io.pandas_input_fn( 165 x, y=None, batch_size=2, shuffle=False, num_epochs=1) 166 167 features = self.callInputFnOnce(input_fn, session) 168 169 self.assertAllEqual(features['a'], [0, 1]) 170 self.assertAllEqual(features['b'], [32, 33]) 171 172 def testPandasInputFn_ExcludesIndex(self): 173 if not HAS_PANDAS: 174 return 175 with self.test_session() as session: 176 x, y = self.makeTestDataFrame() 177 input_fn = pandas_io.pandas_input_fn( 178 x, y, batch_size=2, shuffle=False, num_epochs=1) 179 180 features, _ = self.callInputFnOnce(input_fn, session) 181 182 self.assertFalse('index' in features) 183 184 def assertInputsCallableNTimes(self, input_fn, session, n): 185 inputs = input_fn() 186 coord = coordinator.Coordinator() 187 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 188 for _ in range(n): 189 session.run(inputs) 190 with self.assertRaises(errors.OutOfRangeError): 191 session.run(inputs) 192 coord.request_stop() 193 coord.join(threads) 194 195 def testPandasInputFn_RespectsEpoch_NoShuffle(self): 196 if not HAS_PANDAS: 197 return 198 with self.test_session() as session: 199 x, y = self.makeTestDataFrame() 200 input_fn = pandas_io.pandas_input_fn( 201 x, y, batch_size=4, shuffle=False, num_epochs=1) 202 203 self.assertInputsCallableNTimes(input_fn, session, 1) 204 205 def testPandasInputFn_RespectsEpoch_WithShuffle(self): 206 if not HAS_PANDAS: 207 return 208 with self.test_session() as session: 209 x, y = self.makeTestDataFrame() 210 input_fn = pandas_io.pandas_input_fn( 211 x, y, batch_size=4, shuffle=True, num_epochs=1) 212 213 self.assertInputsCallableNTimes(input_fn, session, 1) 214 215 def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self): 216 if not HAS_PANDAS: 217 return 218 with self.test_session() as session: 219 x, y = self.makeTestDataFrame() 220 input_fn = pandas_io.pandas_input_fn( 221 x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2) 222 223 self.assertInputsCallableNTimes(input_fn, session, 4) 224 225 def testPandasInputFn_RespectsEpochUnevenBatches(self): 226 if not HAS_PANDAS: 227 return 228 x, y = self.makeTestDataFrame() 229 with self.test_session() as session: 230 input_fn = pandas_io.pandas_input_fn( 231 x, y, batch_size=3, shuffle=False, num_epochs=1) 232 233 # Before the last batch, only one element of the epoch should remain. 234 self.assertInputsCallableNTimes(input_fn, session, 2) 235 236 def testPandasInputFn_Idempotent(self): 237 if not HAS_PANDAS: 238 return 239 x, y = self.makeTestDataFrame() 240 for _ in range(2): 241 pandas_io.pandas_input_fn( 242 x, y, batch_size=2, shuffle=False, num_epochs=1)() 243 for _ in range(2): 244 pandas_io.pandas_input_fn( 245 x, y, batch_size=2, shuffle=True, num_epochs=1)() 246 247 248 if __name__ == '__main__': 249 test.main() 250