Home | History | Annotate | Download | only in inputs
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for pandas_io."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.estimator.inputs import pandas_io
     24 from tensorflow.python.framework import errors
     25 from tensorflow.python.platform import test
     26 from tensorflow.python.training import coordinator
     27 from tensorflow.python.training import queue_runner_impl
     28 
     29 try:
     30   # pylint: disable=g-import-not-at-top
     31   import pandas as pd
     32   HAS_PANDAS = True
     33 except IOError:
     34   # Pandas writes a temporary file during import. If it fails, don't use pandas.
     35   HAS_PANDAS = False
     36 except ImportError:
     37   HAS_PANDAS = False
     38 
     39 
     40 class PandasIoTest(test.TestCase):
     41 
     42   def makeTestDataFrame(self):
     43     index = np.arange(100, 104)
     44     a = np.arange(4)
     45     b = np.arange(32, 36)
     46     x = pd.DataFrame({'a': a, 'b': b}, index=index)
     47     y = pd.Series(np.arange(-32, -28), index=index)
     48     return x, y
     49 
     50   def callInputFnOnce(self, input_fn, session):
     51     results = input_fn()
     52     coord = coordinator.Coordinator()
     53     threads = queue_runner_impl.start_queue_runners(session, coord=coord)
     54     result_values = session.run(results)
     55     coord.request_stop()
     56     coord.join(threads)
     57     return result_values
     58 
     59   def testPandasInputFn_IndexMismatch(self):
     60     if not HAS_PANDAS:
     61       return
     62     x, _ = self.makeTestDataFrame()
     63     y_noindex = pd.Series(np.arange(-32, -28))
     64     with self.assertRaises(ValueError):
     65       pandas_io.pandas_input_fn(
     66           x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)
     67 
     68   def testPandasInputFn_NonBoolShuffle(self):
     69     if not HAS_PANDAS:
     70       return
     71     x, _ = self.makeTestDataFrame()
     72     y_noindex = pd.Series(np.arange(-32, -28))
     73     with self.assertRaisesRegexp(TypeError,
     74                                  'shuffle must be explicitly set as boolean'):
     75       # Default shuffle is None
     76       pandas_io.pandas_input_fn(x, y_noindex)
     77 
     78   def testPandasInputFn_ProducesExpectedOutputs(self):
     79     if not HAS_PANDAS:
     80       return
     81     with self.test_session() as session:
     82       x, y = self.makeTestDataFrame()
     83       input_fn = pandas_io.pandas_input_fn(
     84           x, y, batch_size=2, shuffle=False, num_epochs=1)
     85 
     86       features, target = self.callInputFnOnce(input_fn, session)
     87 
     88       self.assertAllEqual(features['a'], [0, 1])
     89       self.assertAllEqual(features['b'], [32, 33])
     90       self.assertAllEqual(target, [-32, -31])
     91 
     92   def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
     93     if not HAS_PANDAS:
     94       return
     95     with self.test_session() as session:
     96       index = np.arange(100, 102)
     97       a = np.arange(2)
     98       b = np.arange(32, 34)
     99       x = pd.DataFrame({'a': a, 'b': b}, index=index)
    100       y = pd.Series(np.arange(-32, -30), index=index)
    101       input_fn = pandas_io.pandas_input_fn(
    102           x, y, batch_size=128, shuffle=False, num_epochs=2)
    103 
    104       results = input_fn()
    105 
    106       coord = coordinator.Coordinator()
    107       threads = queue_runner_impl.start_queue_runners(session, coord=coord)
    108 
    109       features, target = session.run(results)
    110       self.assertAllEqual(features['a'], [0, 1, 0, 1])
    111       self.assertAllEqual(features['b'], [32, 33, 32, 33])
    112       self.assertAllEqual(target, [-32, -31, -32, -31])
    113 
    114       with self.assertRaises(errors.OutOfRangeError):
    115         session.run(results)
    116 
    117       coord.request_stop()
    118       coord.join(threads)
    119 
    120   def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self):
    121     if not HAS_PANDAS:
    122       return
    123     with self.test_session() as session:
    124       index = np.arange(100, 105)
    125       a = np.arange(5)
    126       b = np.arange(32, 37)
    127       x = pd.DataFrame({'a': a, 'b': b}, index=index)
    128       y = pd.Series(np.arange(-32, -27), index=index)
    129 
    130       input_fn = pandas_io.pandas_input_fn(
    131           x, y, batch_size=2, shuffle=False, num_epochs=1)
    132 
    133       results = input_fn()
    134 
    135       coord = coordinator.Coordinator()
    136       threads = queue_runner_impl.start_queue_runners(session, coord=coord)
    137 
    138       features, target = session.run(results)
    139       self.assertAllEqual(features['a'], [0, 1])
    140       self.assertAllEqual(features['b'], [32, 33])
    141       self.assertAllEqual(target, [-32, -31])
    142 
    143       features, target = session.run(results)
    144       self.assertAllEqual(features['a'], [2, 3])
    145       self.assertAllEqual(features['b'], [34, 35])
    146       self.assertAllEqual(target, [-30, -29])
    147 
    148       features, target = session.run(results)
    149       self.assertAllEqual(features['a'], [4])
    150       self.assertAllEqual(features['b'], [36])
    151       self.assertAllEqual(target, [-28])
    152 
    153       with self.assertRaises(errors.OutOfRangeError):
    154         session.run(results)
    155 
    156       coord.request_stop()
    157       coord.join(threads)
    158 
    159   def testPandasInputFn_OnlyX(self):
    160     if not HAS_PANDAS:
    161       return
    162     with self.test_session() as session:
    163       x, _ = self.makeTestDataFrame()
    164       input_fn = pandas_io.pandas_input_fn(
    165           x, y=None, batch_size=2, shuffle=False, num_epochs=1)
    166 
    167       features = self.callInputFnOnce(input_fn, session)
    168 
    169       self.assertAllEqual(features['a'], [0, 1])
    170       self.assertAllEqual(features['b'], [32, 33])
    171 
    172   def testPandasInputFn_ExcludesIndex(self):
    173     if not HAS_PANDAS:
    174       return
    175     with self.test_session() as session:
    176       x, y = self.makeTestDataFrame()
    177       input_fn = pandas_io.pandas_input_fn(
    178           x, y, batch_size=2, shuffle=False, num_epochs=1)
    179 
    180       features, _ = self.callInputFnOnce(input_fn, session)
    181 
    182       self.assertFalse('index' in features)
    183 
    184   def assertInputsCallableNTimes(self, input_fn, session, n):
    185     inputs = input_fn()
    186     coord = coordinator.Coordinator()
    187     threads = queue_runner_impl.start_queue_runners(session, coord=coord)
    188     for _ in range(n):
    189       session.run(inputs)
    190     with self.assertRaises(errors.OutOfRangeError):
    191       session.run(inputs)
    192     coord.request_stop()
    193     coord.join(threads)
    194 
    195   def testPandasInputFn_RespectsEpoch_NoShuffle(self):
    196     if not HAS_PANDAS:
    197       return
    198     with self.test_session() as session:
    199       x, y = self.makeTestDataFrame()
    200       input_fn = pandas_io.pandas_input_fn(
    201           x, y, batch_size=4, shuffle=False, num_epochs=1)
    202 
    203       self.assertInputsCallableNTimes(input_fn, session, 1)
    204 
    205   def testPandasInputFn_RespectsEpoch_WithShuffle(self):
    206     if not HAS_PANDAS:
    207       return
    208     with self.test_session() as session:
    209       x, y = self.makeTestDataFrame()
    210       input_fn = pandas_io.pandas_input_fn(
    211           x, y, batch_size=4, shuffle=True, num_epochs=1)
    212 
    213       self.assertInputsCallableNTimes(input_fn, session, 1)
    214 
    215   def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self):
    216     if not HAS_PANDAS:
    217       return
    218     with self.test_session() as session:
    219       x, y = self.makeTestDataFrame()
    220       input_fn = pandas_io.pandas_input_fn(
    221           x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2)
    222 
    223       self.assertInputsCallableNTimes(input_fn, session, 4)
    224 
    225   def testPandasInputFn_RespectsEpochUnevenBatches(self):
    226     if not HAS_PANDAS:
    227       return
    228     x, y = self.makeTestDataFrame()
    229     with self.test_session() as session:
    230       input_fn = pandas_io.pandas_input_fn(
    231           x, y, batch_size=3, shuffle=False, num_epochs=1)
    232 
    233       # Before the last batch, only one element of the epoch should remain.
    234       self.assertInputsCallableNTimes(input_fn, session, 2)
    235 
    236   def testPandasInputFn_Idempotent(self):
    237     if not HAS_PANDAS:
    238       return
    239     x, y = self.makeTestDataFrame()
    240     for _ in range(2):
    241       pandas_io.pandas_input_fn(
    242           x, y, batch_size=2, shuffle=False, num_epochs=1)()
    243     for _ in range(2):
    244       pandas_io.pandas_input_fn(
    245           x, y, batch_size=2, shuffle=True, num_epochs=1)()
    246 
    247 
    248 if __name__ == '__main__':
    249   test.main()
    250