Home | History | Annotate | Download | only in training
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Resampling methods for batches of tensors."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import dtypes
     22 from tensorflow.python.framework import ops
     23 from tensorflow.python.ops import array_ops
     24 from tensorflow.python.ops import control_flow_ops
     25 from tensorflow.python.ops import math_ops
     26 from tensorflow.python.ops import random_ops
     27 from tensorflow.python.ops import tensor_array_ops
     28 from tensorflow.python.ops import variable_scope
     29 from tensorflow.python.training import moving_averages
     30 
     31 
     32 def _repeat_range(counts, name=None):
     33   """Repeat integers given by range(len(counts)) each the given number of times.
     34 
     35   Example behavior:
     36   [0, 1, 2, 3] -> [1, 2, 2, 3, 3, 3]
     37 
     38   Args:
     39     counts: 1D tensor with dtype=int32.
     40     name: optional name for operation.
     41 
     42   Returns:
     43     1D tensor with dtype=int32 and dynamic length giving the repeated integers.
     44   """
     45   with ops.name_scope(name, 'repeat_range', [counts]) as scope:
     46     counts = ops.convert_to_tensor(counts, name='counts')
     47 
     48     def cond(unused_output, i):
     49       return i < size
     50 
     51     def body(output, i):
     52       value = array_ops.fill(counts[i:i+1], i)
     53       return (output.write(i, value), i + 1)
     54 
     55     size = array_ops.shape(counts)[0]
     56     init_output_array = tensor_array_ops.TensorArray(
     57         dtype=dtypes.int32, size=size, infer_shape=False)
     58     output_array, num_writes = control_flow_ops.while_loop(
     59         cond, body, loop_vars=[init_output_array, 0])
     60 
     61     return control_flow_ops.cond(
     62         num_writes > 0,
     63         output_array.concat,
     64         lambda: array_ops.zeros(shape=[0], dtype=dtypes.int32),
     65         name=scope)
     66 
     67 
     68 def resample_at_rate(inputs, rates, scope=None, seed=None, back_prop=False):
     69   """Given `inputs` tensors, stochastically resamples each at a given rate.
     70 
     71   For example, if the inputs are `[[a1, a2], [b1, b2]]` and the rates
     72   tensor contains `[3, 1]`, then the return value may look like `[[a1,
     73   a2, a1, a1], [b1, b2, b1, b1]]`. However, many other outputs are
     74   possible, since this is stochastic -- averaged over many repeated
     75   calls, each set of inputs should appear in the output `rate` times
     76   the number of invocations.
     77 
     78   Args:
     79     inputs: A list of tensors, each of which has a shape of `[batch_size, ...]`
     80     rates: A tensor of shape `[batch_size]` contiaining the resampling rates
     81        for each input.
     82     scope: Scope for the op.
     83     seed: Random seed to use.
     84     back_prop: Whether to allow back-propagation through this op.
     85 
     86   Returns:
     87     Selections from the input tensors.
     88   """
     89   with ops.name_scope(scope, default_name='resample_at_rate',
     90                       values=list(inputs) + [rates]):
     91     rates = ops.convert_to_tensor(rates, name='rates')
     92     sample_counts = math_ops.cast(
     93         random_ops.random_poisson(rates, (), rates.dtype, seed=seed),
     94         dtypes.int32)
     95     sample_indices = _repeat_range(sample_counts)
     96     if not back_prop:
     97       sample_indices = array_ops.stop_gradient(sample_indices)
     98     return [array_ops.gather(x, sample_indices) for x in inputs]
     99 
    100 
    101 def weighted_resample(inputs, weights, overall_rate, scope=None,
    102                       mean_decay=0.999, seed=None):
    103   """Performs an approximate weighted resampling of `inputs`.
    104 
    105   This method chooses elements from `inputs` where each item's rate of
    106   selection is proportional to its value in `weights`, and the average
    107   rate of selection across all inputs (and many invocations!) is
    108   `overall_rate`.
    109 
    110   Args:
    111     inputs: A list of tensors whose first dimension is `batch_size`.
    112     weights: A `[batch_size]`-shaped tensor with each batch member's weight.
    113     overall_rate: Desired overall rate of resampling.
    114     scope: Scope to use for the op.
    115     mean_decay: How quickly to decay the running estimate of the mean weight.
    116     seed: Random seed.
    117 
    118   Returns:
    119     A list of tensors exactly like `inputs`, but with an unknown (and
    120       possibly zero) first dimension.
    121     A tensor containing the effective resampling rate used for each output.
    122   """
    123   # Algorithm: Just compute rates as weights/mean_weight *
    124   # overall_rate. This way the average weight corresponds to the
    125   # overall rate, and a weight twice the average has twice the rate,
    126   # etc.
    127   with ops.name_scope(scope, 'weighted_resample', inputs) as opscope:
    128     # First: Maintain a running estimated mean weight, with zero debiasing
    129     # enabled (by default) to avoid throwing the average off.
    130 
    131     with variable_scope.variable_scope(scope, 'estimate_mean', inputs):
    132       estimated_mean = variable_scope.get_local_variable(
    133           'estimated_mean',
    134           initializer=math_ops.cast(0, weights.dtype),
    135           dtype=weights.dtype)
    136 
    137       batch_mean = math_ops.reduce_mean(weights)
    138       mean = moving_averages.assign_moving_average(
    139           estimated_mean, batch_mean, mean_decay)
    140 
    141     # Then, normalize the weights into rates using the mean weight and
    142     # overall target rate:
    143     rates = weights * overall_rate / mean
    144 
    145     results = resample_at_rate([rates] + inputs, rates,
    146                                scope=opscope, seed=seed, back_prop=False)
    147 
    148     return (results[1:], results[0])
    149