Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Resampling dataset transformations."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.contrib.data.python.ops import batching
     23 from tensorflow.contrib.data.python.ops import scan_ops
     24 from tensorflow.python.data.ops import dataset_ops
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import control_flow_ops
     29 from tensorflow.python.ops import logging_ops
     30 from tensorflow.python.ops import math_ops
     31 from tensorflow.python.ops import random_ops
     32 
     33 
     34 def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
     35   """A transformation that resamples a dataset to achieve a target distribution.
     36 
     37   **NOTE** Resampling is performed via rejection sampling; some fraction
     38   of the input values will be dropped.
     39 
     40   Args:
     41     class_func: A function mapping an element of the input dataset to a scalar
     42       `tf.int32` tensor. Values should be in `[0, num_classes)`.
     43     target_dist: A floating point type tensor, shaped `[num_classes]`.
     44     initial_dist: (Optional.)  A floating point type tensor, shaped
     45       `[num_classes]`.  If not provided, the true class distribution is
     46       estimated live in a streaming fashion.
     47     seed: (Optional.) Python integer seed for the resampler.
     48 
     49   Returns:
     50     A `Dataset` transformation function, which can be passed to
     51     @{tf.data.Dataset.apply}.
     52   """
     53 
     54   def _apply_fn(dataset):
     55     """Function from `Dataset` to `Dataset` that applies the transformation."""
     56     dist_estimation_batch_size = 32
     57     target_dist_t = ops.convert_to_tensor(target_dist, name="initial_dist")
     58     class_values_ds = dataset.map(class_func)
     59     if initial_dist is not None:
     60       initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
     61       acceptance_dist = _calculate_acceptance_probs(initial_dist_t,
     62                                                     target_dist_t)
     63       initial_dist_ds = dataset_ops.Dataset.from_tensors(
     64           initial_dist_t).repeat()
     65       acceptance_dist_ds = dataset_ops.Dataset.from_tensors(
     66           acceptance_dist).repeat()
     67     else:
     68       num_classes = (target_dist_t.shape[0].value or
     69                      array_ops.shape(target_dist_t)[0])
     70       smoothing_constant = 10
     71       initial_examples_per_class_seen = array_ops.fill(
     72           [num_classes], np.int64(smoothing_constant))
     73 
     74       def update_estimate_and_tile(num_examples_per_class_seen, c):
     75         updated_examples_per_class_seen, dist = _estimate_data_distribution(
     76             c, num_examples_per_class_seen)
     77         tiled_dist = array_ops.tile(
     78             array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
     79         return updated_examples_per_class_seen, tiled_dist
     80 
     81       initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
     82                          .apply(scan_ops.scan(initial_examples_per_class_seen,
     83                                               update_estimate_and_tile))
     84                          .apply(batching.unbatch()))
     85       acceptance_dist_ds = initial_dist_ds.map(
     86           lambda initial: _calculate_acceptance_probs(initial, target_dist_t))
     87 
     88     def maybe_warn_on_large_rejection(accept_dist, initial_dist):
     89       proportion_rejected = math_ops.reduce_sum(
     90           (1 - accept_dist) * initial_dist)
     91       return control_flow_ops.cond(
     92           math_ops.less(proportion_rejected, .5),
     93           lambda: accept_dist,
     94           lambda: logging_ops.Print(  # pylint: disable=g-long-lambda
     95               accept_dist, [proportion_rejected, initial_dist, accept_dist],
     96               message="Proportion of examples rejected by sampler is high: ",
     97               summarize=100,
     98               first_n=10))
     99 
    100     acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds,
    101                                                    initial_dist_ds))
    102                           .map(maybe_warn_on_large_rejection))
    103 
    104     current_probabilities_ds = dataset_ops.Dataset.zip(
    105         (acceptance_dist_ds, class_values_ds)).map(array_ops.gather)
    106     filtered_ds = (
    107         dataset_ops.Dataset.zip((class_values_ds, current_probabilities_ds,
    108                                  dataset))
    109         .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
    110     return filtered_ds.map(lambda class_value, _, data: (class_value, data))
    111 
    112   return _apply_fn
    113 
    114 
    115 def _calculate_acceptance_probs(initial_probs, target_probs):
    116   """Calculate the per-class acceptance rates.
    117 
    118   Args:
    119     initial_probs: The class probabilities of the data.
    120     target_probs: The desired class proportion in minibatches.
    121   Returns:
    122     A list of the per-class acceptance probabilities.
    123 
    124   This method is based on solving the following analysis:
    125 
    126   Let F be the probability of a rejection (on any example).
    127   Let p_i be the proportion of examples in the data in class i (init_probs)
    128   Let a_i is the rate the rejection sampler should *accept* class i
    129   Let t_i is the target proportion in the minibatches for class i (target_probs)
    130 
    131   ```
    132   F = sum_i(p_i * (1-a_i))
    133     = 1 - sum_i(p_i * a_i)     using sum_i(p_i) = 1
    134   ```
    135 
    136   An example with class `i` will be accepted if `k` rejections occur, then an
    137   example with class `i` is seen by the rejector, and it is accepted. This can
    138   be written as follows:
    139 
    140   ```
    141   t_i = sum_k=0^inf(F^k * p_i * a_i)
    142       = p_i * a_j / (1 - F)    using geometric series identity, since 0 <= F < 1
    143       = p_i * a_i / sum_j(p_j * a_j)        using F from above
    144   ```
    145 
    146   Note that the following constraints hold:
    147   ```
    148   0 <= p_i <= 1, sum_i(p_i) = 1
    149   0 <= a_i <= 1
    150   0 <= t_i <= 1, sum_i(t_i) = 1
    151   ```
    152 
    153 
    154   A solution for a_i in terms of the other variabes is the following:
    155     ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
    156   """
    157   # Add tiny to initial_probs to avoid divide by zero.
    158   denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
    159   ratio_l = target_probs / denom
    160 
    161   # Calculate list of acceptance probabilities.
    162   max_ratio = math_ops.reduce_max(ratio_l)
    163   return ratio_l / max_ratio
    164 
    165 
    166 def _estimate_data_distribution(c, num_examples_per_class_seen):
    167   """Estimate data distribution as labels are seen.
    168 
    169   Args:
    170     c: The class labels.  Type `int32`, shape `[batch_size]`.
    171     num_examples_per_class_seen: Type `int64`, shape `[num_classes]`,
    172       containing counts.
    173 
    174   Returns:
    175     num_examples_per_lass_seen: Updated counts.  Type `int64`, shape
    176       `[num_classes]`.
    177     dist: The updated distribution.  Type `float32`, shape `[num_classes]`.
    178   """
    179   num_classes = num_examples_per_class_seen.get_shape()[0].value
    180   # Update the class-count based on what labels are seen in batch.
    181   num_examples_per_class_seen = math_ops.add(
    182       num_examples_per_class_seen, math_ops.reduce_sum(
    183           array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
    184   init_prob_estimate = math_ops.truediv(
    185       num_examples_per_class_seen,
    186       math_ops.reduce_sum(num_examples_per_class_seen))
    187   dist = math_ops.cast(init_prob_estimate, dtypes.float32)
    188   return num_examples_per_class_seen, dist
    189