1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Methods to allow pandas.DataFrame.""" 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import numpy as np 23 from tensorflow.python.estimator.inputs.queues import feeding_functions 24 from tensorflow.python.util.tf_export import tf_export 25 26 try: 27 # pylint: disable=g-import-not-at-top 28 # pylint: disable=unused-import 29 import pandas as pd 30 HAS_PANDAS = True 31 except IOError: 32 # Pandas writes a temporary file during import. If it fails, don't use pandas. 33 HAS_PANDAS = False 34 except ImportError: 35 HAS_PANDAS = False 36 37 38 @tf_export('estimator.inputs.pandas_input_fn') 39 def pandas_input_fn(x, 40 y=None, 41 batch_size=128, 42 num_epochs=1, 43 shuffle=None, 44 queue_capacity=1000, 45 num_threads=1, 46 target_column='target'): 47 """Returns input function that would feed Pandas DataFrame into the model. 48 49 Note: `y`'s index must match `x`'s index. 50 51 Args: 52 x: pandas `DataFrame` object. 53 y: pandas `Series` object. `None` if absent. 54 batch_size: int, size of batches to return. 55 num_epochs: int, number of epochs to iterate over data. If not `None`, 56 read attempts that would exceed this value will raise `OutOfRangeError`. 57 shuffle: bool, whether to read the records in random order. 58 queue_capacity: int, size of the read queue. If `None`, it will be set 59 roughly to the size of `x`. 60 num_threads: Integer, number of threads used for reading and enqueueing. In 61 order to have predicted and repeatable order of reading and enqueueing, 62 such as in prediction and evaluation mode, `num_threads` should be 1. 63 target_column: str, name to give the target column `y`. 64 65 Returns: 66 Function, that has signature of ()->(dict of `features`, `target`) 67 68 Raises: 69 ValueError: if `x` already contains a column with the same name as `y`, or 70 if the indexes of `x` and `y` don't match. 71 TypeError: `shuffle` is not bool. 72 """ 73 if not HAS_PANDAS: 74 raise TypeError( 75 'pandas_input_fn should not be called without pandas installed') 76 77 if not isinstance(shuffle, bool): 78 raise TypeError('shuffle must be explicitly set as boolean; ' 79 'got {}'.format(shuffle)) 80 81 x = x.copy() 82 if y is not None: 83 if target_column in x: 84 raise ValueError( 85 'Cannot use name %s for target column: DataFrame already has a ' 86 'column with that name: %s' % (target_column, x.columns)) 87 if not np.array_equal(x.index, y.index): 88 raise ValueError('Index for x and y are mismatched.\nIndex for x: %s\n' 89 'Index for y: %s\n' % (x.index, y.index)) 90 x[target_column] = y 91 92 # TODO(mdan): These are memory copies. We probably don't need 4x slack space. 93 # The sizes below are consistent with what I've seen elsewhere. 94 if queue_capacity is None: 95 if shuffle: 96 queue_capacity = 4 * len(x) 97 else: 98 queue_capacity = len(x) 99 min_after_dequeue = max(queue_capacity / 4, 1) 100 101 def input_fn(): 102 """Pandas input function.""" 103 queue = feeding_functions._enqueue_data( # pylint: disable=protected-access 104 x, 105 queue_capacity, 106 shuffle=shuffle, 107 min_after_dequeue=min_after_dequeue, 108 num_threads=num_threads, 109 enqueue_size=batch_size, 110 num_epochs=num_epochs) 111 if num_epochs is None: 112 features = queue.dequeue_many(batch_size) 113 else: 114 features = queue.dequeue_up_to(batch_size) 115 assert len(features) == len(x.columns) + 1, ('Features should have one ' 116 'extra element for the index.') 117 features = features[1:] 118 features = dict(zip(list(x.columns), features)) 119 if y is not None: 120 target = features.pop(target_column) 121 return features, target 122 return features 123 return input_fn 124