Home | History | Annotate | Download | only in learn_io
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for pandas_io."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.learn.python.learn.learn_io import pandas_io
     24 from tensorflow.python.framework import errors
     25 from tensorflow.python.platform import test
     26 from tensorflow.python.training import coordinator
     27 from tensorflow.python.training import queue_runner_impl
     28 
     29 # pylint: disable=g-import-not-at-top
     30 try:
     31   import pandas as pd
     32   HAS_PANDAS = True
     33 except ImportError:
     34   HAS_PANDAS = False
     35 
     36 
     37 class PandasIoTest(test.TestCase):
     38 
     39   def makeTestDataFrame(self):
     40     index = np.arange(100, 104)
     41     a = np.arange(4)
     42     b = np.arange(32, 36)
     43     x = pd.DataFrame({'a': a, 'b': b}, index=index)
     44     y = pd.Series(np.arange(-32, -28), index=index)
     45     return x, y
     46 
     47   def callInputFnOnce(self, input_fn, session):
     48     results = input_fn()
     49     coord = coordinator.Coordinator()
     50     threads = queue_runner_impl.start_queue_runners(session, coord=coord)
     51     result_values = session.run(results)
     52     coord.request_stop()
     53     coord.join(threads)
     54     return result_values
     55 
     56   def testPandasInputFn_IndexMismatch(self):
     57     if not HAS_PANDAS:
     58       return
     59     x, _ = self.makeTestDataFrame()
     60     y_noindex = pd.Series(np.arange(-32, -28))
     61     with self.assertRaises(ValueError):
     62       pandas_io.pandas_input_fn(
     63           x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)
     64 
     65   def testPandasInputFn_ProducesExpectedOutputs(self):
     66     if not HAS_PANDAS:
     67       return
     68     with self.test_session() as session:
     69       x, y = self.makeTestDataFrame()
     70       input_fn = pandas_io.pandas_input_fn(
     71           x, y, batch_size=2, shuffle=False, num_epochs=1)
     72 
     73       features, target = self.callInputFnOnce(input_fn, session)
     74 
     75       self.assertAllEqual(features['a'], [0, 1])
     76       self.assertAllEqual(features['b'], [32, 33])
     77       self.assertAllEqual(target, [-32, -31])
     78 
     79   def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
     80     if not HAS_PANDAS:
     81       return
     82     with self.test_session() as session:
     83       index = np.arange(100, 102)
     84       a = np.arange(2)
     85       b = np.arange(32, 34)
     86       x = pd.DataFrame({'a': a, 'b': b}, index=index)
     87       y = pd.Series(np.arange(-32, -30), index=index)
     88       input_fn = pandas_io.pandas_input_fn(
     89           x, y, batch_size=128, shuffle=False, num_epochs=2)
     90 
     91       results = input_fn()
     92 
     93       coord = coordinator.Coordinator()
     94       threads = queue_runner_impl.start_queue_runners(session, coord=coord)
     95 
     96       features, target = session.run(results)
     97       self.assertAllEqual(features['a'], [0, 1, 0, 1])
     98       self.assertAllEqual(features['b'], [32, 33, 32, 33])
     99       self.assertAllEqual(target, [-32, -31, -32, -31])
    100 
    101       with self.assertRaises(errors.OutOfRangeError):
    102         session.run(results)
    103 
    104       coord.request_stop()
    105       coord.join(threads)
    106 
    107   def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self):
    108     if not HAS_PANDAS:
    109       return
    110     with self.test_session() as session:
    111       index = np.arange(100, 105)
    112       a = np.arange(5)
    113       b = np.arange(32, 37)
    114       x = pd.DataFrame({'a': a, 'b': b}, index=index)
    115       y = pd.Series(np.arange(-32, -27), index=index)
    116 
    117       input_fn = pandas_io.pandas_input_fn(
    118           x, y, batch_size=2, shuffle=False, num_epochs=1)
    119 
    120       results = input_fn()
    121 
    122       coord = coordinator.Coordinator()
    123       threads = queue_runner_impl.start_queue_runners(session, coord=coord)
    124 
    125       features, target = session.run(results)
    126       self.assertAllEqual(features['a'], [0, 1])
    127       self.assertAllEqual(features['b'], [32, 33])
    128       self.assertAllEqual(target, [-32, -31])
    129 
    130       features, target = session.run(results)
    131       self.assertAllEqual(features['a'], [2, 3])
    132       self.assertAllEqual(features['b'], [34, 35])
    133       self.assertAllEqual(target, [-30, -29])
    134 
    135       features, target = session.run(results)
    136       self.assertAllEqual(features['a'], [4])
    137       self.assertAllEqual(features['b'], [36])
    138       self.assertAllEqual(target, [-28])
    139 
    140       with self.assertRaises(errors.OutOfRangeError):
    141         session.run(results)
    142 
    143       coord.request_stop()
    144       coord.join(threads)
    145 
    146   def testPandasInputFn_OnlyX(self):
    147     if not HAS_PANDAS:
    148       return
    149     with self.test_session() as session:
    150       x, _ = self.makeTestDataFrame()
    151       input_fn = pandas_io.pandas_input_fn(
    152           x, y=None, batch_size=2, shuffle=False, num_epochs=1)
    153 
    154       features = self.callInputFnOnce(input_fn, session)
    155 
    156       self.assertAllEqual(features['a'], [0, 1])
    157       self.assertAllEqual(features['b'], [32, 33])
    158 
    159   def testPandasInputFn_ExcludesIndex(self):
    160     if not HAS_PANDAS:
    161       return
    162     with self.test_session() as session:
    163       x, y = self.makeTestDataFrame()
    164       input_fn = pandas_io.pandas_input_fn(
    165           x, y, batch_size=2, shuffle=False, num_epochs=1)
    166 
    167       features, _ = self.callInputFnOnce(input_fn, session)
    168 
    169       self.assertFalse('index' in features)
    170 
    171   def assertInputsCallableNTimes(self, input_fn, session, n):
    172     inputs = input_fn()
    173     coord = coordinator.Coordinator()
    174     threads = queue_runner_impl.start_queue_runners(session, coord=coord)
    175     for _ in range(n):
    176       session.run(inputs)
    177     with self.assertRaises(errors.OutOfRangeError):
    178       session.run(inputs)
    179     coord.request_stop()
    180     coord.join(threads)
    181 
    182   def testPandasInputFn_RespectsEpoch_NoShuffle(self):
    183     if not HAS_PANDAS:
    184       return
    185     with self.test_session() as session:
    186       x, y = self.makeTestDataFrame()
    187       input_fn = pandas_io.pandas_input_fn(
    188           x, y, batch_size=4, shuffle=False, num_epochs=1)
    189 
    190       self.assertInputsCallableNTimes(input_fn, session, 1)
    191 
    192   def testPandasInputFn_RespectsEpoch_WithShuffle(self):
    193     if not HAS_PANDAS:
    194       return
    195     with self.test_session() as session:
    196       x, y = self.makeTestDataFrame()
    197       input_fn = pandas_io.pandas_input_fn(
    198           x, y, batch_size=4, shuffle=True, num_epochs=1)
    199 
    200       self.assertInputsCallableNTimes(input_fn, session, 1)
    201 
    202   def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self):
    203     if not HAS_PANDAS:
    204       return
    205     with self.test_session() as session:
    206       x, y = self.makeTestDataFrame()
    207       input_fn = pandas_io.pandas_input_fn(
    208           x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2)
    209 
    210       self.assertInputsCallableNTimes(input_fn, session, 4)
    211 
    212   def testPandasInputFn_RespectsEpochUnevenBatches(self):
    213     if not HAS_PANDAS:
    214       return
    215     x, y = self.makeTestDataFrame()
    216     with self.test_session() as session:
    217       input_fn = pandas_io.pandas_input_fn(
    218           x, y, batch_size=3, shuffle=False, num_epochs=1)
    219 
    220       # Before the last batch, only one element of the epoch should remain.
    221       self.assertInputsCallableNTimes(input_fn, session, 2)
    222 
    223   def testPandasInputFn_Idempotent(self):
    224     if not HAS_PANDAS:
    225       return
    226     x, y = self.makeTestDataFrame()
    227     for _ in range(2):
    228       pandas_io.pandas_input_fn(
    229           x, y, batch_size=2, shuffle=False, num_epochs=1)()
    230     for _ in range(2):
    231       pandas_io.pandas_input_fn(
    232           x, y, batch_size=2, shuffle=True, num_epochs=1)()
    233 
    234 
    235 if __name__ == '__main__':
    236   test.main()
    237