Home | History | Annotate | Download | only in ops
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Stateless random ops which take seed as a tensor input."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.ops import gen_stateless_random_ops
     22 
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.ops import random_ops
     26 from tensorflow.python.ops import math_ops
     27 from tensorflow.python.util import deprecation
     28 from tensorflow.python.util.tf_export import tf_export
     29 
     30 ops.NotDifferentiable("StatelessMultinomial")
     31 ops.NotDifferentiable("StatelessRandomNormal")
     32 ops.NotDifferentiable("StatelessRandomUniform")
     33 ops.NotDifferentiable("StatelessRandomUniformInt")
     34 ops.NotDifferentiable("StatelessTruncatedNormal")
     35 
     36 
     37 @tf_export("random.stateless_uniform")
     38 def stateless_random_uniform(shape,
     39                              seed,
     40                              minval=0,
     41                              maxval=None,
     42                              dtype=dtypes.float32,
     43                              name=None):
     44   """Outputs deterministic pseudorandom values from a uniform distribution.
     45 
     46   This is a stateless version of `tf.random_uniform`: if run twice with the
     47   same seeds, it will produce the same pseudorandom numbers.  The output is
     48   consistent across multiple runs on the same hardware (and between CPU
     49   and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
     50   hardware.
     51 
     52   The generated values follow a uniform distribution in the range
     53   `[minval, maxval)`. The lower bound `minval` is included in the range, while
     54   the upper bound `maxval` is excluded.
     55 
     56   For floats, the default range is `[0, 1)`.  For ints, at least `maxval` must
     57   be specified explicitly.
     58 
     59   In the integer case, the random integers are slightly biased unless
     60   `maxval - minval` is an exact power of two.  The bias is small for values of
     61   `maxval - minval` significantly smaller than the range of the output (either
     62   `2**32` or `2**64`).
     63 
     64   Args:
     65     shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
     66     seed: A shape [2] integer Tensor of seeds to the random number generator.
     67     minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on the
     68       range of random values to generate.  Defaults to 0.
     69     maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on the
     70       range of random values to generate.  Defaults to 1 if `dtype` is floating
     71       point.
     72     dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or
     73       `int64`.
     74     name: A name for the operation (optional).
     75 
     76   Returns:
     77     A tensor of the specified shape filled with random uniform values.
     78 
     79   Raises:
     80     ValueError: If `dtype` is integral and `maxval` is not specified.
     81   """
     82   dtype = dtypes.as_dtype(dtype)
     83   if dtype not in (dtypes.float16, dtypes.bfloat16, dtypes.float32,
     84                    dtypes.float64, dtypes.int32, dtypes.int64):
     85     raise ValueError("Invalid dtype %r" % dtype)
     86   if maxval is None:
     87     if dtype.is_integer:
     88       raise ValueError("Must specify maxval for integer dtype %r" % dtype)
     89     maxval = 1
     90   with ops.name_scope(name, "stateless_random_uniform",
     91                       [shape, seed, minval, maxval]) as name:
     92     shape = random_ops._ShapeTensor(shape)  # pylint: disable=protected-access
     93     minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
     94     maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
     95     if dtype.is_integer:
     96       return gen_stateless_random_ops.stateless_random_uniform_int(
     97           shape, seed=seed, minval=minval, maxval=maxval, name=name)
     98     else:
     99       rnd = gen_stateless_random_ops.stateless_random_uniform(
    100           shape, seed=seed, dtype=dtype)
    101       return math_ops.add(rnd * (maxval - minval), minval, name=name)
    102 
    103 
    104 @tf_export("random.stateless_normal")
    105 def stateless_random_normal(shape,
    106                             seed,
    107                             mean=0.0,
    108                             stddev=1.0,
    109                             dtype=dtypes.float32,
    110                             name=None):
    111   """Outputs deterministic pseudorandom values from a normal distribution.
    112 
    113   This is a stateless version of `tf.random_normal`: if run twice with the
    114   same seeds, it will produce the same pseudorandom numbers.  The output is
    115   consistent across multiple runs on the same hardware (and between CPU
    116   and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
    117   hardware.
    118 
    119   Args:
    120     shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
    121     seed: A shape [2] integer Tensor of seeds to the random number generator.
    122     mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
    123       distribution.
    124     stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
    125       of the normal distribution.
    126     dtype: The type of the output.
    127     name: A name for the operation (optional).
    128 
    129   Returns:
    130     A tensor of the specified shape filled with random normal values.
    131   """
    132   with ops.name_scope(name, "stateless_random_normal",
    133                       [shape, seed, mean, stddev]) as name:
    134     shape = random_ops._ShapeTensor(shape)  # pylint: disable=protected-access
    135     mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
    136     stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
    137     rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
    138     return math_ops.add(rnd * stddev, mean, name=name)
    139 
    140 
    141 @tf_export("random.stateless_truncated_normal")
    142 def stateless_truncated_normal(shape,
    143                                seed,
    144                                mean=0.0,
    145                                stddev=1.0,
    146                                dtype=dtypes.float32,
    147                                name=None):
    148   """Outputs deterministic pseudorandom values, truncated normally distributed.
    149 
    150   This is a stateless version of `tf.truncated_normal`: if run twice with the
    151   same seeds, it will produce the same pseudorandom numbers.  The output is
    152   consistent across multiple runs on the same hardware (and between CPU
    153   and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
    154   hardware.
    155 
    156   The generated values follow a normal distribution with specified mean and
    157   standard deviation, except that values whose magnitude is more than 2 standard
    158   deviations from the mean are dropped and re-picked.
    159 
    160   Args:
    161     shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
    162     seed: A shape [2] integer Tensor of seeds to the random number generator.
    163     mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
    164       truncated normal distribution.
    165     stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
    166       of the normal distribution, before truncation.
    167     dtype: The type of the output.
    168     name: A name for the operation (optional).
    169 
    170   Returns:
    171     A tensor of the specified shape filled with random truncated normal values.
    172   """
    173   with ops.name_scope(name, "stateless_truncated_normal",
    174                       [shape, seed, mean, stddev]) as name:
    175     shape = random_ops._ShapeTensor(shape)  # pylint: disable=protected-access
    176     mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
    177     stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
    178     rnd = gen_stateless_random_ops.stateless_truncated_normal(
    179         shape, seed, dtype)
    180     return math_ops.add(rnd * stddev, mean, name=name)
    181 
    182 
    183 @tf_export(v1=["random.stateless_multinomial"])
    184 @deprecation.deprecated(
    185     date=None, instructions="Use `tf.random.stateless_categorical` instead.")
    186 def stateless_multinomial(logits,
    187                           num_samples,
    188                           seed,
    189                           output_dtype=dtypes.int64,
    190                           name=None):
    191   """Draws deterministic pseudorandom samples from a multinomial distribution.
    192 
    193   This is a stateless version of `tf.multinomial`: if run twice with the
    194   same seeds, it will produce the same pseudorandom numbers.  The output is
    195   consistent across multiple runs on the same hardware (and between CPU
    196   and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
    197   hardware.
    198 
    199   Example:
    200 
    201   ```python
    202   # samples has shape [1, 5], where each value is either 0 or 1 with equal
    203   # probability.
    204   samples = tf.random.stateless_multinomial(
    205       tf.log([[10., 10.]]), 5, seed=[7, 17])
    206   ```
    207 
    208   Args:
    209     logits: 2-D Tensor with shape `[batch_size, num_classes]`.  Each slice
    210       `[i, :]` represents the unnormalized log-probabilities for all classes.
    211     num_samples: 0-D.  Number of independent samples to draw for each row slice.
    212     seed: A shape [2] integer Tensor of seeds to the random number generator.
    213     output_dtype: integer type to use for the output. Defaults to int64.
    214     name: Optional name for the operation.
    215 
    216   Returns:
    217     The drawn samples of shape `[batch_size, num_samples]`.
    218   """
    219   with ops.name_scope(name, "stateless_multinomial", [logits, seed]):
    220     return stateless_multinomial_categorical_impl(logits, num_samples,
    221                                                   output_dtype, seed)
    222 
    223 
    224 @tf_export("random.stateless_categorical")
    225 def stateless_categorical(logits,
    226                           num_samples,
    227                           seed,
    228                           dtype=dtypes.int64,
    229                           name=None):
    230   """Draws deterministic pseudorandom samples from a categorical distribution.
    231 
    232   This is a stateless version of `tf.categorical`: if run twice with the
    233   same seeds, it will produce the same pseudorandom numbers.  The output is
    234   consistent across multiple runs on the same hardware (and between CPU
    235   and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
    236   hardware.
    237 
    238   Example:
    239 
    240   ```python
    241   # samples has shape [1, 5], where each value is either 0 or 1 with equal
    242   # probability.
    243   samples = tf.random.stateless_categorical(
    244       tf.log([[10., 10.]]), 5, seed=[7, 17])
    245   ```
    246 
    247   Args:
    248     logits: 2-D Tensor with shape `[batch_size, num_classes]`.  Each slice
    249       `[i, :]` represents the unnormalized log-probabilities for all classes.
    250     num_samples: 0-D.  Number of independent samples to draw for each row slice.
    251     seed: A shape [2] integer Tensor of seeds to the random number generator.
    252     dtype: integer type to use for the output. Defaults to int64.
    253     name: Optional name for the operation.
    254 
    255   Returns:
    256     The drawn samples of shape `[batch_size, num_samples]`.
    257   """
    258   with ops.name_scope(name, "stateless_categorical", [logits, seed]):
    259     return stateless_multinomial_categorical_impl(logits, num_samples, dtype,
    260                                                   seed)
    261 
    262 
    263 def stateless_multinomial_categorical_impl(logits, num_samples, dtype, seed):
    264   """Implementation for stateless multinomial/categorical ops (v1/v2)."""
    265   logits = ops.convert_to_tensor(logits, name="logits")
    266   return gen_stateless_random_ops.stateless_multinomial(
    267       logits, num_samples, seed, output_dtype=dtype)
    268