Home | History | Annotate | Download | only in learn_io
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Methods to allow generator of dict with numpy arrays."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from collections import Container
     22 from types import FunctionType
     23 from types import GeneratorType
     24 
     25 from tensorflow.python.estimator.inputs.queues.feeding_functions import _enqueue_data as enqueue_data
     26 
     27 
     28 def generator_input_fn(x,
     29                        target_key=None,
     30                        batch_size=128,
     31                        num_epochs=1,
     32                        shuffle=True,
     33                        queue_capacity=1000,
     34                        num_threads=1,
     35                        pad_value=None):
     36   """Returns input function that returns dicts of numpy arrays
     37      yielded from a generator.
     38 
     39   It is assumed that every dict of numpy arrays yielded from the dictionary
     40   represents a single sample. The generator should consume a single epoch of the
     41   data.
     42 
     43   This returns a function outputting `features` and `target` based on the dict
     44   of numpy arrays. The dict `features` has the same keys as an element yielded
     45   from x.
     46 
     47   Example:
     48     ```python
     49     def generator():
     50       for index in range(10):
     51         yield {'height': np.random.randint(32,36),
     52               'age': np.random.randint(18, 80),
     53               'label': np.ones(1)}
     54 
     55     with tf.Session() as session:
     56       input_fn = generator_io.generator_input_fn(
     57           generator, target_key="label", batch_size=2, shuffle=False,
     58           num_epochs=1)
     59     ```
     60 
     61   Args:
     62     x: Generator Function, returns a `Generator` that will yield the data
     63       in `dict` of numpy arrays
     64     target_key: String or Container of Strings, the key or Container of keys of
     65       the numpy arrays in x dictionaries to use as target.
     66     batch_size: Integer, size of batches to return.
     67     num_epochs: Integer, number of epochs to iterate over data. If `None` will
     68       run forever.
     69     shuffle: Boolean, if True shuffles the queue. Avoid shuffle at prediction
     70       time.
     71     queue_capacity: Integer, size of queue to accumulate.
     72     num_threads: Integer, number of threads used for reading and enqueueing.
     73     pad_value: default value for dynamic padding of data samples, if provided.
     74 
     75   Returns:
     76     Function, that returns a feature `dict` with `Tensors` and an optional
     77      label `dict` with `Tensors`, or if target_key is `str` label is a `Tensor`
     78 
     79   Raises:
     80     TypeError: `x` is not `FunctionType`.
     81     TypeError: `x()` is not `GeneratorType`.
     82     TypeError: `next(x())` is not `dict`.
     83     TypeError: `target_key` is not `str` or `target_key` is not `Container`
     84        of `str`.
     85     KeyError:  `target_key` not a key or `target_key[index]` not in next(`x()`).
     86     KeyError: `key` mismatch between dicts emitted from `x()`
     87   """
     88   if not isinstance(x, FunctionType):
     89     raise TypeError(
     90         'x must be generator function; got {}'.format(type(x).__name__))
     91   generator = x()
     92   if not isinstance(generator, GeneratorType):
     93     raise TypeError(
     94         'x() must be generator; got {}'.format(type(generator).__name__))
     95   data = next(generator)
     96   if not isinstance(data, dict):
     97     raise TypeError('x() must yield dict; got {}'.format(type(data).__name__))
     98   input_keys = sorted(next(x()).keys())
     99   if target_key is not None:
    100     if isinstance(target_key, str):
    101       target_key = [target_key]
    102     elif isinstance(target_key, Container):
    103       for item in target_key:
    104         if not isinstance(item, str):
    105           raise TypeError('target_key must be str or Container of str; got {}'.
    106                           format(type(item).__name__))
    107         if item not in input_keys:
    108           raise KeyError(
    109               'target_key not in yielded dict. Expected {} keys; got {}'.format(
    110                   input_keys, item))
    111     else:
    112       raise TypeError('target_key must be str or Container of str; got {}'.
    113                       format(type(target_key).__name__))
    114 
    115   def _generator_input_fn():
    116     """generator input function."""
    117     queue = enqueue_data(
    118         x,
    119         queue_capacity,
    120         shuffle=shuffle,
    121         num_threads=num_threads,
    122         enqueue_size=batch_size,
    123         num_epochs=num_epochs,
    124         pad_value=pad_value)
    125 
    126     features = (queue.dequeue_many(batch_size)
    127                 if num_epochs is None else queue.dequeue_up_to(batch_size))
    128     if not isinstance(features, list):
    129       features = [features]
    130     features = dict(zip(input_keys, features))
    131     if target_key is not None:
    132       if len(target_key) > 1:
    133         target = {key: features.pop(key) for key in target_key}
    134       else:
    135         target = features.pop(target_key[0])
    136       return features, target
    137     return features
    138 
    139   return _generator_input_fn
    140