Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Functions for computing moving statistics."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.framework import ops
     21 from tensorflow.python.ops import array_ops
     22 from tensorflow.python.ops import init_ops
     23 from tensorflow.python.ops import math_ops
     24 from tensorflow.python.ops import state_ops
     25 from tensorflow.python.ops import variable_scope
     26 
     27 
     28 __all__ = [
     29     "assign_moving_mean_variance",
     30     "assign_log_moving_mean_exp",
     31     "moving_mean_variance",
     32 ]
     33 
     34 
     35 def assign_moving_mean_variance(
     36     mean_var, variance_var, value, decay, name=None):
     37   """Compute exponentially weighted moving {mean,variance} of a streaming value.
     38 
     39   The `value` updated exponentially weighted moving `mean_var` and
     40   `variance_var` are given by the following recurrence relations:
     41 
     42   ```python
     43   variance_var = decay * (variance_var + (1-decay) * (value - mean_var)**2)
     44   mean_var     = decay * mean_var + (1 - decay) * value
     45   ```
     46 
     47   Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses
     48   the lag-1 mean.
     49 
     50   For derivation justification, see equation 143 of:
     51     T. Finch, Feb 2009. "Incremental calculation of weighted mean and variance".
     52     http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf
     53 
     54   Args:
     55     mean_var: `float`-like `Variable` representing the exponentially weighted
     56       moving mean. Same shape as `variance_var` and `value`.
     57     variance_var: `float`-like `Variable` representing the
     58       exponentially weighted moving variance. Same shape as `mean_var` and
     59       `value`.
     60     value: `float`-like `Tensor`. Same shape as `mean_var` and `variance_var`.
     61     decay: A `float`-like `Tensor`. The moving mean decay. Typically close to
     62       `1.`, e.g., `0.999`.
     63     name: Optional name of the returned operation.
     64 
     65   Returns:
     66     mean_var: `Variable` representing the `value`-updated exponentially weighted
     67       moving mean.
     68     variance_var: `Variable` representing the `value`-updated
     69       exponentially weighted moving variance.
     70 
     71   Raises:
     72     TypeError: if `mean_var` does not have float type `dtype`.
     73     TypeError: if `mean_var`, `variance_var`, `value`, `decay` have different
     74       `base_dtype`.
     75   """
     76   with ops.name_scope(name, "assign_moving_mean_variance",
     77                       [variance_var, mean_var, value, decay]):
     78     with ops.colocate_with(variance_var):
     79       with ops.colocate_with(mean_var):
     80         base_dtype = mean_var.dtype.base_dtype
     81         if not base_dtype.is_floating:
     82           raise TypeError(
     83               "mean_var.base_dtype({}) does not have float type "
     84               "`dtype`.".format(base_dtype.name))
     85         if base_dtype != variance_var.dtype.base_dtype:
     86           raise TypeError(
     87               "mean_var.base_dtype({}) != variance_var.base_dtype({})".format(
     88                   base_dtype.name,
     89                   variance_var.dtype.base_dtype.name))
     90         value = ops.convert_to_tensor(value, dtype=base_dtype, name="value")
     91         decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay")
     92         delta = value - mean_var
     93         with ops.control_dependencies([delta]):
     94           mean_var = state_ops.assign_add(
     95               mean_var,
     96               (1. - decay) * delta)
     97           variance_var = state_ops.assign_sub(
     98               variance_var,
     99               (1. - decay) * (variance_var - decay * math_ops.square(delta)))
    100         return mean_var, variance_var
    101 
    102 
    103 def assign_log_moving_mean_exp(
    104     log_mean_exp_var, log_value, decay, name=None):
    105   """Compute the log of the exponentially weighted moving mean of the exp.
    106 
    107   If `log_value` is a draw from a stationary random variable, this function
    108   approximates `log(E[exp(log_value)])`, i.e., a weighted log-sum-exp. More
    109   precisely, a `tf.Variable`, `log_mean_exp_var`, is updated by `log_value`
    110   using the following identity:
    111 
    112   ```none
    113   log_mean_exp_var =
    114   = log(decay exp(log_mean_exp_var) + (1 - decay) exp(log_value))
    115   = log(exp(log_mean_exp_var + log(decay)) + exp(log_value + log1p(-decay)))
    116   = log_mean_exp_var
    117     + log(  exp(log_mean_exp_var   - log_mean_exp_var + log(decay))
    118           + exp(log_value - log_mean_exp_var + log1p(-decay)))
    119   = log_mean_exp_var
    120     + log_sum_exp([log(decay), log_value - log_mean_exp_var + log1p(-decay)]).
    121   ```
    122 
    123   In addition to numerical stability, this formulation is advantageous because
    124   `log_mean_exp_var` can be updated in a lock-free manner, i.e., using
    125   `assign_add`. (Note: the updates are not thread-safe; it's just that the
    126   update to the tf.Variable is presumed efficient due to being lock-free.)
    127 
    128   Args:
    129     log_mean_exp_var: `float`-like `Variable` representing the log of the
    130       exponentially weighted moving mean of the exp. Same shape as `log_value`.
    131     log_value: `float`-like `Tensor` representing a new (streaming) observation.
    132       Same shape as `log_mean_exp_var`.
    133     decay: A `float`-like `Tensor`. The moving mean decay. Typically close to
    134       `1.`, e.g., `0.999`.
    135     name: Optional name of the returned operation.
    136 
    137   Returns:
    138     log_mean_exp_var: A reference to the input 'Variable' tensor with the
    139       `log_value`-updated log of the exponentially weighted moving mean of exp.
    140 
    141   Raises:
    142     TypeError: if `log_mean_exp_var` does not have float type `dtype`.
    143     TypeError: if `log_mean_exp_var`, `log_value`, `decay` have different
    144       `base_dtype`.
    145   """
    146   with ops.name_scope(name, "assign_log_moving_mean_exp",
    147                       [log_mean_exp_var, log_value, decay]):
    148     # We want to update the variable in a numerically stable and lock-free way.
    149     # To do this, observe that variable `x` updated by `v` is:
    150     # x = log(w exp(x) + (1-w) exp(v))
    151     #   = log(exp(x + log(w)) + exp(v + log1p(-w)))
    152     #   = x + log(exp(x - x + log(w)) + exp(v - x + log1p(-w)))
    153     #   = x + lse([log(w), v - x + log1p(-w)])
    154     with ops.colocate_with(log_mean_exp_var):
    155       base_dtype = log_mean_exp_var.dtype.base_dtype
    156       if not base_dtype.is_floating:
    157         raise TypeError(
    158             "log_mean_exp_var.base_dtype({}) does not have float type "
    159             "`dtype`.".format(base_dtype.name))
    160       log_value = ops.convert_to_tensor(log_value, dtype=base_dtype,
    161                                         name="log_value")
    162       decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay")
    163       delta = (log_value - log_mean_exp_var)[array_ops.newaxis, ...]
    164       x = array_ops.concat([
    165           math_ops.log(decay) * array_ops.ones_like(delta),
    166           delta + math_ops.log1p(-decay)
    167       ], axis=0)
    168       x = math_ops.reduce_logsumexp(x, axis=0)
    169       return log_mean_exp_var.assign_add(x)
    170 
    171 
    172 def moving_mean_variance(value, decay, collections=None, name=None):
    173   """Compute exponentially weighted moving {mean,variance} of a streaming value.
    174 
    175   The exponentially-weighting moving `mean_var` and `variance_var` are updated
    176   by `value` according to the following recurrence:
    177 
    178   ```python
    179   variance_var = decay * (variance_var + (1-decay) * (value - mean_var)**2)
    180   mean_var     = decay * mean_var + (1 - decay) * value
    181   ```
    182 
    183   Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses
    184   the lag-`1` mean.
    185 
    186   For derivation justification, see equation 143 of:
    187     T. Finch, Feb 2009. "Incremental calculation of weighted mean and variance".
    188     http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf
    189 
    190   Unlike `assign_moving_mean_variance`, this function handles
    191   variable creation.
    192 
    193   Args:
    194     value: `float`-like `Tensor`. Same shape as `mean_var` and `variance_var`.
    195     decay: A `float`-like `Tensor`. The moving mean decay. Typically close to
    196       `1.`, e.g., `0.999`.
    197     collections: Python list of graph-collections keys to which the internal
    198       variables `mean_var` and `variance_var` are added.
    199       Default value is `[GraphKeys.GLOBAL_VARIABLES]`.
    200     name: Optional name of the returned operation.
    201 
    202   Returns:
    203     mean_var: `Variable` representing the `value`-updated exponentially weighted
    204       moving mean.
    205     variance_var: `Variable` representing the `value`-updated
    206       exponentially weighted moving variance.
    207 
    208   Raises:
    209     TypeError: if `value_var` does not have float type `dtype`.
    210     TypeError: if `value`, `decay` have different `base_dtype`.
    211   """
    212   if collections is None:
    213     collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    214   with variable_scope.variable_scope(
    215       name, "moving_mean_variance", [value, decay]):
    216     value = ops.convert_to_tensor(value, name="value")
    217     base_dtype = value.dtype.base_dtype
    218     if not base_dtype.is_floating:
    219       raise TypeError(
    220           "value.base_dtype({}) does not have float type `dtype`.".format(
    221               base_dtype.name))
    222     decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay")
    223     variance_var = variable_scope.get_variable(
    224         "moving_variance",
    225         shape=value.shape,
    226         dtype=value.dtype,
    227         initializer=init_ops.zeros_initializer(),
    228         trainable=False,
    229         collections=collections)
    230     mean_var = variable_scope.get_variable(
    231         "moving_mean",
    232         shape=value.shape,
    233         dtype=value.dtype,
    234         initializer=init_ops.zeros_initializer(),
    235         trainable=False,
    236         collections=collections)
    237     return assign_moving_mean_variance(
    238         mean_var, variance_var, value, decay)
    239