1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Resampling dataset transformations.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.contrib.data.python.ops import batching 23 from tensorflow.contrib.data.python.ops import scan_ops 24 from tensorflow.python.data.ops import dataset_ops 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import ops 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import control_flow_ops 29 from tensorflow.python.ops import logging_ops 30 from tensorflow.python.ops import math_ops 31 from tensorflow.python.ops import random_ops 32 33 34 def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): 35 """A transformation that resamples a dataset to achieve a target distribution. 36 37 **NOTE** Resampling is performed via rejection sampling; some fraction 38 of the input values will be dropped. 39 40 Args: 41 class_func: A function mapping an element of the input dataset to a scalar 42 `tf.int32` tensor. Values should be in `[0, num_classes)`. 43 target_dist: A floating point type tensor, shaped `[num_classes]`. 44 initial_dist: (Optional.) A floating point type tensor, shaped 45 `[num_classes]`. If not provided, the true class distribution is 46 estimated live in a streaming fashion. 47 seed: (Optional.) Python integer seed for the resampler. 48 49 Returns: 50 A `Dataset` transformation function, which can be passed to 51 @{tf.data.Dataset.apply}. 52 """ 53 54 def _apply_fn(dataset): 55 """Function from `Dataset` to `Dataset` that applies the transformation.""" 56 dist_estimation_batch_size = 32 57 target_dist_t = ops.convert_to_tensor(target_dist, name="initial_dist") 58 class_values_ds = dataset.map(class_func) 59 if initial_dist is not None: 60 initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") 61 acceptance_dist = _calculate_acceptance_probs(initial_dist_t, 62 target_dist_t) 63 initial_dist_ds = dataset_ops.Dataset.from_tensors( 64 initial_dist_t).repeat() 65 acceptance_dist_ds = dataset_ops.Dataset.from_tensors( 66 acceptance_dist).repeat() 67 else: 68 num_classes = (target_dist_t.shape[0].value or 69 array_ops.shape(target_dist_t)[0]) 70 smoothing_constant = 10 71 initial_examples_per_class_seen = array_ops.fill( 72 [num_classes], np.int64(smoothing_constant)) 73 74 def update_estimate_and_tile(num_examples_per_class_seen, c): 75 updated_examples_per_class_seen, dist = _estimate_data_distribution( 76 c, num_examples_per_class_seen) 77 tiled_dist = array_ops.tile( 78 array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) 79 return updated_examples_per_class_seen, tiled_dist 80 81 initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) 82 .apply(scan_ops.scan(initial_examples_per_class_seen, 83 update_estimate_and_tile)) 84 .apply(batching.unbatch())) 85 acceptance_dist_ds = initial_dist_ds.map( 86 lambda initial: _calculate_acceptance_probs(initial, target_dist_t)) 87 88 def maybe_warn_on_large_rejection(accept_dist, initial_dist): 89 proportion_rejected = math_ops.reduce_sum( 90 (1 - accept_dist) * initial_dist) 91 return control_flow_ops.cond( 92 math_ops.less(proportion_rejected, .5), 93 lambda: accept_dist, 94 lambda: logging_ops.Print( # pylint: disable=g-long-lambda 95 accept_dist, [proportion_rejected, initial_dist, accept_dist], 96 message="Proportion of examples rejected by sampler is high: ", 97 summarize=100, 98 first_n=10)) 99 100 acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds, 101 initial_dist_ds)) 102 .map(maybe_warn_on_large_rejection)) 103 104 current_probabilities_ds = dataset_ops.Dataset.zip( 105 (acceptance_dist_ds, class_values_ds)).map(array_ops.gather) 106 filtered_ds = ( 107 dataset_ops.Dataset.zip((class_values_ds, current_probabilities_ds, 108 dataset)) 109 .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) 110 return filtered_ds.map(lambda class_value, _, data: (class_value, data)) 111 112 return _apply_fn 113 114 115 def _calculate_acceptance_probs(initial_probs, target_probs): 116 """Calculate the per-class acceptance rates. 117 118 Args: 119 initial_probs: The class probabilities of the data. 120 target_probs: The desired class proportion in minibatches. 121 Returns: 122 A list of the per-class acceptance probabilities. 123 124 This method is based on solving the following analysis: 125 126 Let F be the probability of a rejection (on any example). 127 Let p_i be the proportion of examples in the data in class i (init_probs) 128 Let a_i is the rate the rejection sampler should *accept* class i 129 Let t_i is the target proportion in the minibatches for class i (target_probs) 130 131 ``` 132 F = sum_i(p_i * (1-a_i)) 133 = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1 134 ``` 135 136 An example with class `i` will be accepted if `k` rejections occur, then an 137 example with class `i` is seen by the rejector, and it is accepted. This can 138 be written as follows: 139 140 ``` 141 t_i = sum_k=0^inf(F^k * p_i * a_i) 142 = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1 143 = p_i * a_i / sum_j(p_j * a_j) using F from above 144 ``` 145 146 Note that the following constraints hold: 147 ``` 148 0 <= p_i <= 1, sum_i(p_i) = 1 149 0 <= a_i <= 1 150 0 <= t_i <= 1, sum_i(t_i) = 1 151 ``` 152 153 154 A solution for a_i in terms of the other variabes is the following: 155 ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` 156 """ 157 # Add tiny to initial_probs to avoid divide by zero. 158 denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) 159 ratio_l = target_probs / denom 160 161 # Calculate list of acceptance probabilities. 162 max_ratio = math_ops.reduce_max(ratio_l) 163 return ratio_l / max_ratio 164 165 166 def _estimate_data_distribution(c, num_examples_per_class_seen): 167 """Estimate data distribution as labels are seen. 168 169 Args: 170 c: The class labels. Type `int32`, shape `[batch_size]`. 171 num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, 172 containing counts. 173 174 Returns: 175 num_examples_per_lass_seen: Updated counts. Type `int64`, shape 176 `[num_classes]`. 177 dist: The updated distribution. Type `float32`, shape `[num_classes]`. 178 """ 179 num_classes = num_examples_per_class_seen.get_shape()[0].value 180 # Update the class-count based on what labels are seen in batch. 181 num_examples_per_class_seen = math_ops.add( 182 num_examples_per_class_seen, math_ops.reduce_sum( 183 array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) 184 init_prob_estimate = math_ops.truediv( 185 num_examples_per_class_seen, 186 math_ops.reduce_sum(num_examples_per_class_seen)) 187 dist = math_ops.cast(init_prob_estimate, dtypes.float32) 188 return num_examples_per_class_seen, dist 189