1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Methods to allow generator of dict with numpy arrays.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from collections import Container 22 from types import FunctionType 23 from types import GeneratorType 24 25 from tensorflow.python.estimator.inputs.queues.feeding_functions import _enqueue_data as enqueue_data 26 27 28 def generator_input_fn(x, 29 target_key=None, 30 batch_size=128, 31 num_epochs=1, 32 shuffle=True, 33 queue_capacity=1000, 34 num_threads=1, 35 pad_value=None): 36 """Returns input function that returns dicts of numpy arrays 37 yielded from a generator. 38 39 It is assumed that every dict of numpy arrays yielded from the dictionary 40 represents a single sample. The generator should consume a single epoch of the 41 data. 42 43 This returns a function outputting `features` and `target` based on the dict 44 of numpy arrays. The dict `features` has the same keys as an element yielded 45 from x. 46 47 Example: 48 ```python 49 def generator(): 50 for index in range(10): 51 yield {'height': np.random.randint(32,36), 52 'age': np.random.randint(18, 80), 53 'label': np.ones(1)} 54 55 with tf.Session() as session: 56 input_fn = generator_io.generator_input_fn( 57 generator, target_key="label", batch_size=2, shuffle=False, 58 num_epochs=1) 59 ``` 60 61 Args: 62 x: Generator Function, returns a `Generator` that will yield the data 63 in `dict` of numpy arrays 64 target_key: String or Container of Strings, the key or Container of keys of 65 the numpy arrays in x dictionaries to use as target. 66 batch_size: Integer, size of batches to return. 67 num_epochs: Integer, number of epochs to iterate over data. If `None` will 68 run forever. 69 shuffle: Boolean, if True shuffles the queue. Avoid shuffle at prediction 70 time. 71 queue_capacity: Integer, size of queue to accumulate. 72 num_threads: Integer, number of threads used for reading and enqueueing. 73 pad_value: default value for dynamic padding of data samples, if provided. 74 75 Returns: 76 Function, that returns a feature `dict` with `Tensors` and an optional 77 label `dict` with `Tensors`, or if target_key is `str` label is a `Tensor` 78 79 Raises: 80 TypeError: `x` is not `FunctionType`. 81 TypeError: `x()` is not `GeneratorType`. 82 TypeError: `next(x())` is not `dict`. 83 TypeError: `target_key` is not `str` or `target_key` is not `Container` 84 of `str`. 85 KeyError: `target_key` not a key or `target_key[index]` not in next(`x()`). 86 KeyError: `key` mismatch between dicts emitted from `x()` 87 """ 88 if not isinstance(x, FunctionType): 89 raise TypeError( 90 'x must be generator function; got {}'.format(type(x).__name__)) 91 generator = x() 92 if not isinstance(generator, GeneratorType): 93 raise TypeError( 94 'x() must be generator; got {}'.format(type(generator).__name__)) 95 data = next(generator) 96 if not isinstance(data, dict): 97 raise TypeError('x() must yield dict; got {}'.format(type(data).__name__)) 98 input_keys = sorted(next(x()).keys()) 99 if target_key is not None: 100 if isinstance(target_key, str): 101 target_key = [target_key] 102 elif isinstance(target_key, Container): 103 for item in target_key: 104 if not isinstance(item, str): 105 raise TypeError('target_key must be str or Container of str; got {}'. 106 format(type(item).__name__)) 107 if item not in input_keys: 108 raise KeyError( 109 'target_key not in yielded dict. Expected {} keys; got {}'.format( 110 input_keys, item)) 111 else: 112 raise TypeError('target_key must be str or Container of str; got {}'. 113 format(type(target_key).__name__)) 114 115 def _generator_input_fn(): 116 """generator input function.""" 117 queue = enqueue_data( 118 x, 119 queue_capacity, 120 shuffle=shuffle, 121 num_threads=num_threads, 122 enqueue_size=batch_size, 123 num_epochs=num_epochs, 124 pad_value=pad_value) 125 126 features = (queue.dequeue_many(batch_size) 127 if num_epochs is None else queue.dequeue_up_to(batch_size)) 128 if not isinstance(features, list): 129 features = [features] 130 features = dict(zip(input_keys, features)) 131 if target_key is not None: 132 if len(target_key) > 1: 133 target = {key: features.pop(key) for key in target_key} 134 else: 135 target = features.pop(target_key[0]) 136 return features, target 137 return features 138 139 return _generator_input_fn 140