Home | History | Annotate | Download | only in inputs
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Methods to allow pandas.DataFrame."""
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import numpy as np
     23 from tensorflow.python.estimator.inputs.queues import feeding_functions
     24 from tensorflow.python.util.tf_export import tf_export
     25 
     26 try:
     27   # pylint: disable=g-import-not-at-top
     28   # pylint: disable=unused-import
     29   import pandas as pd
     30   HAS_PANDAS = True
     31 except IOError:
     32   # Pandas writes a temporary file during import. If it fails, don't use pandas.
     33   HAS_PANDAS = False
     34 except ImportError:
     35   HAS_PANDAS = False
     36 
     37 
     38 @tf_export('estimator.inputs.pandas_input_fn')
     39 def pandas_input_fn(x,
     40                     y=None,
     41                     batch_size=128,
     42                     num_epochs=1,
     43                     shuffle=None,
     44                     queue_capacity=1000,
     45                     num_threads=1,
     46                     target_column='target'):
     47   """Returns input function that would feed Pandas DataFrame into the model.
     48 
     49   Note: `y`'s index must match `x`'s index.
     50 
     51   Args:
     52     x: pandas `DataFrame` object.
     53     y: pandas `Series` object. `None` if absent.
     54     batch_size: int, size of batches to return.
     55     num_epochs: int, number of epochs to iterate over data. If not `None`,
     56       read attempts that would exceed this value will raise `OutOfRangeError`.
     57     shuffle: bool, whether to read the records in random order.
     58     queue_capacity: int, size of the read queue. If `None`, it will be set
     59       roughly to the size of `x`.
     60     num_threads: Integer, number of threads used for reading and enqueueing. In
     61       order to have predicted and repeatable order of reading and enqueueing,
     62       such as in prediction and evaluation mode, `num_threads` should be 1.
     63     target_column: str, name to give the target column `y`.
     64 
     65   Returns:
     66     Function, that has signature of ()->(dict of `features`, `target`)
     67 
     68   Raises:
     69     ValueError: if `x` already contains a column with the same name as `y`, or
     70       if the indexes of `x` and `y` don't match.
     71     TypeError: `shuffle` is not bool.
     72   """
     73   if not HAS_PANDAS:
     74     raise TypeError(
     75         'pandas_input_fn should not be called without pandas installed')
     76 
     77   if not isinstance(shuffle, bool):
     78     raise TypeError('shuffle must be explicitly set as boolean; '
     79                     'got {}'.format(shuffle))
     80 
     81   x = x.copy()
     82   if y is not None:
     83     if target_column in x:
     84       raise ValueError(
     85           'Cannot use name %s for target column: DataFrame already has a '
     86           'column with that name: %s' % (target_column, x.columns))
     87     if not np.array_equal(x.index, y.index):
     88       raise ValueError('Index for x and y are mismatched.\nIndex for x: %s\n'
     89                        'Index for y: %s\n' % (x.index, y.index))
     90     x[target_column] = y
     91 
     92   # TODO(mdan): These are memory copies. We probably don't need 4x slack space.
     93   # The sizes below are consistent with what I've seen elsewhere.
     94   if queue_capacity is None:
     95     if shuffle:
     96       queue_capacity = 4 * len(x)
     97     else:
     98       queue_capacity = len(x)
     99   min_after_dequeue = max(queue_capacity / 4, 1)
    100 
    101   def input_fn():
    102     """Pandas input function."""
    103     queue = feeding_functions._enqueue_data(  # pylint: disable=protected-access
    104         x,
    105         queue_capacity,
    106         shuffle=shuffle,
    107         min_after_dequeue=min_after_dequeue,
    108         num_threads=num_threads,
    109         enqueue_size=batch_size,
    110         num_epochs=num_epochs)
    111     if num_epochs is None:
    112       features = queue.dequeue_many(batch_size)
    113     else:
    114       features = queue.dequeue_up_to(batch_size)
    115     assert len(features) == len(x.columns) + 1, ('Features should have one '
    116                                                  'extra element for the index.')
    117     features = features[1:]
    118     features = dict(zip(list(x.columns), features))
    119     if y is not None:
    120       target = features.pop(target_column)
    121       return features, target
    122     return features
    123   return input_fn
    124