Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """A helper class for inferring Distribution shape."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import contextlib
     21 
     22 from tensorflow.python.framework import dtypes
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.framework import tensor_util
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import check_ops
     27 from tensorflow.python.ops import control_flow_ops
     28 from tensorflow.python.ops import math_ops
     29 from tensorflow.python.ops.distributions import util as distribution_util
     30 from tensorflow.python.util import deprecation
     31 
     32 
     33 class _DistributionShape(object):
     34   """Manage and manipulate `Distribution` shape.
     35 
     36   #### Terminology
     37 
     38   Recall that a `Tensor` has:
     39     - `shape`: size of `Tensor` dimensions,
     40     - `ndims`: size of `shape`; number of `Tensor` dimensions,
     41     - `dims`: indexes into `shape`; useful for transpose, reduce.
     42 
     43   `Tensor`s sampled from a `Distribution` can be partitioned by `sample_dims`,
     44   `batch_dims`, and `event_dims`. To understand the semantics of these
     45   dimensions, consider when two of the three are fixed and the remaining
     46   is varied:
     47     - `sample_dims`: indexes independent draws from identical
     48                      parameterizations of the `Distribution`.
     49     - `batch_dims`:  indexes independent draws from non-identical
     50                      parameterizations of the `Distribution`.
     51     - `event_dims`:  indexes event coordinates from one sample.
     52 
     53   The `sample`, `batch`, and `event` dimensions constitute the entirety of a
     54   `Distribution` `Tensor`'s shape.
     55 
     56   The dimensions are always in `sample`, `batch`, `event` order.
     57 
     58   #### Purpose
     59 
     60   This class partitions `Tensor` notions of `shape`, `ndims`, and `dims` into
     61   `Distribution` notions of `sample,` `batch,` and `event` dimensions. That
     62   is, it computes any of:
     63 
     64   ```
     65   sample_shape     batch_shape     event_shape
     66   sample_dims      batch_dims      event_dims
     67   sample_ndims     batch_ndims     event_ndims
     68   ```
     69 
     70   for a given `Tensor`, e.g., the result of
     71   `Distribution.sample(sample_shape=...)`.
     72 
     73   For a given `Tensor`, this class computes the above table using minimal
     74   information: `batch_ndims` and `event_ndims`.
     75 
     76   #### Examples
     77 
     78   We show examples of distribution shape semantics.
     79 
     80     - Sample dimensions:
     81       Computing summary statistics, i.e., the average is a reduction over sample
     82       dimensions.
     83 
     84       ```python
     85       sample_dims = [0]
     86       tf.reduce_mean(Normal(loc=1.3, scale=1.).sample_n(1000),
     87                      axis=sample_dims)  # ~= 1.3
     88       ```
     89 
     90     - Batch dimensions:
     91       Monte Carlo estimation of a marginal probability:
     92       Average over batch dimensions where batch dimensions are associated with
     93       random draws from a prior.
     94       E.g., suppose we want to find the Monte Carlo estimate of the marginal
     95       distribution of a `Normal` with a random `Laplace` location:
     96 
     97       ```
     98       P(X=x) = integral P(X=x|y) P(Y=y) dy
     99             ~= 1/n sum_{i=1}^n P(X=x|y_i),   y_i ~iid Laplace(0,1)
    100              = tf.reduce_mean(Normal(loc=Laplace(0., 1.).sample_n(n=1000),
    101                                      scale=tf.ones(1000)).prob(x),
    102                               axis=batch_dims)
    103       ```
    104 
    105       The `Laplace` distribution generates a `Tensor` of shape `[1000]`. When
    106       fed to a `Normal`, this is interpreted as 1000 different locations, i.e.,
    107       1000 non-identical Normals. Therefore a single call to `prob(x)` yields
    108       1000 probabilities, one for every location. The average over this batch
    109       yields the marginal.
    110 
    111     - Event dimensions:
    112       Computing the determinant of the Jacobian of a function of a random
    113       variable involves a reduction over event dimensions.
    114       E.g., Jacobian of the transform `Y = g(X) = exp(X)`:
    115 
    116       ```python
    117       tf.div(1., tf.reduce_prod(x, event_dims))
    118       ```
    119 
    120   We show examples using this class.
    121 
    122   Write `S, B, E` for `sample_shape`, `batch_shape`, and `event_shape`.
    123 
    124   ```python
    125   # 150 iid samples from one multivariate Normal with two degrees of freedom.
    126   mu = [0., 0]
    127   sigma = [[1., 0],
    128            [0,  1]]
    129   mvn = MultivariateNormal(mu, sigma)
    130   rand_mvn = mvn.sample(sample_shape=[3, 50])
    131   shaper = DistributionShape(batch_ndims=0, event_ndims=1)
    132   S, B, E = shaper.get_shape(rand_mvn)
    133   # S = [3, 50]
    134   # B = []
    135   # E = [2]
    136 
    137   # 12 iid samples from one Wishart with 2x2 events.
    138   sigma = [[1., 0],
    139            [2,  1]]
    140   wishart = Wishart(df=5, scale=sigma)
    141   rand_wishart = wishart.sample(sample_shape=[3, 4])
    142   shaper = DistributionShape(batch_ndims=0, event_ndims=2)
    143   S, B, E = shaper.get_shape(rand_wishart)
    144   # S = [3, 4]
    145   # B = []
    146   # E = [2, 2]
    147 
    148   # 100 iid samples from two, non-identical trivariate Normal distributions.
    149   mu    = ...  # shape(2, 3)
    150   sigma = ...  # shape(2, 3, 3)
    151   X = MultivariateNormal(mu, sigma).sample(shape=[4, 25])
    152   # S = [4, 25]
    153   # B = [2]
    154   # E = [3]
    155   ```
    156 
    157   #### Argument Validation
    158 
    159   When `validate_args=False`, checks that cannot be done during
    160   graph construction are performed at graph execution. This may result in a
    161   performance degradation because data must be switched from GPU to CPU.
    162 
    163   For example, when `validate_args=False` and `event_ndims` is a
    164   non-constant `Tensor`, it is checked to be a non-negative integer at graph
    165   execution. (Same for `batch_ndims`). Constant `Tensor`s and non-`Tensor`
    166   arguments are always checked for correctness since this can be done for
    167   "free," i.e., during graph construction.
    168   """
    169 
    170   @deprecation.deprecated(
    171       "2018-10-01",
    172       "The TensorFlow Distributions library has moved to "
    173       "TensorFlow Probability "
    174       "(https://github.com/tensorflow/probability). You "
    175       "should update all references to use `tfp.distributions` "
    176       "instead of `tf.contrib.distributions`.",
    177       warn_once=True)
    178   def __init__(self,
    179                batch_ndims=None,
    180                event_ndims=None,
    181                validate_args=False,
    182                name="DistributionShape"):
    183     """Construct `DistributionShape` with fixed `batch_ndims`, `event_ndims`.
    184 
    185     `batch_ndims` and `event_ndims` are fixed throughout the lifetime of a
    186     `Distribution`. They may only be known at graph execution.
    187 
    188     If both `batch_ndims` and `event_ndims` are python scalars (rather than
    189     either being a `Tensor`), functions in this class automatically perform
    190     sanity checks during graph construction.
    191 
    192     Args:
    193       batch_ndims: `Tensor`. Number of `dims` (`rank`) of the batch portion of
    194         indexes of a `Tensor`. A "batch" is a non-identical distribution, i.e,
    195         Normal with different parameters.
    196       event_ndims: `Tensor`. Number of `dims` (`rank`) of the event portion of
    197         indexes of a `Tensor`. An "event" is what is sampled from a
    198         distribution, i.e., a trivariate Normal has an event shape of [3] and a
    199         4 dimensional Wishart has an event shape of [4, 4].
    200       validate_args: Python `bool`, default `False`. When `True`,
    201         non-`tf.constant` `Tensor` arguments are checked for correctness.
    202         (`tf.constant` arguments are always checked.)
    203       name: Python `str`. The name prepended to Ops created by this class.
    204 
    205     Raises:
    206       ValueError: if either `batch_ndims` or `event_ndims` are: `None`,
    207         negative, not `int32`.
    208     """
    209     if batch_ndims is None: raise ValueError("batch_ndims cannot be None")
    210     if event_ndims is None: raise ValueError("event_ndims cannot be None")
    211     self._batch_ndims = batch_ndims
    212     self._event_ndims = event_ndims
    213     self._validate_args = validate_args
    214     with ops.name_scope(name):
    215       self._name = name
    216       with ops.name_scope("init"):
    217         self._batch_ndims = self._assert_non_negative_int32_scalar(
    218             ops.convert_to_tensor(
    219                 batch_ndims, name="batch_ndims"))
    220         self._batch_ndims_static, self._batch_ndims_is_0 = (
    221             self._introspect_ndims(self._batch_ndims))
    222         self._event_ndims = self._assert_non_negative_int32_scalar(
    223             ops.convert_to_tensor(
    224                 event_ndims, name="event_ndims"))
    225         self._event_ndims_static, self._event_ndims_is_0 = (
    226             self._introspect_ndims(self._event_ndims))
    227 
    228   @property
    229   def name(self):
    230     """Name given to ops created by this class."""
    231     return self._name
    232 
    233   @property
    234   def batch_ndims(self):
    235     """Returns number of dimensions corresponding to non-identical draws."""
    236     return self._batch_ndims
    237 
    238   @property
    239   def event_ndims(self):
    240     """Returns number of dimensions needed to index a sample's coordinates."""
    241     return self._event_ndims
    242 
    243   @property
    244   def validate_args(self):
    245     """Returns True if graph-runtime `Tensor` checks are enabled."""
    246     return self._validate_args
    247 
    248   def get_ndims(self, x, name="get_ndims"):
    249     """Get `Tensor` number of dimensions (rank).
    250 
    251     Args:
    252       x: `Tensor`.
    253       name: Python `str`. The name to give this op.
    254 
    255     Returns:
    256       ndims: Scalar number of dimensions associated with a `Tensor`.
    257     """
    258     with self._name_scope(name, values=[x]):
    259       x = ops.convert_to_tensor(x, name="x")
    260       ndims = x.get_shape().ndims
    261       if ndims is None:
    262         return array_ops.rank(x, name="ndims")
    263       return ops.convert_to_tensor(ndims, dtype=dtypes.int32, name="ndims")
    264 
    265   def get_sample_ndims(self, x, name="get_sample_ndims"):
    266     """Returns number of dimensions corresponding to iid draws ("sample").
    267 
    268     Args:
    269       x: `Tensor`.
    270       name: Python `str`. The name to give this op.
    271 
    272     Returns:
    273       sample_ndims: `Tensor` (0D, `int32`).
    274 
    275     Raises:
    276       ValueError: if `sample_ndims` is calculated to be negative.
    277     """
    278     with self._name_scope(name, values=[x]):
    279       ndims = self.get_ndims(x, name=name)
    280       if self._is_all_constant_helper(ndims, self.batch_ndims,
    281                                       self.event_ndims):
    282         ndims = tensor_util.constant_value(ndims)
    283         sample_ndims = (ndims - self._batch_ndims_static -
    284                         self._event_ndims_static)
    285         if sample_ndims < 0:
    286           raise ValueError(
    287               "expected batch_ndims(%d) + event_ndims(%d) <= ndims(%d)" %
    288               (self._batch_ndims_static, self._event_ndims_static, ndims))
    289         return ops.convert_to_tensor(sample_ndims, name="sample_ndims")
    290       else:
    291         with ops.name_scope(name="sample_ndims"):
    292           sample_ndims = ndims - self.batch_ndims - self.event_ndims
    293           if self.validate_args:
    294             sample_ndims = control_flow_ops.with_dependencies(
    295                 [check_ops.assert_non_negative(sample_ndims)], sample_ndims)
    296         return sample_ndims
    297 
    298   def get_dims(self, x, name="get_dims"):
    299     """Returns dimensions indexing `sample_shape`, `batch_shape`, `event_shape`.
    300 
    301     Example:
    302 
    303     ```python
    304     x = ...  # Tensor with shape [4, 3, 2, 1]
    305     sample_dims, batch_dims, event_dims = _DistributionShape(
    306       batch_ndims=2, event_ndims=1).get_dims(x)
    307     # sample_dims == [0]
    308     # batch_dims == [1, 2]
    309     # event_dims == [3]
    310     # Note that these are not the shape parts, but rather indexes into shape.
    311     ```
    312 
    313     Args:
    314       x: `Tensor`.
    315       name: Python `str`. The name to give this op.
    316 
    317     Returns:
    318       sample_dims: `Tensor` (1D, `int32`).
    319       batch_dims: `Tensor` (1D, `int32`).
    320       event_dims: `Tensor` (1D, `int32`).
    321     """
    322     with self._name_scope(name, values=[x]):
    323       def make_dims(start_sum, size, name):
    324         """Closure to make dims range."""
    325         start_sum = start_sum if start_sum else [
    326             array_ops.zeros([], dtype=dtypes.int32, name="zero")]
    327         if self._is_all_constant_helper(size, *start_sum):
    328           start = sum(tensor_util.constant_value(s) for s in start_sum)
    329           stop = start + tensor_util.constant_value(size)
    330           return ops.convert_to_tensor(
    331               list(range(start, stop)), dtype=dtypes.int32, name=name)
    332         else:
    333           start = sum(start_sum)
    334           return math_ops.range(start, start + size)
    335       sample_ndims = self.get_sample_ndims(x, name=name)
    336       return (make_dims([], sample_ndims, name="sample_dims"),
    337               make_dims([sample_ndims], self.batch_ndims, name="batch_dims"),
    338               make_dims([sample_ndims, self.batch_ndims],
    339                         self.event_ndims, name="event_dims"))
    340 
    341   def get_shape(self, x, name="get_shape"):
    342     """Returns `Tensor`'s shape partitioned into `sample`, `batch`, `event`.
    343 
    344     Args:
    345       x: `Tensor`.
    346       name: Python `str`. The name to give this op.
    347 
    348     Returns:
    349       sample_shape: `Tensor` (1D, `int32`).
    350       batch_shape: `Tensor` (1D, `int32`).
    351       event_shape: `Tensor` (1D, `int32`).
    352     """
    353     with self._name_scope(name, values=[x]):
    354       x = ops.convert_to_tensor(x, name="x")
    355       def slice_shape(start_sum, size, name):
    356         """Closure to slice out shape."""
    357         start_sum = start_sum if start_sum else [
    358             array_ops.zeros([], dtype=dtypes.int32, name="zero")]
    359         if (x.get_shape().ndims is not None and
    360             self._is_all_constant_helper(size, *start_sum)):
    361           start = sum(tensor_util.constant_value(s) for s in start_sum)
    362           stop = start + tensor_util.constant_value(size)
    363           slice_ = x.get_shape()[start:stop].as_list()
    364           if all(s is not None for s in slice_):
    365             return ops.convert_to_tensor(slice_, dtype=dtypes.int32, name=name)
    366         return array_ops.slice(array_ops.shape(x), [sum(start_sum)], [size])
    367       sample_ndims = self.get_sample_ndims(x, name=name)
    368       return (slice_shape([], sample_ndims,
    369                           name="sample_shape"),
    370               slice_shape([sample_ndims], self.batch_ndims,
    371                           name="batch_shape"),
    372               slice_shape([sample_ndims, self.batch_ndims], self.event_ndims,
    373                           name="event_shape"))
    374 
    375   # TODO(jvdillon): Make remove expand_batch_dim and make expand_batch_dim=False
    376   # the default behavior.
    377   def make_batch_of_event_sample_matrices(
    378       self, x, expand_batch_dim=True,
    379       name="make_batch_of_event_sample_matrices"):
    380     """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_.
    381 
    382     Where:
    383       - `B_ = B if B or not expand_batch_dim else [1]`,
    384       - `E_ = E if E else [1]`,
    385       - `S_ = [tf.reduce_prod(S)]`.
    386 
    387     Args:
    388       x: `Tensor`.
    389       expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded
    390         such that `batch_ndims >= 1`.
    391       name: Python `str`. The name to give this op.
    392 
    393     Returns:
    394       x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`.
    395       sample_shape: `Tensor` (1D, `int32`).
    396     """
    397     with self._name_scope(name, values=[x]):
    398       x = ops.convert_to_tensor(x, name="x")
    399       # x.shape: S+B+E
    400       sample_shape, batch_shape, event_shape = self.get_shape(x)
    401       event_shape = distribution_util.pick_vector(
    402           self._event_ndims_is_0, [1], event_shape)
    403       if expand_batch_dim:
    404         batch_shape = distribution_util.pick_vector(
    405             self._batch_ndims_is_0, [1], batch_shape)
    406       new_shape = array_ops.concat([[-1], batch_shape, event_shape], 0)
    407       x = array_ops.reshape(x, shape=new_shape)
    408       # x.shape: [prod(S)]+B_+E_
    409       x = distribution_util.rotate_transpose(x, shift=-1)
    410       # x.shape: B_+E_+[prod(S)]
    411       return x, sample_shape
    412 
    413   # TODO(jvdillon): Make remove expand_batch_dim and make expand_batch_dim=False
    414   # the default behavior.
    415   def undo_make_batch_of_event_sample_matrices(
    416       self, x, sample_shape, expand_batch_dim=True,
    417       name="undo_make_batch_of_event_sample_matrices"):
    418     """Reshapes/transposes `Distribution` `Tensor` from B_+E_+S_ to S+B+E.
    419 
    420     Where:
    421       - `B_ = B if B or not expand_batch_dim else [1]`,
    422       - `E_ = E if E else [1]`,
    423       - `S_ = [tf.reduce_prod(S)]`.
    424 
    425     This function "reverses" `make_batch_of_event_sample_matrices`.
    426 
    427     Args:
    428       x: `Tensor` of shape `B_+E_+S_`.
    429       sample_shape: `Tensor` (1D, `int32`).
    430       expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded
    431         such that `batch_ndims>=1`.
    432       name: Python `str`. The name to give this op.
    433 
    434     Returns:
    435       x: `Tensor`. Input transposed/reshaped to `S+B+E`.
    436     """
    437     with self._name_scope(name, values=[x, sample_shape]):
    438       x = ops.convert_to_tensor(x, name="x")
    439       # x.shape: _B+_E+[prod(S)]
    440       sample_shape = ops.convert_to_tensor(sample_shape, name="sample_shape")
    441       x = distribution_util.rotate_transpose(x, shift=1)
    442       # x.shape: [prod(S)]+_B+_E
    443       if self._is_all_constant_helper(self.batch_ndims, self.event_ndims):
    444         if self._batch_ndims_is_0 or self._event_ndims_is_0:
    445           squeeze_dims = []
    446           if self._event_ndims_is_0:
    447             squeeze_dims += [-1]
    448           if self._batch_ndims_is_0 and expand_batch_dim:
    449             squeeze_dims += [1]
    450           if squeeze_dims:
    451             x = array_ops.squeeze(x, axis=squeeze_dims)
    452             # x.shape: [prod(S)]+B+E
    453         _, batch_shape, event_shape = self.get_shape(x)
    454       else:
    455         s = (x.get_shape().as_list() if x.get_shape().is_fully_defined()
    456              else array_ops.shape(x))
    457         batch_shape = s[1:1+self.batch_ndims]
    458         # Since sample_dims=1 and is left-most, we add 1 to the number of
    459         # batch_ndims to get the event start dim.
    460         event_start = array_ops.where(
    461             math_ops.logical_and(expand_batch_dim, self._batch_ndims_is_0),
    462             2, 1 + self.batch_ndims)
    463         event_shape = s[event_start:event_start+self.event_ndims]
    464       new_shape = array_ops.concat([sample_shape, batch_shape, event_shape], 0)
    465       x = array_ops.reshape(x, shape=new_shape)
    466       # x.shape: S+B+E
    467       return x
    468 
    469   @contextlib.contextmanager
    470   def _name_scope(self, name=None, values=None):
    471     """Helper function to standardize op scope."""
    472     with ops.name_scope(self.name):
    473       with ops.name_scope(name, values=(
    474           (values or []) + [self.batch_ndims, self.event_ndims])) as scope:
    475         yield scope
    476 
    477   def _is_all_constant_helper(self, *args):
    478     """Helper which returns True if all inputs are constant_value."""
    479     return all(tensor_util.constant_value(x) is not None for x in args)
    480 
    481   def _assert_non_negative_int32_scalar(self, x):
    482     """Helper which ensures that input is a non-negative, int32, scalar."""
    483     x = ops.convert_to_tensor(x, name="x")
    484     if x.dtype.base_dtype != dtypes.int32.base_dtype:
    485       raise TypeError("%s.dtype=%s is not %s" % (x.name, x.dtype, dtypes.int32))
    486     x_value_static = tensor_util.constant_value(x)
    487     if x.get_shape().ndims is not None and x_value_static is not None:
    488       if x.get_shape().ndims != 0:
    489         raise ValueError("%s.ndims=%d is not 0 (scalar)" %
    490                          (x.name, x.get_shape().ndims))
    491       if x_value_static < 0:
    492         raise ValueError("%s.value=%d cannot be negative" %
    493                          (x.name, x_value_static))
    494       return x
    495     if self.validate_args:
    496       x = control_flow_ops.with_dependencies([
    497           check_ops.assert_rank(x, 0),
    498           check_ops.assert_non_negative(x)], x)
    499     return x
    500 
    501   def _introspect_ndims(self, ndims):
    502     """Helper to establish some properties of input ndims args."""
    503     if self._is_all_constant_helper(ndims):
    504       return (tensor_util.constant_value(ndims),
    505               tensor_util.constant_value(ndims) == 0)
    506     return None, math_ops.equal(ndims, 0)
    507