Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Operations often used for initializing tensors.
     16 
     17 All variable initializers returned by functions in this file should have the
     18 following signature:
     19 
     20 def _initializer(shape, dtype=dtypes.float32, partition_info=None):
     21   Args:
     22     shape: List of `int` representing the shape of the output `Tensor`. Some
     23       initializers may also be able to accept a `Tensor`.
     24     dtype: (Optional) Type of the output `Tensor`.
     25     partition_info: (Optional) variable_scope._PartitionInfo object holding
     26       additional information about how the variable is partitioned. May be
     27       `None` if the variable is not partitioned.
     28   Returns:
     29     A `Tensor` of type `dtype` and `shape`.
     30 """
     31 from __future__ import absolute_import
     32 from __future__ import division
     33 from __future__ import print_function
     34 
     35 import math
     36 
     37 import numpy as np
     38 
     39 from tensorflow.python.framework import constant_op
     40 from tensorflow.python.framework import dtypes
     41 from tensorflow.python.ops import array_ops
     42 from tensorflow.python.ops import linalg_ops
     43 from tensorflow.python.ops import math_ops
     44 from tensorflow.python.ops import random_ops
     45 from tensorflow.python.ops import random_ops
     46 from tensorflow.python.util.deprecation import deprecated
     47 from tensorflow.python.util.tf_export import tf_export
     48 
     49 
     50 @tf_export("keras.initializers.Initializer")
     51 class Initializer(object):
     52   """Initializer base class: all initializers inherit from this class.
     53   """
     54 
     55   def __call__(self, shape, dtype=None, partition_info=None):
     56     raise NotImplementedError
     57 
     58   def get_config(self):
     59     """Returns the configuration of the initializer as a JSON-serializable dict.
     60 
     61     Returns:
     62       A JSON-serializable Python dict.
     63     """
     64     return {}
     65 
     66   @classmethod
     67   def from_config(cls, config):
     68     """Instantiates an initializer from a configuration dictionary.
     69 
     70     Example:
     71 
     72     ```python
     73     initializer = RandomUniform(-1, 1)
     74     config = initializer.get_config()
     75     initializer = RandomUniform.from_config(config)
     76     ```
     77 
     78     Args:
     79       config: A Python dictionary.
     80         It will typically be the output of `get_config`.
     81 
     82     Returns:
     83       An Initializer instance.
     84     """
     85     return cls(**config)
     86 
     87 
     88 @tf_export("keras.initializers.Zeros", "initializers.zeros",
     89            "zeros_initializer")
     90 class Zeros(Initializer):
     91   """Initializer that generates tensors initialized to 0."""
     92 
     93   def __init__(self, dtype=dtypes.float32):
     94     self.dtype = dtypes.as_dtype(dtype)
     95 
     96   def __call__(self, shape, dtype=None, partition_info=None):
     97     if dtype is None:
     98       dtype = self.dtype
     99     return array_ops.zeros(shape, dtype)
    100 
    101   def get_config(self):
    102     return {"dtype": self.dtype.name}
    103 
    104 
    105 @tf_export("keras.initializers.Ones", "initializers.ones", "ones_initializer")
    106 class Ones(Initializer):
    107   """Initializer that generates tensors initialized to 1."""
    108 
    109   def __init__(self, dtype=dtypes.float32):
    110     self.dtype = dtypes.as_dtype(dtype)
    111 
    112   def __call__(self, shape, dtype=None, partition_info=None):
    113     if dtype is None:
    114       dtype = self.dtype
    115     return array_ops.ones(shape, dtype)
    116 
    117   def get_config(self):
    118     return {"dtype": self.dtype.name}
    119 
    120 
    121 @tf_export("keras.initializers.Constant", "initializers.constant",
    122            "constant_initializer")
    123 class Constant(Initializer):
    124   """Initializer that generates tensors with constant values.
    125 
    126   The resulting tensor is populated with values of type `dtype`, as
    127   specified by arguments `value` following the desired `shape` of the
    128   new tensor (see examples below).
    129 
    130   The argument `value` can be a constant value, or a list of values of type
    131   `dtype`. If `value` is a list, then the length of the list must be less
    132   than or equal to the number of elements implied by the desired shape of the
    133   tensor. In the case where the total number of elements in `value` is less
    134   than the number of elements required by the tensor shape, the last element
    135   in `value` will be used to fill the remaining entries. If the total number of
    136   elements in `value` is greater than the number of elements required by the
    137   tensor shape, the initializer will raise a `ValueError`.
    138 
    139   Args:
    140     value: A Python scalar, list or tuple of values, or a N-dimensional numpy
    141       array. All elements of the initialized variable will be set to the
    142       corresponding value in the `value` argument.
    143     dtype: The data type.
    144     verify_shape: Boolean that enables verification of the shape of `value`. If
    145       `True`, the initializer will throw an error if the shape of `value` is not
    146       compatible with the shape of the initialized tensor.
    147 
    148   Raises:
    149     TypeError: If the input `value` is not one of the expected types.
    150 
    151   Examples:
    152     The following example can be rewritten using a numpy.ndarray instead
    153     of the `value` list, even reshaped, as shown in the two commented lines
    154     below the `value` list initialization.
    155 
    156   ```python
    157     >>> import numpy as np
    158     >>> import tensorflow as tf
    159 
    160     >>> value = [0, 1, 2, 3, 4, 5, 6, 7]
    161     >>> # value = np.array(value)
    162     >>> # value = value.reshape([2, 4])
    163     >>> init = tf.constant_initializer(value)
    164 
    165     >>> print('fitting shape:')
    166     >>> with tf.Session():
    167     >>>   x = tf.get_variable('x', shape=[2, 4], initializer=init)
    168     >>>   x.initializer.run()
    169     >>>   print(x.eval())
    170 
    171     fitting shape:
    172     [[ 0.  1.  2.  3.]
    173      [ 4.  5.  6.  7.]]
    174 
    175     >>> print('larger shape:')
    176     >>> with tf.Session():
    177     >>>   x = tf.get_variable('x', shape=[3, 4], initializer=init)
    178     >>>   x.initializer.run()
    179     >>>   print(x.eval())
    180 
    181     larger shape:
    182     [[ 0.  1.  2.  3.]
    183      [ 4.  5.  6.  7.]
    184      [ 7.  7.  7.  7.]]
    185 
    186     >>> print('smaller shape:')
    187     >>> with tf.Session():
    188     >>>   x = tf.get_variable('x', shape=[2, 3], initializer=init)
    189 
    190     ValueError: Too many elements provided. Needed at most 6, but received 8
    191 
    192     >>> print('shape verification:')
    193     >>> init_verify = tf.constant_initializer(value, verify_shape=True)
    194     >>> with tf.Session():
    195     >>>   x = tf.get_variable('x', shape=[3, 4], initializer=init_verify)
    196 
    197     TypeError: Expected Tensor's shape: (3, 4), got (8,).
    198   ```
    199   """
    200 
    201   def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False):
    202     if not (np.isscalar(value) or isinstance(value, (list, tuple, np.ndarray))):
    203       raise TypeError(
    204           "Invalid type for initial value: %s (expected Python scalar, list or "
    205           "tuple of values, or numpy.ndarray)." % type(value))
    206 
    207     self.value = value
    208     self.dtype = dtypes.as_dtype(dtype)
    209     self._verify_shape = verify_shape
    210 
    211   def __call__(self, shape, dtype=None, partition_info=None, verify_shape=None):
    212     if dtype is None:
    213       dtype = self.dtype
    214     if verify_shape is None:
    215       verify_shape = self._verify_shape
    216     return constant_op.constant(
    217         self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
    218 
    219   def get_config(self):
    220     # We don't include `verify_shape` for compatibility with Keras.
    221     # `verify_shape` should be passed as an argument to `__call__` rather
    222     # than as a constructor argument: conceptually it isn't a property
    223     # of the initializer.
    224     return {"value": self.value, "dtype": self.dtype.name}
    225 
    226 
    227 @tf_export("keras.initializers.RandomUniform", "initializers.random_uniform",
    228            "random_uniform_initializer")
    229 class RandomUniform(Initializer):
    230   """Initializer that generates tensors with a uniform distribution.
    231 
    232   Args:
    233     minval: A python scalar or a scalar tensor. Lower bound of the range
    234       of random values to generate.
    235     maxval: A python scalar or a scalar tensor. Upper bound of the range
    236       of random values to generate.  Defaults to 1 for float types.
    237     seed: A Python integer. Used to create random seeds. See
    238       @{tf.set_random_seed}
    239       for behavior.
    240     dtype: The data type.
    241   """
    242 
    243   def __init__(self, minval=0, maxval=None, seed=None, dtype=dtypes.float32):
    244     self.minval = minval
    245     self.maxval = maxval
    246     self.seed = seed
    247     self.dtype = dtypes.as_dtype(dtype)
    248 
    249   def __call__(self, shape, dtype=None, partition_info=None):
    250     if dtype is None:
    251       dtype = self.dtype
    252     return random_ops.random_uniform(
    253         shape, self.minval, self.maxval, dtype, seed=self.seed)
    254 
    255   def get_config(self):
    256     return {
    257         "minval": self.minval,
    258         "maxval": self.maxval,
    259         "seed": self.seed,
    260         "dtype": self.dtype.name
    261     }
    262 
    263 
    264 @tf_export("keras.initializers.RandomNormal", "initializers.random_normal",
    265            "random_normal_initializer")
    266 class RandomNormal(Initializer):
    267   """Initializer that generates tensors with a normal distribution.
    268 
    269   Args:
    270     mean: a python scalar or a scalar tensor. Mean of the random values
    271       to generate.
    272     stddev: a python scalar or a scalar tensor. Standard deviation of the
    273       random values to generate.
    274     seed: A Python integer. Used to create random seeds. See
    275       @{tf.set_random_seed}
    276       for behavior.
    277     dtype: The data type. Only floating point types are supported.
    278   """
    279 
    280   def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
    281     self.mean = mean
    282     self.stddev = stddev
    283     self.seed = seed
    284     self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
    285 
    286   def __call__(self, shape, dtype=None, partition_info=None):
    287     if dtype is None:
    288       dtype = self.dtype
    289     return random_ops.random_normal(
    290         shape, self.mean, self.stddev, dtype, seed=self.seed)
    291 
    292   def get_config(self):
    293     return {
    294         "mean": self.mean,
    295         "stddev": self.stddev,
    296         "seed": self.seed,
    297         "dtype": self.dtype.name
    298     }
    299 
    300 
    301 @tf_export("keras.initializers.TruncatedNormal",
    302            "initializers.truncated_normal", "truncated_normal_initializer")
    303 class TruncatedNormal(Initializer):
    304   """Initializer that generates a truncated normal distribution.
    305 
    306   These values are similar to values from a `random_normal_initializer`
    307   except that values more than two standard deviations from the mean
    308   are discarded and re-drawn. This is the recommended initializer for
    309   neural network weights and filters.
    310 
    311   Args:
    312     mean: a python scalar or a scalar tensor. Mean of the random values
    313       to generate.
    314     stddev: a python scalar or a scalar tensor. Standard deviation of the
    315       random values to generate.
    316     seed: A Python integer. Used to create random seeds. See
    317       @{tf.set_random_seed}
    318       for behavior.
    319     dtype: The data type. Only floating point types are supported.
    320   """
    321 
    322   def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
    323     self.mean = mean
    324     self.stddev = stddev
    325     self.seed = seed
    326     self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
    327 
    328   def __call__(self, shape, dtype=None, partition_info=None):
    329     if dtype is None:
    330       dtype = self.dtype
    331     return random_ops.truncated_normal(
    332         shape, self.mean, self.stddev, dtype, seed=self.seed)
    333 
    334   def get_config(self):
    335     return {
    336         "mean": self.mean,
    337         "stddev": self.stddev,
    338         "seed": self.seed,
    339         "dtype": self.dtype.name
    340     }
    341 
    342 
    343 @tf_export("initializers.uniform_unit_scaling",
    344            "uniform_unit_scaling_initializer")
    345 class UniformUnitScaling(Initializer):
    346   """Initializer that generates tensors without scaling variance.
    347 
    348   When initializing a deep network, it is in principle advantageous to keep
    349   the scale of the input variance constant, so it does not explode or diminish
    350   by reaching the final layer. If the input is `x` and the operation `x * W`,
    351   and we want to initialize `W` uniformly at random, we need to pick `W` from
    352 
    353       [-sqrt(3) / sqrt(dim), sqrt(3) / sqrt(dim)]
    354 
    355   to keep the scale intact, where `dim = W.shape[0]` (the size of the input).
    356   A similar calculation for convolutional networks gives an analogous result
    357   with `dim` equal to the product of the first 3 dimensions.  When
    358   nonlinearities are present, we need to multiply this by a constant `factor`.
    359   See [Sussillo et al., 2014](https://arxiv.org/abs/1412.6558)
    360   ([pdf](http://arxiv.org/pdf/1412.6558.pdf)) for deeper motivation, experiments
    361   and the calculation of constants. In section 2.3 there, the constants were
    362   numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15.
    363 
    364   Args:
    365     factor: Float.  A multiplicative factor by which the values will be scaled.
    366     seed: A Python integer. Used to create random seeds. See
    367       @{tf.set_random_seed}
    368       for behavior.
    369     dtype: The data type. Only floating point types are supported.
    370   """
    371 
    372   @deprecated(None,
    373               "Use tf.initializers.variance_scaling instead with distribution="
    374               "uniform to get equivalent behavior.")
    375   def __init__(self, factor=1.0, seed=None, dtype=dtypes.float32):
    376     self.factor = factor
    377     self.seed = seed
    378     self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
    379 
    380   def __call__(self, shape, dtype=None, partition_info=None):
    381     if dtype is None:
    382       dtype = self.dtype
    383     scale_shape = shape
    384     if partition_info is not None:
    385       scale_shape = partition_info.full_shape
    386 
    387     input_size = 1.0
    388     # Estimating input size is not possible to do perfectly, but we try.
    389     # The estimate, obtained by multiplying all dimensions but the last one,
    390     # is the right thing for matrix multiply and convolutions (see above).
    391     for dim in scale_shape[:-1]:
    392       input_size *= float(dim)
    393     # Avoid errors when initializing zero-size tensors.
    394     input_size = max(input_size, 1.0)
    395     max_val = math.sqrt(3 / input_size) * self.factor
    396     return random_ops.random_uniform(
    397         shape, -max_val, max_val, dtype, seed=self.seed)
    398 
    399   def get_config(self):
    400     return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name}
    401 
    402 
    403 @tf_export("keras.initializers.VarianceScaling",
    404            "initializers.variance_scaling", "variance_scaling_initializer")
    405 class VarianceScaling(Initializer):
    406   """Initializer capable of adapting its scale to the shape of weights tensors.
    407 
    408   With `distribution="normal"`, samples are drawn from a truncated normal
    409   distribution centered on zero, with `stddev = sqrt(scale / n)`
    410   where n is:
    411     - number of input units in the weight tensor, if mode = "fan_in"
    412     - number of output units, if mode = "fan_out"
    413     - average of the numbers of input and output units, if mode = "fan_avg"
    414 
    415   With `distribution="uniform"`, samples are drawn from a uniform distribution
    416   within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
    417 
    418   Args:
    419     scale: Scaling factor (positive float).
    420     mode: One of "fan_in", "fan_out", "fan_avg".
    421     distribution: Random distribution to use. One of "normal", "uniform".
    422     seed: A Python integer. Used to create random seeds. See
    423       @{tf.set_random_seed}
    424       for behavior.
    425     dtype: The data type. Only floating point types are supported.
    426 
    427   Raises:
    428     ValueError: In case of an invalid value for the "scale", mode" or
    429       "distribution" arguments.
    430   """
    431 
    432   def __init__(self,
    433                scale=1.0,
    434                mode="fan_in",
    435                distribution="normal",
    436                seed=None,
    437                dtype=dtypes.float32):
    438     if scale <= 0.:
    439       raise ValueError("`scale` must be positive float.")
    440     if mode not in {"fan_in", "fan_out", "fan_avg"}:
    441       raise ValueError("Invalid `mode` argument:", mode)
    442     distribution = distribution.lower()
    443     if distribution not in {"normal", "uniform"}:
    444       raise ValueError("Invalid `distribution` argument:", distribution)
    445     self.scale = scale
    446     self.mode = mode
    447     self.distribution = distribution
    448     self.seed = seed
    449     self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
    450 
    451   def __call__(self, shape, dtype=None, partition_info=None):
    452     if dtype is None:
    453       dtype = self.dtype
    454     scale = self.scale
    455     scale_shape = shape
    456     if partition_info is not None:
    457       scale_shape = partition_info.full_shape
    458     fan_in, fan_out = _compute_fans(scale_shape)
    459     if self.mode == "fan_in":
    460       scale /= max(1., fan_in)
    461     elif self.mode == "fan_out":
    462       scale /= max(1., fan_out)
    463     else:
    464       scale /= max(1., (fan_in + fan_out) / 2.)
    465     if self.distribution == "normal":
    466       stddev = math.sqrt(scale)
    467       return random_ops.truncated_normal(
    468           shape, 0.0, stddev, dtype, seed=self.seed)
    469     else:
    470       limit = math.sqrt(3.0 * scale)
    471       return random_ops.random_uniform(
    472           shape, -limit, limit, dtype, seed=self.seed)
    473 
    474   def get_config(self):
    475     return {
    476         "scale": self.scale,
    477         "mode": self.mode,
    478         "distribution": self.distribution,
    479         "seed": self.seed,
    480         "dtype": self.dtype.name
    481     }
    482 
    483 
    484 @tf_export("keras.initializers.Orthogonal", "initializers.orthogonal",
    485            "orthogonal_initializer")
    486 class Orthogonal(Initializer):
    487   """Initializer that generates an orthogonal matrix.
    488 
    489   If the shape of the tensor to initialize is two-dimensional, it is initialized
    490   with an orthogonal matrix obtained from the QR decomposition of a matrix of
    491   uniform random numbers. If the matrix has fewer rows than columns then the
    492   output will have orthogonal rows. Otherwise, the output will have orthogonal
    493   columns.
    494 
    495   If the shape of the tensor to initialize is more than two-dimensional,
    496   a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`
    497   is initialized, where `n` is the length of the shape vector.
    498   The matrix is subsequently reshaped to give a tensor of the desired shape.
    499 
    500   Args:
    501     gain: multiplicative factor to apply to the orthogonal matrix
    502     dtype: The type of the output.
    503     seed: A Python integer. Used to create random seeds. See
    504       @{tf.set_random_seed}
    505       for behavior.
    506   """
    507 
    508   def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
    509     self.gain = gain
    510     self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
    511     self.seed = seed
    512 
    513   def __call__(self, shape, dtype=None, partition_info=None):
    514     if dtype is None:
    515       dtype = self.dtype
    516     # Check the shape
    517     if len(shape) < 2:
    518       raise ValueError("The tensor to initialize must be "
    519                        "at least two-dimensional")
    520     # Flatten the input shape with the last dimension remaining
    521     # its original shape so it works for conv2d
    522     num_rows = 1
    523     for dim in shape[:-1]:
    524       num_rows *= dim
    525     num_cols = shape[-1]
    526     flat_shape = (num_cols, num_rows) if num_rows < num_cols else (num_rows,
    527                                                                    num_cols)
    528 
    529     # Generate a random matrix
    530     a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed)
    531     # Compute the qr factorization
    532     q, r = linalg_ops.qr(a, full_matrices=False)
    533     # Make Q uniform
    534     d = array_ops.diag_part(r)
    535     ph = d / math_ops.abs(d)
    536     q *= ph
    537     if num_rows < num_cols:
    538       q = array_ops.matrix_transpose(q)
    539     return self.gain * array_ops.reshape(q, shape)
    540 
    541   def get_config(self):
    542     return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
    543 
    544 
    545 @tf_export("keras.initializers.Identity", "initializers.identity")
    546 class Identity(Initializer):
    547   """Initializer that generates the identity matrix.
    548 
    549   Only use for 2D matrices.
    550 
    551   Args:
    552     gain: Multiplicative factor to apply to the identity matrix.
    553     dtype: The type of the output.
    554   """
    555 
    556   def __init__(self, gain=1.0, dtype=dtypes.float32):
    557     self.gain = gain
    558     self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
    559 
    560   def __call__(self, shape, dtype=None, partition_info=None):
    561     full_shape = shape if partition_info is None else partition_info.full_shape
    562     if len(full_shape) != 2:
    563       raise ValueError(
    564           "Identity matrix initializer can only be used for 2D matrices.")
    565     if dtype is None:
    566       dtype = self.dtype
    567     initializer = linalg_ops.eye(*full_shape, dtype=dtype)
    568     if partition_info is not None:
    569       initializer = array_ops.slice(initializer, partition_info.var_offset,
    570                                     shape)
    571     return self.gain * initializer
    572 
    573   def get_config(self):
    574     return {"gain": self.gain, "dtype": self.dtype.name}
    575 
    576 # Aliases.
    577 
    578 # pylint: disable=invalid-name
    579 zeros_initializer = Zeros
    580 ones_initializer = Ones
    581 constant_initializer = Constant
    582 random_uniform_initializer = RandomUniform
    583 random_normal_initializer = RandomNormal
    584 truncated_normal_initializer = TruncatedNormal
    585 uniform_unit_scaling_initializer = UniformUnitScaling
    586 variance_scaling_initializer = VarianceScaling
    587 orthogonal_initializer = Orthogonal
    588 identity_initializer = Identity
    589 
    590 # pylint: enable=invalid-name
    591 
    592 
    593 @tf_export("glorot_uniform_initializer")
    594 def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
    595   """The Glorot uniform initializer, also called Xavier uniform initializer.
    596 
    597   It draws samples from a uniform distribution within [-limit, limit]
    598   where `limit` is `sqrt(6 / (fan_in + fan_out))`
    599   where `fan_in` is the number of input units in the weight tensor
    600   and `fan_out` is the number of output units in the weight tensor.
    601 
    602   Reference: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
    603 
    604   Args:
    605     seed: A Python integer. Used to create random seeds. See
    606       @{tf.set_random_seed}
    607       for behavior.
    608     dtype: The data type. Only floating point types are supported.
    609 
    610   Returns:
    611     An initializer.
    612   """
    613   return variance_scaling_initializer(
    614       scale=1.0, mode="fan_avg", distribution="uniform", seed=seed, dtype=dtype)
    615 
    616 
    617 @tf_export("glorot_normal_initializer")
    618 def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
    619   """The Glorot normal initializer, also called Xavier normal initializer.
    620 
    621   It draws samples from a truncated normal distribution centered on 0
    622   with `stddev = sqrt(2 / (fan_in + fan_out))`
    623   where `fan_in` is the number of input units in the weight tensor
    624   and `fan_out` is the number of output units in the weight tensor.
    625 
    626   Reference: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
    627 
    628   Args:
    629     seed: A Python integer. Used to create random seeds. See
    630       @{tf.set_random_seed}
    631       for behavior.
    632     dtype: The data type. Only floating point types are supported.
    633 
    634   Returns:
    635     An initializer.
    636   """
    637   return variance_scaling_initializer(
    638       scale=1.0, mode="fan_avg", distribution="normal", seed=seed, dtype=dtype)
    639 
    640 
    641 # Utility functions.
    642 
    643 
    644 def _compute_fans(shape):
    645   """Computes the number of input and output units for a weight shape.
    646 
    647   Args:
    648     shape: Integer shape tuple or TF tensor shape.
    649 
    650   Returns:
    651     A tuple of scalars (fan_in, fan_out).
    652   """
    653   if len(shape) < 1:  # Just to avoid errors for constants.
    654     fan_in = fan_out = 1
    655   elif len(shape) == 1:
    656     fan_in = fan_out = shape[0]
    657   elif len(shape) == 2:
    658     fan_in = shape[0]
    659     fan_out = shape[1]
    660   else:
    661     # Assuming convolution kernels (2D, 3D, or more).
    662     # kernel shape: (..., input_depth, depth)
    663     receptive_field_size = 1.
    664     for dim in shape[:-2]:
    665       receptive_field_size *= dim
    666     fan_in = shape[-2] * receptive_field_size
    667     fan_out = shape[-1] * receptive_field_size
    668   return fan_in, fan_out
    669 
    670 
    671 def _assert_float_dtype(dtype):
    672   """Validate and return floating point type based on `dtype`.
    673 
    674   `dtype` must be a floating point type.
    675 
    676   Args:
    677     dtype: The data type to validate.
    678 
    679   Returns:
    680     Validated type.
    681 
    682   Raises:
    683     ValueError: if `dtype` is not a floating point type.
    684   """
    685   if not dtype.is_floating:
    686     raise ValueError("Expected floating point type, got %s." % dtype)
    687   return dtype
    688