1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Resampling methods for batches of tensors.""" 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import dtypes 22 from tensorflow.python.framework import ops 23 from tensorflow.python.ops import array_ops 24 from tensorflow.python.ops import control_flow_ops 25 from tensorflow.python.ops import math_ops 26 from tensorflow.python.ops import random_ops 27 from tensorflow.python.ops import tensor_array_ops 28 from tensorflow.python.ops import variable_scope 29 from tensorflow.python.training import moving_averages 30 31 32 def _repeat_range(counts, name=None): 33 """Repeat integers given by range(len(counts)) each the given number of times. 34 35 Example behavior: 36 [0, 1, 2, 3] -> [1, 2, 2, 3, 3, 3] 37 38 Args: 39 counts: 1D tensor with dtype=int32. 40 name: optional name for operation. 41 42 Returns: 43 1D tensor with dtype=int32 and dynamic length giving the repeated integers. 44 """ 45 with ops.name_scope(name, 'repeat_range', [counts]) as scope: 46 counts = ops.convert_to_tensor(counts, name='counts') 47 48 def cond(unused_output, i): 49 return i < size 50 51 def body(output, i): 52 value = array_ops.fill(counts[i:i+1], i) 53 return (output.write(i, value), i + 1) 54 55 size = array_ops.shape(counts)[0] 56 init_output_array = tensor_array_ops.TensorArray( 57 dtype=dtypes.int32, size=size, infer_shape=False) 58 output_array, num_writes = control_flow_ops.while_loop( 59 cond, body, loop_vars=[init_output_array, 0]) 60 61 return control_flow_ops.cond( 62 num_writes > 0, 63 output_array.concat, 64 lambda: array_ops.zeros(shape=[0], dtype=dtypes.int32), 65 name=scope) 66 67 68 def resample_at_rate(inputs, rates, scope=None, seed=None, back_prop=False): 69 """Given `inputs` tensors, stochastically resamples each at a given rate. 70 71 For example, if the inputs are `[[a1, a2], [b1, b2]]` and the rates 72 tensor contains `[3, 1]`, then the return value may look like `[[a1, 73 a2, a1, a1], [b1, b2, b1, b1]]`. However, many other outputs are 74 possible, since this is stochastic -- averaged over many repeated 75 calls, each set of inputs should appear in the output `rate` times 76 the number of invocations. 77 78 Args: 79 inputs: A list of tensors, each of which has a shape of `[batch_size, ...]` 80 rates: A tensor of shape `[batch_size]` contiaining the resampling rates 81 for each input. 82 scope: Scope for the op. 83 seed: Random seed to use. 84 back_prop: Whether to allow back-propagation through this op. 85 86 Returns: 87 Selections from the input tensors. 88 """ 89 with ops.name_scope(scope, default_name='resample_at_rate', 90 values=list(inputs) + [rates]): 91 rates = ops.convert_to_tensor(rates, name='rates') 92 sample_counts = math_ops.cast( 93 random_ops.random_poisson(rates, (), rates.dtype, seed=seed), 94 dtypes.int32) 95 sample_indices = _repeat_range(sample_counts) 96 if not back_prop: 97 sample_indices = array_ops.stop_gradient(sample_indices) 98 return [array_ops.gather(x, sample_indices) for x in inputs] 99 100 101 def weighted_resample(inputs, weights, overall_rate, scope=None, 102 mean_decay=0.999, seed=None): 103 """Performs an approximate weighted resampling of `inputs`. 104 105 This method chooses elements from `inputs` where each item's rate of 106 selection is proportional to its value in `weights`, and the average 107 rate of selection across all inputs (and many invocations!) is 108 `overall_rate`. 109 110 Args: 111 inputs: A list of tensors whose first dimension is `batch_size`. 112 weights: A `[batch_size]`-shaped tensor with each batch member's weight. 113 overall_rate: Desired overall rate of resampling. 114 scope: Scope to use for the op. 115 mean_decay: How quickly to decay the running estimate of the mean weight. 116 seed: Random seed. 117 118 Returns: 119 A list of tensors exactly like `inputs`, but with an unknown (and 120 possibly zero) first dimension. 121 A tensor containing the effective resampling rate used for each output. 122 """ 123 # Algorithm: Just compute rates as weights/mean_weight * 124 # overall_rate. This way the average weight corresponds to the 125 # overall rate, and a weight twice the average has twice the rate, 126 # etc. 127 with ops.name_scope(scope, 'weighted_resample', inputs) as opscope: 128 # First: Maintain a running estimated mean weight, with zero debiasing 129 # enabled (by default) to avoid throwing the average off. 130 131 with variable_scope.variable_scope(scope, 'estimate_mean', inputs): 132 estimated_mean = variable_scope.get_local_variable( 133 'estimated_mean', 134 initializer=math_ops.cast(0, weights.dtype), 135 dtype=weights.dtype) 136 137 batch_mean = math_ops.reduce_mean(weights) 138 mean = moving_averages.assign_moving_average( 139 estimated_mean, batch_mean, mean_decay) 140 141 # Then, normalize the weights into rates using the mean weight and 142 # overall target rate: 143 rates = weights * overall_rate / mean 144 145 results = resample_at_rate([rates] + inputs, rates, 146 scope=opscope, seed=seed, back_prop=False) 147 148 return (results[1:], results[0]) 149