Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Basic arithmetic operators.
     16 
     17 See the @{$python/math_ops} guide.
     18 
     19 @@add
     20 @@subtract
     21 @@multiply
     22 @@scalar_mul
     23 @@div
     24 @@divide
     25 @@truediv
     26 @@floordiv
     27 @@realdiv
     28 @@truncatediv
     29 @@floor_div
     30 @@truncatemod
     31 @@floormod
     32 @@mod
     33 @@cross
     34 @@add_n
     35 @@abs
     36 @@negative
     37 @@sign
     38 @@reciprocal
     39 @@square
     40 @@round
     41 @@sqrt
     42 @@rsqrt
     43 @@pow
     44 @@exp
     45 @@expm1
     46 @@log
     47 @@log1p
     48 @@sinh
     49 @@cosh
     50 @@asinh
     51 @@acosh
     52 @@atanh
     53 @@ceil
     54 @@floor
     55 @@maximum
     56 @@minimum
     57 @@cos
     58 @@sin
     59 @@lbeta
     60 @@tan
     61 @@acos
     62 @@asin
     63 @@atan
     64 @@atan2
     65 @@lgamma
     66 @@digamma
     67 @@erf
     68 @@erfc
     69 @@squared_difference
     70 @@igamma
     71 @@igammac
     72 @@zeta
     73 @@polygamma
     74 @@betainc
     75 @@rint
     76 @@diag
     77 @@diag_part
     78 @@trace
     79 @@transpose
     80 @@eye
     81 @@matrix_diag
     82 @@matrix_diag_part
     83 @@matrix_band_part
     84 @@matrix_set_diag
     85 @@matrix_transpose
     86 @@matmul
     87 @@norm
     88 @@matrix_determinant
     89 @@matrix_inverse
     90 @@cholesky
     91 @@cholesky_solve
     92 @@matrix_exponential
     93 @@matrix_logarithm
     94 @@matrix_solve
     95 @@matrix_triangular_solve
     96 @@matrix_solve_ls
     97 @@qr
     98 @@self_adjoint_eig
     99 @@self_adjoint_eigvals
    100 @@svd
    101 @@tensordot
    102 @@complex
    103 @@conj
    104 @@imag
    105 @@angle
    106 @@real
    107 @@fft
    108 @@ifft
    109 @@fft2d
    110 @@ifft2d
    111 @@fft3d
    112 @@ifft3d
    113 @@reduce_sum
    114 @@reduce_prod
    115 @@reduce_min
    116 @@reduce_max
    117 @@reduce_mean
    118 @@reduce_all
    119 @@reduce_any
    120 @@reduce_logsumexp
    121 @@count_nonzero
    122 @@accumulate_n
    123 @@einsum
    124 @@bincount
    125 @@cumsum
    126 @@cumprod
    127 @@segment_sum
    128 @@segment_prod
    129 @@segment_min
    130 @@segment_max
    131 @@segment_mean
    132 @@unsorted_segment_sum
    133 @@unsorted_segment_max
    134 @@unsorted_segment_min
    135 @@unsorted_segment_prod
    136 @@unsorted_segment_sqrt_n
    137 @@sparse_segment_sum
    138 @@sparse_segment_mean
    139 @@sparse_segment_sqrt_n
    140 @@argmin
    141 @@argmax
    142 @@setdiff1d
    143 @@where
    144 @@unique
    145 @@edit_distance
    146 @@invert_permutation
    147 """
    148 from __future__ import absolute_import
    149 from __future__ import division
    150 from __future__ import print_function
    151 
    152 import numpy as np
    153 from six.moves import xrange  # pylint: disable=redefined-builtin
    154 
    155 from tensorflow.python.eager import context
    156 from tensorflow.python.framework import common_shapes
    157 from tensorflow.python.framework import constant_op
    158 from tensorflow.python.framework import dtypes
    159 from tensorflow.python.framework import graph_util
    160 from tensorflow.python.framework import ops
    161 from tensorflow.python.framework import sparse_tensor
    162 from tensorflow.python.framework import tensor_shape
    163 from tensorflow.python.ops import array_ops
    164 from tensorflow.python.ops import gen_control_flow_ops
    165 from tensorflow.python.ops import gen_data_flow_ops
    166 from tensorflow.python.ops import gen_math_ops
    167 from tensorflow.python.ops import gen_nn_ops
    168 from tensorflow.python.ops import gen_sparse_ops
    169 from tensorflow.python.ops import gen_spectral_ops
    170 from tensorflow.python.ops import gen_state_ops
    171 from tensorflow.python.ops import state_ops
    172 # go/tf-wildcard-import
    173 # pylint: disable=wildcard-import
    174 from tensorflow.python.ops.gen_math_ops import *
    175 # pylint: enable=wildcard-import
    176 from tensorflow.python.util import compat
    177 from tensorflow.python.util import deprecation
    178 from tensorflow.python.util.tf_export import tf_export
    179 
    180 # Aliases for some automatically-generated names.
    181 linspace = gen_math_ops.lin_space
    182 
    183 arg_max = deprecation.deprecated(None, "Use `argmax` instead")(arg_max)  # pylint: disable=used-before-assignment
    184 arg_min = deprecation.deprecated(None, "Use `argmin` instead")(arg_min)  # pylint: disable=used-before-assignment
    185 
    186 
    187 def _set_doc(doc):
    188 
    189   def _decorator(func):
    190     func.__doc__ = doc
    191     return func
    192 
    193   return _decorator
    194 
    195 
    196 # pylint: disable=redefined-builtin
    197 @tf_export("argmax")
    198 @deprecation.deprecated_args(None, "Use the `axis` argument instead",
    199                              "dimension")
    200 @_set_doc(
    201     gen_math_ops.arg_max.__doc__.replace("dimensions", "axes").replace(
    202         "dimension", "axis"))
    203 def argmax(input,
    204            axis=None,
    205            name=None,
    206            dimension=None,
    207            output_type=dtypes.int64):
    208   if dimension is not None:
    209     if axis is not None:
    210       raise ValueError("Cannot specify both 'axis' and 'dimension'")
    211     axis = dimension
    212   elif axis is None:
    213     axis = 0
    214   return gen_math_ops.arg_max(input, axis, name=name, output_type=output_type)
    215 
    216 
    217 @tf_export("argmin")
    218 @deprecation.deprecated_args(None, "Use the `axis` argument instead",
    219                              "dimension")
    220 @_set_doc(
    221     gen_math_ops.arg_min.__doc__.replace("dimensions", "axes").replace(
    222         "dimension", "axis"))
    223 def argmin(input,
    224            axis=None,
    225            name=None,
    226            dimension=None,
    227            output_type=dtypes.int64):
    228   if dimension is not None:
    229     if axis is not None:
    230       raise ValueError("Cannot specify both 'axis' and 'dimension'")
    231     axis = dimension
    232   elif axis is None:
    233     axis = 0
    234   return gen_math_ops.arg_min(input, axis, name=name, output_type=output_type)
    235 
    236 
    237 # pylint: enable=redefined-builtin
    238 
    239 
    240 # pylint: disable=anomalous-backslash-in-string,protected-access
    241 # pylint: disable=g-docstring-has-escape
    242 @tf_export("abs")
    243 def abs(x, name=None):  # pylint: disable=redefined-builtin
    244   r"""Computes the absolute value of a tensor.
    245 
    246   Given a tensor `x` of complex numbers, this operation returns a tensor of type
    247   `float32` or `float64` that is the absolute value of each element in `x`. All
    248   elements in `x` must be complex numbers of the form \\(a + bj\\). The
    249   absolute value is computed as \\( \sqrt{a^2 + b^2}\\).  For example:
    250   ```python
    251   x = tf.constant([[-2.25 + 4.75j], [-3.25 + 5.75j]])
    252   tf.abs(x)  # [5.25594902, 6.60492229]
    253   ```
    254 
    255   Args:
    256     x: A `Tensor` or `SparseTensor` of type `float32`, `float64`, `int32`,
    257       `int64`, `complex64` or `complex128`.
    258     name: A name for the operation (optional).
    259 
    260   Returns:
    261     A `Tensor` or `SparseTensor` the same size and type as `x` with absolute
    262       values.
    263     Note, for `complex64` or `complex128` input, the returned `Tensor` will be
    264       of type `float32` or `float64`, respectively.
    265   """
    266   with ops.name_scope(name, "Abs", [x]) as name:
    267     if isinstance(x, sparse_tensor.SparseTensor):
    268       if x.values.dtype.is_complex:
    269         x_abs = gen_math_ops._complex_abs(
    270             x.values, Tout=x.values.dtype.real_dtype, name=name)
    271         return sparse_tensor.SparseTensor(
    272             indices=x.indices, values=x_abs, dense_shape=x.dense_shape)
    273       x_abs = gen_math_ops._abs(x.values, name=name)
    274       return sparse_tensor.SparseTensor(
    275           indices=x.indices, values=x_abs, dense_shape=x.dense_shape)
    276     else:
    277       x = ops.convert_to_tensor(x, name="x")
    278       if x.dtype.is_complex:
    279         return gen_math_ops._complex_abs(x, Tout=x.dtype.real_dtype, name=name)
    280       return gen_math_ops._abs(x, name=name)
    281 
    282 
    283 # pylint: enable=g-docstring-has-escape
    284 
    285 
    286 # pylint: disable=redefined-builtin
    287 def _bucketize(input, boundaries, name=None):
    288   return gen_math_ops._bucketize(input=input, boundaries=boundaries, name=name)
    289 
    290 
    291 # pylint: enable=redefined-builtin
    292 
    293 
    294 class DivideDelegateWithName(object):
    295   """Use Python2/Python3 division delegation to implement divide for tensors."""
    296 
    297   def __init__(self, x, name):
    298     """Construct DivideDelegateWithName.
    299 
    300     Args:
    301       x: Tensor to use as left operand in operator overloads
    302       name: The name that is preferred for the op created.
    303     """
    304     self.x = x
    305     self.name = name
    306 
    307   def __truediv__(self, y):
    308     return _truediv_python3(self.x, y, self.name)
    309 
    310   def __floordiv__(self, y):
    311     return floordiv(self.x, y, self.name)
    312 
    313   def __div__(self, y):
    314     return _div_python2(self.x, y, self.name)
    315 
    316 
    317 @tf_export("divide")
    318 def divide(x, y, name=None):
    319   """Computes Python style division of `x` by `y`."""
    320 
    321   if name is not None:
    322     # Cannot use tensors operator overload, because it has no way to track
    323     # override names. Use a dummy class to track the runtime division behavior
    324     return DivideDelegateWithName(x, name) / y
    325   else:
    326     return x / y
    327 
    328 
    329 @tf_export("multiply")
    330 def multiply(x, y, name=None):
    331   return gen_math_ops._mul(x, y, name)
    332 
    333 
    334 multiply.__doc__ = gen_math_ops._mul.__doc__.replace("Mul", "`tf.multiply`")
    335 
    336 
    337 # TODO(aselle): put deprecation in after another round of global code changes
    338 @deprecation.deprecated(
    339     "2016-12-30",
    340     "`tf.mul(x, y)` is deprecated, please use `tf.multiply(x, y)` or `x * y`")
    341 def _mul(x, y, name=None):
    342   return gen_math_ops._mul(x, y, name)
    343 
    344 
    345 _mul.__doc__ = (
    346     gen_math_ops._mul.__doc__ + ("" if _mul.__doc__ is None else _mul.__doc__))
    347 
    348 
    349 @tf_export("subtract")
    350 def subtract(x, y, name=None):
    351   return gen_math_ops._sub(x, y, name)
    352 
    353 
    354 subtract.__doc__ = gen_math_ops._sub.__doc__.replace("`Sub`", "`tf.subtract`")
    355 
    356 
    357 # TODO(aselle): put deprecation in after another round of global code changes
    358 @deprecation.deprecated(
    359     "2016-12-30",
    360     "`tf.sub(x, y)` is deprecated, please use `tf.subtract(x, y)` or `x - y`")
    361 def _sub(x, y, name=None):
    362   return gen_math_ops._sub(x, y, name)
    363 
    364 
    365 _sub.__doc__ = (
    366     gen_math_ops._sub.__doc__ + ("" if _sub.__doc__ is None else _sub.__doc__))
    367 
    368 
    369 # pylint: disable=g-docstring-has-escape
    370 @tf_export("negative")
    371 def negative(x, name=None):
    372   """Computes numerical negative value element-wise.
    373 
    374   I.e., \\(y = -x\\).
    375 
    376   Args:
    377     x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
    378       `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.
    379     name: A name for the operation (optional).
    380 
    381   Returns:
    382     A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
    383   """
    384   with ops.name_scope(name, "Neg", [x]) as name:
    385     if isinstance(x, sparse_tensor.SparseTensor):
    386       x_neg = gen_math_ops._neg(x.values, name=name)
    387       return sparse_tensor.SparseTensor(
    388           indices=x.indices, values=x_neg, dense_shape=x.dense_shape)
    389     else:
    390       return gen_math_ops._neg(x, name=name)
    391 
    392 
    393 # pylint: enable=g-docstring-has-escape
    394 
    395 
    396 # pylint: disable=g-docstring-has-escape
    397 @deprecation.deprecated(
    398     "2016-12-30",
    399     "`tf.neg(x)` is deprecated, please use `tf.negative(x)` or `-x`")
    400 def _neg(x, name=None):
    401   """Computes numerical negative value element-wise.
    402 
    403   I.e., \\(y = -x\\).
    404 
    405   Args:
    406     x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
    407       `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.
    408     name: A name for the operation (optional).
    409 
    410   Returns:
    411     A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
    412   """
    413   return negative(x, name)
    414 
    415 
    416 # pylint: enable=g-docstring-has-escape
    417 
    418 
    419 @tf_export("sign")
    420 def sign(x, name=None):
    421   """Returns an element-wise indication of the sign of a number.
    422 
    423   `y = sign(x) = -1` if `x < 0`; 0 if `x == 0` or `tf.is_nan(x)`; 1 if `x > 0`.
    424 
    425   Zero is returned for NaN inputs.
    426 
    427   For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`.
    428 
    429   Args:
    430     x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
    431       `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.
    432     name: A name for the operation (optional).
    433 
    434   Returns:
    435     A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
    436 
    437   @compatibility(numpy)
    438   Equivalent to numpy.sign except for the behavior for input values of NaN.
    439   @end_compatibility
    440   """
    441   with ops.name_scope(name, "Sign", [x]) as name:
    442     if isinstance(x, sparse_tensor.SparseTensor):
    443       x_sign = gen_math_ops.sign(x.values, name=name)
    444       return sparse_tensor.SparseTensor(
    445           indices=x.indices, values=x_sign, dense_shape=x.dense_shape)
    446     else:
    447       return gen_math_ops.sign(x, name=name)
    448 
    449 
    450 @tf_export("square")
    451 def square(x, name=None):
    452   r"""Computes square of x element-wise.
    453 
    454   I.e., \\(y = x * x = x^2\\).
    455 
    456   Args:
    457     x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
    458       `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.
    459     name: A name for the operation (optional).
    460 
    461   Returns:
    462     A `Tensor` or `SparseTensor`. Has the same type as `x`.
    463   """
    464   with ops.name_scope(name, "Square", [x]) as name:
    465     if isinstance(x, sparse_tensor.SparseTensor):
    466       x_square = gen_math_ops.square(x.values, name=name)
    467       return sparse_tensor.SparseTensor(
    468           indices=x.indices, values=x_square, dense_shape=x.dense_shape)
    469     else:
    470       return gen_math_ops.square(x, name=name)
    471 
    472 
    473 @tf_export("sqrt")
    474 def sqrt(x, name=None):
    475   r"""Computes square root of x element-wise.
    476 
    477   I.e., \\(y = \sqrt{x} = x^{1/2}\\).
    478 
    479   Args:
    480     x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
    481       `float32`, `float64`, `complex64`, `complex128`.
    482     name: A name for the operation (optional).
    483 
    484   Returns:
    485     A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
    486   """
    487   with ops.name_scope(name, "Sqrt", [x]) as name:
    488     if isinstance(x, sparse_tensor.SparseTensor):
    489       x_sqrt = gen_math_ops.sqrt(x.values, name=name)
    490       return sparse_tensor.SparseTensor(
    491           indices=x.indices, values=x_sqrt, dense_shape=x.dense_shape)
    492     else:
    493       return gen_math_ops.sqrt(x, name=name)
    494 
    495 
    496 @tf_export("erf")
    497 def erf(x, name=None):
    498   """Computes the Gauss error function of `x` element-wise.
    499 
    500   Args:
    501     x: A `Tensor` of `SparseTensor`. Must be one of the following types: `half`,
    502       `float32`, `float64`.
    503     name: A name for the operation (optional).
    504 
    505   Returns:
    506     A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
    507   """
    508   with ops.name_scope(name, "Erf", [x]) as name:
    509     if isinstance(x, sparse_tensor.SparseTensor):
    510       x_erf = gen_math_ops.erf(x.values, name=name)
    511       return sparse_tensor.SparseTensor(
    512           indices=x.indices, values=x_erf, dense_shape=x.dense_shape)
    513     else:
    514       return gen_math_ops.erf(x, name=name)
    515 
    516 
    517 @tf_export("scalar_mul")
    518 def scalar_mul(scalar, x):
    519   """Multiplies a scalar times a `Tensor` or `IndexedSlices` object.
    520 
    521   Intended for use in gradient code which might deal with `IndexedSlices`
    522   objects, which are easy to multiply by a scalar but more expensive to
    523   multiply with arbitrary tensors.
    524 
    525   Args:
    526     scalar: A 0-D scalar `Tensor`. Must have known shape.
    527     x: A `Tensor` or `IndexedSlices` to be scaled.
    528 
    529   Returns:
    530     `scalar * x` of the same type (`Tensor` or `IndexedSlices`) as `x`.
    531 
    532   Raises:
    533     ValueError: if scalar is not a 0-D `scalar`.
    534   """
    535   scalar = ops.convert_to_tensor(
    536       scalar, dtype=x.dtype.base_dtype, name="scalar")
    537   shape = scalar.get_shape()
    538   if shape.ndims == 0:
    539     if isinstance(x, ops.IndexedSlices):
    540       return ops.IndexedSlices(scalar * x.values, x.indices, x.dense_shape)
    541     else:
    542       return scalar * x
    543   else:
    544     raise ValueError("Only scalar multiply works, got shape %s" % shape)
    545 
    546 
    547 @tf_export("pow")
    548 def pow(x, y, name=None):  # pylint: disable=redefined-builtin
    549   r"""Computes the power of one value to another.
    550 
    551   Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
    552   corresponding elements in `x` and `y`. For example:
    553 
    554   ```python
    555   x = tf.constant([[2, 2], [3, 3]])
    556   y = tf.constant([[8, 16], [2, 3]])
    557   tf.pow(x, y)  # [[256, 65536], [9, 27]]
    558   ```
    559 
    560   Args:
    561     x: A `Tensor` of type `float32`, `float64`, `int32`, `int64`, `complex64`,
    562      or `complex128`.
    563     y: A `Tensor` of type `float32`, `float64`, `int32`, `int64`, `complex64`,
    564      or `complex128`.
    565     name: A name for the operation (optional).
    566 
    567   Returns:
    568     A `Tensor`.
    569   """
    570   with ops.name_scope(name, "Pow", [x]) as name:
    571     return gen_math_ops._pow(x, y, name=name)
    572 
    573 
    574 # pylint: disable=redefined-builtin,redefined-outer-name
    575 @tf_export("complex")
    576 def complex(real, imag, name=None):
    577   r"""Converts two real numbers to a complex number.
    578 
    579   Given a tensor `real` representing the real part of a complex number, and a
    580   tensor `imag` representing the imaginary part of a complex number, this
    581   operation returns complex numbers elementwise of the form \\(a + bj\\), where
    582   *a* represents the `real` part and *b* represents the `imag` part.
    583 
    584   The input tensors `real` and `imag` must have the same shape.
    585 
    586   For example:
    587 
    588   ```python
    589   real = tf.constant([2.25, 3.25])
    590   imag = tf.constant([4.75, 5.75])
    591   tf.complex(real, imag)  # [[2.25 + 4.75j], [3.25 + 5.75j]]
    592   ```
    593 
    594   Args:
    595     real: A `Tensor`. Must be one of the following types: `float32`,
    596       `float64`.
    597     imag: A `Tensor`. Must have the same type as `real`.
    598     name: A name for the operation (optional).
    599 
    600   Returns:
    601     A `Tensor` of type `complex64` or `complex128`.
    602   """
    603   real = ops.convert_to_tensor(real, name="real")
    604   imag = ops.convert_to_tensor(imag, name="imag")
    605   with ops.name_scope(name, "Complex", [real, imag]) as name:
    606     input_types = (real.dtype, imag.dtype)
    607     if input_types == (dtypes.float64, dtypes.float64):
    608       Tout = dtypes.complex128
    609     elif input_types == (dtypes.float32, dtypes.float32):
    610       Tout = dtypes.complex64
    611     else:
    612       raise TypeError("real and imag have incorrect types: "
    613                       "{} {}".format(real.dtype.name, imag.dtype.name))
    614     return gen_math_ops._complex(real, imag, Tout=Tout, name=name)
    615 
    616 
    617 @tf_export("real")
    618 def real(input, name=None):
    619   r"""Returns the real part of a complex (or real) tensor.
    620 
    621   Given a tensor `input`, this operation returns a tensor of type `float` that
    622   is the real part of each element in `input` considered as a complex number.
    623 
    624   For example:
    625 
    626   ```python
    627   x = tf.constant([-2.25 + 4.75j, 3.25 + 5.75j])
    628   tf.real(x)  # [-2.25, 3.25]
    629   ```
    630 
    631   If `input` is already real, it is returned unchanged.
    632 
    633   Args:
    634     input: A `Tensor`. Must have numeric type.
    635     name: A name for the operation (optional).
    636 
    637   Returns:
    638     A `Tensor` of type `float32` or `float64`.
    639   """
    640   with ops.name_scope(name, "Real", [input]) as name:
    641     if input.dtype.is_complex:
    642       real_dtype = input.dtype.real_dtype
    643       return gen_math_ops.real(input, Tout=real_dtype, name=name)
    644     else:
    645       return input
    646 
    647 
    648 @tf_export("imag")
    649 def imag(input, name=None):
    650   r"""Returns the imaginary part of a complex (or real) tensor.
    651 
    652   Given a tensor `input`, this operation returns a tensor of type `float` that
    653   is the imaginary part of each element in `input` considered as a complex
    654   number. If `input` is real, a tensor of all zeros is returned.
    655 
    656   For example:
    657 
    658   ```python
    659   x = tf.constant([-2.25 + 4.75j, 3.25 + 5.75j])
    660   tf.imag(x)  # [4.75, 5.75]
    661   ```
    662 
    663   Args:
    664     input: A `Tensor`. Must be one of the following types: `float`, `double`,
    665       `complex64`, `complex128`.
    666     name: A name for the operation (optional).
    667 
    668   Returns:
    669     A `Tensor` of type `float32` or `float64`.
    670   """
    671   with ops.name_scope(name, "Imag", [input]) as name:
    672     if input.dtype.is_complex:
    673       return gen_math_ops.imag(input, Tout=input.dtype.real_dtype, name=name)
    674     else:
    675       return array_ops.zeros_like(input)
    676 
    677 
    678 @tf_export("angle")
    679 def angle(input, name=None):
    680   r"""Returns the element-wise argument of a complex (or real) tensor.
    681 
    682   Given a tensor `input`, this operation returns a tensor of type `float` that
    683   is the argument of each element in `input` considered as a complex number.
    684 
    685   The elements in `input` are considered to be complex numbers of the form
    686   \\(a + bj\\), where *a* is the real part and *b* is the imaginary part.
    687   If `input` is real then *b* is zero by definition.
    688 
    689   The argument returned by this function is of the form \\(atan2(b, a)\\).
    690   If `input` is real, a tensor of all zeros is returned.
    691 
    692   For example:
    693 
    694   ```
    695   # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
    696   tf.angle(input) ==> [2.0132, 1.056]
    697   ```
    698 
    699   Args:
    700     input: A `Tensor`. Must be one of the following types: `float`, `double`,
    701       `complex64`, `complex128`.
    702     name: A name for the operation (optional).
    703 
    704   Returns:
    705     A `Tensor` of type `float32` or `float64`.
    706   """
    707   with ops.name_scope(name, "Angle", [input]) as name:
    708     if input.dtype.is_complex:
    709       return gen_math_ops.angle(input, Tout=input.dtype.real_dtype, name=name)
    710     else:
    711       return array_ops.zeros_like(input)
    712 
    713 
    714 # pylint: enable=redefined-outer-name,redefined-builtin
    715 
    716 
    717 @tf_export("round")
    718 def round(x, name=None):  # pylint: disable=redefined-builtin
    719   """Rounds the values of a tensor to the nearest integer, element-wise.
    720 
    721   Rounds half to even.  Also known as bankers rounding. If you want to round
    722   according to the current system rounding mode use tf::cint.
    723   For example:
    724 
    725   ```python
    726   x = tf.constant([0.9, 2.5, 2.3, 1.5, -4.5])
    727   tf.round(x)  # [ 1.0, 2.0, 2.0, 2.0, -4.0 ]
    728   ```
    729 
    730   Args:
    731     x: A `Tensor` of type `float32` or `float64`.
    732     name: A name for the operation (optional).
    733 
    734   Returns:
    735     A `Tensor` of same shape and type as `x`.
    736   """
    737   x = ops.convert_to_tensor(x, name="x")
    738   if x.dtype.is_integer:
    739     return x
    740   else:
    741     return gen_math_ops.round(x, name=name)
    742 
    743 
    744 @tf_export("cast")
    745 def cast(x, dtype, name=None):
    746   """Casts a tensor to a new type.
    747 
    748   The operation casts `x` (in case of `Tensor`) or `x.values`
    749   (in case of `SparseTensor`) to `dtype`.
    750 
    751   For example:
    752 
    753   ```python
    754   x = tf.constant([1.8, 2.2], dtype=tf.float32)
    755   tf.cast(x, tf.int32)  # [1, 2], dtype=tf.int32
    756   ```
    757 
    758   Args:
    759     x: A `Tensor` or `SparseTensor`.
    760     dtype: The destination type.
    761     name: A name for the operation (optional).
    762 
    763   Returns:
    764     A `Tensor` or `SparseTensor` with same shape as `x`.
    765 
    766   Raises:
    767     TypeError: If `x` cannot be cast to the `dtype`.
    768   """
    769   base_type = dtypes.as_dtype(dtype).base_dtype
    770   with ops.name_scope(name, "Cast", [x]) as name:
    771     if isinstance(x, sparse_tensor.SparseTensor):
    772       values_cast = cast(x.values, base_type, name=name)
    773       return sparse_tensor.SparseTensor(x.indices, values_cast, x.dense_shape)
    774     else:
    775       # TODO(josh11b): If x is not already a Tensor, we could return
    776       # ops.convert_to_tensor(x, dtype=dtype, ...)  here, but that
    777       # allows some conversions that cast() can't do, e.g. casting numbers to
    778       # strings.
    779       x = ops.convert_to_tensor(x, name="x")
    780       if x.dtype.base_dtype == base_type:
    781         return x
    782       return gen_math_ops.cast(x, base_type, name=name)
    783 
    784 
    785 @tf_export("saturate_cast")
    786 def saturate_cast(value, dtype, name=None):
    787   """Performs a safe saturating cast of `value` to `dtype`.
    788 
    789   This function casts the input to `dtype` without applying any scaling.  If
    790   there is a danger that values would over or underflow in the cast, this op
    791   applies the appropriate clamping before the cast.
    792 
    793   Args:
    794     value: A `Tensor`.
    795     dtype: The desired output `DType`.
    796     name: A name for the operation (optional).
    797 
    798   Returns:
    799     `value` safely cast to `dtype`.
    800   """
    801   # When casting to a type with smaller representable range, clamp.
    802   # Note that this covers casting to unsigned types as well.
    803   with ops.name_scope(name, "saturate_cast", [value]) as name:
    804     value = ops.convert_to_tensor(value, name="value")
    805     dtype = dtypes.as_dtype(dtype).base_dtype
    806     if value.dtype.min < dtype.min:
    807       value = gen_math_ops.maximum(value,
    808                                    ops.convert_to_tensor(
    809                                        dtype.min, dtype=value.dtype,
    810                                        name="min"))
    811     if value.dtype.max > dtype.max:
    812       value = gen_math_ops.minimum(value,
    813                                    ops.convert_to_tensor(
    814                                        dtype.max, dtype=value.dtype,
    815                                        name="max"))
    816     return cast(value, dtype, name=name)
    817 
    818 
    819 @tf_export("to_float")
    820 def to_float(x, name="ToFloat"):
    821   """Casts a tensor to type `float32`.
    822 
    823   Args:
    824     x: A `Tensor` or `SparseTensor`.
    825     name: A name for the operation (optional).
    826 
    827   Returns:
    828     A `Tensor` or `SparseTensor` with same shape as `x` with type `float32`.
    829 
    830   Raises:
    831     TypeError: If `x` cannot be cast to the `float32`.
    832   """
    833   return cast(x, dtypes.float32, name=name)
    834 
    835 
    836 @tf_export("to_double")
    837 def to_double(x, name="ToDouble"):
    838   """Casts a tensor to type `float64`.
    839 
    840   Args:
    841     x: A `Tensor` or `SparseTensor`.
    842     name: A name for the operation (optional).
    843 
    844   Returns:
    845     A `Tensor` or `SparseTensor` with same shape as `x` with type `float64`.
    846 
    847   Raises:
    848     TypeError: If `x` cannot be cast to the `float64`.
    849   """
    850   return cast(x, dtypes.float64, name=name)
    851 
    852 
    853 @tf_export("to_int32")
    854 def to_int32(x, name="ToInt32"):
    855   """Casts a tensor to type `int32`.
    856 
    857   Args:
    858     x: A `Tensor` or `SparseTensor`.
    859     name: A name for the operation (optional).
    860 
    861   Returns:
    862     A `Tensor` or `SparseTensor` with same shape as `x` with type `int32`.
    863 
    864   Raises:
    865     TypeError: If `x` cannot be cast to the `int32`.
    866   """
    867   return cast(x, dtypes.int32, name=name)
    868 
    869 
    870 @tf_export("to_int64")
    871 def to_int64(x, name="ToInt64"):
    872   """Casts a tensor to type `int64`.
    873 
    874   Args:
    875     x: A `Tensor` or `SparseTensor`.
    876     name: A name for the operation (optional).
    877 
    878   Returns:
    879     A `Tensor` or `SparseTensor` with same shape as `x` with type `int64`.
    880 
    881   Raises:
    882     TypeError: If `x` cannot be cast to the `int64`.
    883   """
    884   return cast(x, dtypes.int64, name=name)
    885 
    886 
    887 @tf_export("to_bfloat16")
    888 def to_bfloat16(x, name="ToBFloat16"):
    889   """Casts a tensor to type `bfloat16`.
    890 
    891   Args:
    892     x: A `Tensor` or `SparseTensor`.
    893     name: A name for the operation (optional).
    894 
    895   Returns:
    896     A `Tensor` or `SparseTensor` with same shape as `x` with type `bfloat16`.
    897 
    898   Raises:
    899     TypeError: If `x` cannot be cast to the `bfloat16`.
    900   """
    901   return cast(x, dtypes.bfloat16, name=name)
    902 
    903 
    904 ops.Tensor._override_operator("__neg__", gen_math_ops._neg)
    905 ops.Tensor._override_operator("__abs__", abs)
    906 # __invert__ corresponds to the ~ operator.  Here we follow the numpy convention
    907 # ~ marks an elementwise bit-wise inverse.  This is only implemented for boolean
    908 # tensors and will throw a TypeError if used on nonboolean arrays
    909 ops.Tensor._override_operator("__invert__", gen_math_ops.logical_not)
    910 
    911 
    912 def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
    913   """Register operators with different tensor and scalar versions.
    914 
    915   If `clazz_object` is `SparseTensor`, assumes `func` takes `(sp_indices,
    916   sp_values, sp_shape, dense)` and outputs `(new_sp_values)`.
    917 
    918   Args:
    919     func: the operator
    920     op_name: name of the operator being overridden
    921     clazz_object: class to override for.  Either `Tensor` or `SparseTensor`.
    922   """
    923 
    924   def binary_op_wrapper(x, y):
    925     with ops.name_scope(None, op_name, [x, y]) as name:
    926       if not isinstance(y, sparse_tensor.SparseTensor):
    927         try:
    928           y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")
    929         except TypeError:
    930           # If the RHS is not a tensor, it might be a tensor aware object
    931           # that can implement the operator with knowledge of itself
    932           # and the tensor.
    933           if hasattr(type(y), "__r%s__" % op_name):
    934             return NotImplemented
    935           else:
    936             raise
    937       return func(x, y, name=name)
    938 
    939   def binary_op_wrapper_sparse(sp_x, y):
    940     with ops.name_scope(None, op_name, [sp_x, y]) as name:
    941       y = ops.convert_to_tensor(y, dtype=sp_x.dtype.base_dtype, name="y")
    942       return sparse_tensor.SparseTensor(sp_x.indices,
    943                                         func(
    944                                             sp_x.indices,
    945                                             sp_x.values,
    946                                             sp_x.dense_shape,
    947                                             y,
    948                                             name=name), sp_x.dense_shape)
    949 
    950   def r_binary_op_wrapper(y, x):
    951     with ops.name_scope(None, op_name, [x, y]) as name:
    952       x = ops.convert_to_tensor(x, dtype=y.dtype.base_dtype, name="x")
    953       return func(x, y, name=name)
    954 
    955   # Propagate func.__doc__ to the wrappers
    956   try:
    957     doc = func.__doc__
    958   except AttributeError:
    959     doc = None
    960   binary_op_wrapper.__doc__ = doc
    961   r_binary_op_wrapper.__doc__ = doc
    962   binary_op_wrapper_sparse.__doc__ = doc
    963 
    964   if clazz_object is ops.Tensor:
    965     clazz_object._override_operator("__%s__" % op_name, binary_op_wrapper)
    966     del binary_op_wrapper
    967     clazz_object._override_operator("__r%s__" % op_name, r_binary_op_wrapper)
    968     del r_binary_op_wrapper
    969   else:
    970     clazz_object._override_operator("__%s__" % op_name,
    971                                     binary_op_wrapper_sparse)
    972     del binary_op_wrapper_sparse
    973 
    974 
    975 # Conversion table for __truediv__.  None entries mean no conversion required.
    976 _TRUEDIV_TABLE = {
    977     dtypes.uint8: dtypes.float32,
    978     dtypes.int8: dtypes.float32,
    979     dtypes.uint16: dtypes.float32,
    980     dtypes.int16: dtypes.float32,
    981     dtypes.int32: dtypes.float64,
    982     dtypes.int64: dtypes.float64,
    983     dtypes.bfloat16: None,
    984     dtypes.float16: None,
    985     dtypes.float32: None,
    986     dtypes.float64: None,
    987     dtypes.complex64: None,
    988     dtypes.complex128: None,
    989 }
    990 
    991 
    992 # NOTE: the support of "sparse (true)div dense" is currently not baked in into
    993 # "tf.(true_)div()".  Until such an API decision is made, the supported usage is
    994 # to explicitly use the "/" operator to invoke either truediv or div.
    995 def _sparse_dense_truediv(sp_indices, sp_values, sp_shape, y, name=None):
    996   """Internal helper function for 'sp_t / dense_t'."""
    997   with ops.name_scope(name, "truediv",
    998                       [sp_indices, sp_values, sp_shape, y]) as name:
    999     sp_values = ops.convert_to_tensor(sp_values, name="sp_values")
   1000     y = ops.convert_to_tensor(y, name="y")
   1001     x_dtype = sp_values.dtype.base_dtype
   1002     y_dtype = y.dtype.base_dtype
   1003     if x_dtype != y_dtype:
   1004       raise TypeError("x and y must have the same dtype, got %r != %r" %
   1005                       (x_dtype, y_dtype))
   1006     try:
   1007       dtype = _TRUEDIV_TABLE[x_dtype]
   1008     except KeyError:
   1009       raise TypeError("Invalid dtype %r in __truediv__" % x_dtype)
   1010     if dtype is not None:
   1011       sp_values = cast(sp_values, dtype)
   1012       y = cast(y, dtype)
   1013     return gen_sparse_ops.sparse_dense_cwise_div(
   1014         sp_indices, sp_values, sp_shape, y, name=name)
   1015 
   1016 
   1017 def _truediv_python3(x, y, name=None):
   1018   with ops.name_scope(name, "truediv", [x, y]) as name:
   1019     x = ops.convert_to_tensor(x, name="x")
   1020     y = ops.convert_to_tensor(y, name="y")
   1021     x_dtype = x.dtype.base_dtype
   1022     y_dtype = y.dtype.base_dtype
   1023     if x_dtype != y_dtype:
   1024       raise TypeError("x and y must have the same dtype, got %r != %r" %
   1025                       (x_dtype, y_dtype))
   1026     try:
   1027       dtype = _TRUEDIV_TABLE[x_dtype]
   1028     except KeyError:
   1029       raise TypeError("Invalid dtype %r in __truediv__" % x_dtype)
   1030     if dtype is not None:
   1031       x = cast(x, dtype)
   1032       y = cast(y, dtype)
   1033     return gen_math_ops._real_div(x, y, name=name)
   1034 
   1035 
   1036 def _div_python2(x, y, name=None):
   1037   """Divide two values using Python 2 semantics. Used for Tensor.__div__.
   1038 
   1039   Args:
   1040     x: `Tensor` numerator of real numeric type.
   1041     y: `Tensor` denominator of real numeric type.
   1042     name: A name for the operation (optional).
   1043   Returns:
   1044     `x / y` returns the quotient of x and y.
   1045   """
   1046 
   1047   with ops.name_scope(name, "div", [x, y]) as name:
   1048     x = ops.convert_to_tensor(x, name="x")
   1049     y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
   1050     x_dtype = x.dtype.base_dtype
   1051     y_dtype = y.dtype.base_dtype
   1052     if x_dtype != y_dtype:
   1053       raise TypeError("x and y must have the same dtype, got %r != %r" %
   1054                       (x_dtype, y_dtype))
   1055     if x_dtype.is_floating or x_dtype.is_complex:
   1056       return gen_math_ops._real_div(x, y, name=name)
   1057     else:
   1058       return gen_math_ops._floor_div(x, y, name=name)
   1059 
   1060 
   1061 @tf_export("truediv")
   1062 def truediv(x, y, name=None):
   1063   """Divides x / y elementwise (using Python 3 division operator semantics).
   1064 
   1065   NOTE: Prefer using the Tensor operator or tf.divide which obey Python
   1066   division operator semantics.
   1067 
   1068   This function forces Python 3 division operator semantics where all integer
   1069   arguments are cast to floating types first.   This op is generated by normal
   1070   `x / y` division in Python 3 and in Python 2.7 with
   1071   `from __future__ import division`.  If you want integer division that rounds
   1072   down, use `x // y` or `tf.floordiv`.
   1073 
   1074   `x` and `y` must have the same numeric type.  If the inputs are floating
   1075   point, the output will have the same type.  If the inputs are integral, the
   1076   inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
   1077   and `int64` (matching the behavior of Numpy).
   1078 
   1079   Args:
   1080     x: `Tensor` numerator of numeric type.
   1081     y: `Tensor` denominator of numeric type.
   1082     name: A name for the operation (optional).
   1083 
   1084   Returns:
   1085     `x / y` evaluated in floating point.
   1086 
   1087   Raises:
   1088     TypeError: If `x` and `y` have different dtypes.
   1089   """
   1090   return _truediv_python3(x, y, name)
   1091 
   1092 
   1093 @tf_export("div")
   1094 def div(x, y, name=None):
   1095   """Divides x / y elementwise (using Python 2 division operator semantics).
   1096 
   1097   NOTE: Prefer using the Tensor division operator or tf.divide which obey Python
   1098   division operator semantics.
   1099 
   1100   This function divides `x` and `y`, forcing Python 2.7 semantics. That is,
   1101   if one of `x` or `y` is a float, then the result will be a float.
   1102   Otherwise, the output will be an integer type. Flooring semantics are used
   1103   for integer division.
   1104 
   1105   Args:
   1106     x: `Tensor` numerator of real numeric type.
   1107     y: `Tensor` denominator of real numeric type.
   1108     name: A name for the operation (optional).
   1109   Returns:
   1110     `x / y` returns the quotient of x and y.
   1111   """
   1112   return _div_python2(x, y, name)
   1113 
   1114 
   1115 # TODO(aselle): This should be removed
   1116 mod = gen_math_ops._floor_mod
   1117 
   1118 
   1119 # TODO(aselle): Deprecate this once all internal functionality uses
   1120 # tf.truncatediv
   1121 @tf_export("floordiv")
   1122 def floordiv(x, y, name=None):
   1123   """Divides `x / y` elementwise, rounding toward the most negative integer.
   1124 
   1125   The same as `tf.div(x,y)` for integers, but uses `tf.floor(tf.div(x,y))` for
   1126   floating point arguments so that the result is always an integer (though
   1127   possibly an integer represented as floating point).  This op is generated by
   1128   `x // y` floor division in Python 3 and in Python 2.7 with
   1129   `from __future__ import division`.
   1130 
   1131   Note that for efficiency, `floordiv` uses C semantics for negative numbers
   1132   (unlike Python and Numpy).
   1133 
   1134   `x` and `y` must have the same type, and the result will have the same type
   1135   as well.
   1136 
   1137   Args:
   1138     x: `Tensor` numerator of real numeric type.
   1139     y: `Tensor` denominator of real numeric type.
   1140     name: A name for the operation (optional).
   1141 
   1142   Returns:
   1143     `x / y` rounded down (except possibly towards zero for negative integers).
   1144 
   1145   Raises:
   1146     TypeError: If the inputs are complex.
   1147   """
   1148   with ops.name_scope(name, "floordiv", [x, y]) as name:
   1149     return gen_math_ops._floor_div(x, y, name=name)
   1150 
   1151 
   1152 realdiv = gen_math_ops._real_div
   1153 truncatediv = gen_math_ops._truncate_div
   1154 # TODO(aselle): Rename this to floordiv when we can.
   1155 floor_div = gen_math_ops._floor_div
   1156 truncatemod = gen_math_ops._truncate_mod
   1157 floormod = gen_math_ops._floor_mod
   1158 
   1159 
   1160 def _mul_dispatch(x, y, name=None):
   1161   """Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse"."""
   1162   is_tensor_y = isinstance(y, ops.Tensor)
   1163   if is_tensor_y:
   1164     return gen_math_ops._mul(x, y, name=name)
   1165   else:
   1166     assert isinstance(y, sparse_tensor.SparseTensor)  # Case: Dense * Sparse.
   1167     new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values,
   1168                                                      y.dense_shape, x, name)
   1169     return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape)
   1170 
   1171 
   1172 # NOTE(aselle): When integer division is added for sparse_dense_cwise,
   1173 # div, truediv, and floordiv should be delegated appropriately for
   1174 # Python sematnics, analogous to dense cwise tensor operations.
   1175 _OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_div, "div",
   1176                               sparse_tensor.SparseTensor)
   1177 _OverrideBinaryOperatorHelper(_sparse_dense_truediv, "truediv",
   1178                               sparse_tensor.SparseTensor)
   1179 _OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_mul, "mul",
   1180                               sparse_tensor.SparseTensor)
   1181 
   1182 _OverrideBinaryOperatorHelper(gen_math_ops.add, "add")
   1183 _OverrideBinaryOperatorHelper(gen_math_ops._sub, "sub")
   1184 _OverrideBinaryOperatorHelper(_mul_dispatch, "mul")
   1185 _OverrideBinaryOperatorHelper(_div_python2, "div")
   1186 _OverrideBinaryOperatorHelper(_truediv_python3, "truediv")
   1187 _OverrideBinaryOperatorHelper(floordiv, "floordiv")
   1188 _OverrideBinaryOperatorHelper(gen_math_ops._floor_mod, "mod")
   1189 _OverrideBinaryOperatorHelper(pow, "pow")
   1190 
   1191 
   1192 @tf_export("logical_xor")
   1193 def logical_xor(x, y, name="LogicalXor"):
   1194   """x ^ y = (x | y) & ~(x & y)."""
   1195   # TODO(alemi) Make this a cwise op if people end up relying on it.
   1196   return gen_math_ops.logical_and(
   1197       gen_math_ops.logical_or(x, y),
   1198       gen_math_ops.logical_not(gen_math_ops.logical_and(x, y)),
   1199       name=name)
   1200 
   1201 
   1202 _OverrideBinaryOperatorHelper(gen_math_ops.logical_and, "and")
   1203 _OverrideBinaryOperatorHelper(gen_math_ops.logical_or, "or")
   1204 _OverrideBinaryOperatorHelper(logical_xor, "xor")
   1205 
   1206 ops.Tensor._override_operator("__lt__", gen_math_ops.less)
   1207 ops.Tensor._override_operator("__le__", gen_math_ops.less_equal)
   1208 ops.Tensor._override_operator("__gt__", gen_math_ops.greater)
   1209 ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)
   1210 
   1211 
   1212 @tf_export("range")
   1213 def range(start, limit=None, delta=1, dtype=None, name="range"):  # pylint: disable=redefined-builtin
   1214   """Creates a sequence of numbers.
   1215 
   1216   Creates a sequence of numbers that begins at `start` and extends by
   1217   increments of `delta` up to but not including `limit`.
   1218 
   1219   The dtype of the resulting tensor is inferred from the inputs unless
   1220   it is provided explicitly.
   1221 
   1222   Like the Python builtin `range`, `start` defaults to 0, so that
   1223   `range(n) = range(0, n)`.
   1224 
   1225   For example:
   1226 
   1227   ```python
   1228   start = 3
   1229   limit = 18
   1230   delta = 3
   1231   tf.range(start, limit, delta)  # [3, 6, 9, 12, 15]
   1232 
   1233   start = 3
   1234   limit = 1
   1235   delta = -0.5
   1236   tf.range(start, limit, delta)  # [3, 2.5, 2, 1.5]
   1237 
   1238   limit = 5
   1239   tf.range(limit)  # [0, 1, 2, 3, 4]
   1240   ```
   1241 
   1242   Args:
   1243     start: A 0-D `Tensor` (scalar). Acts as first entry in the range if
   1244       `limit` is not None; otherwise, acts as range limit and first entry
   1245       defaults to 0.
   1246     limit: A 0-D `Tensor` (scalar). Upper limit of sequence,
   1247       exclusive. If None, defaults to the value of `start` while the first
   1248       entry of the range defaults to 0.
   1249     delta: A 0-D `Tensor` (scalar). Number that increments
   1250       `start`. Defaults to 1.
   1251     dtype: The type of the elements of the resulting tensor.
   1252     name: A name for the operation. Defaults to "range".
   1253 
   1254   Returns:
   1255     An 1-D `Tensor` of type `dtype`.
   1256 
   1257   @compatibility(numpy)
   1258   Equivalent to np.arange
   1259   @end_compatibility
   1260   """
   1261   if limit is None:
   1262     start, limit = 0, start
   1263 
   1264   with ops.name_scope(name, "Range", [start, limit, delta]) as name:
   1265     start = ops.convert_to_tensor(start, dtype=dtype, name="start")
   1266     limit = ops.convert_to_tensor(limit, dtype=dtype, name="limit")
   1267     delta = ops.convert_to_tensor(delta, dtype=dtype, name="delta")
   1268 
   1269     # infer dtype if not explicitly provided
   1270     if dtype is None:
   1271       dtype_hierarchy = [
   1272           dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
   1273       ]
   1274       assert all(arg.dtype in dtype_hierarchy for arg in [start, limit, delta])
   1275       inferred_dtype = max(
   1276           [arg.dtype for arg in [start, limit, delta]],
   1277           key=dtype_hierarchy.index)
   1278 
   1279       start = cast(start, inferred_dtype)
   1280       limit = cast(limit, inferred_dtype)
   1281       delta = cast(delta, inferred_dtype)
   1282 
   1283     return gen_math_ops._range(start, limit, delta, name=name)
   1284 
   1285 
   1286 # Reduction operations
   1287 def _ReductionDims(x, axis, reduction_indices):
   1288   """Returns range(0, rank(x)) if reduction_indices is None."""
   1289   # TODO(aselle): Remove this after deprecation
   1290   if reduction_indices is not None:
   1291     if axis is not None:
   1292       raise ValueError("Can't specify both axis' and 'reduction_indices'.")
   1293     axis = reduction_indices
   1294   if axis is not None:
   1295     return axis
   1296   else:
   1297     # Fast path: avoid creating Rank and Range ops if ndims is known.
   1298     if isinstance(x, ops.Tensor) and x.get_shape().ndims is not None:
   1299       return constant_op.constant(
   1300           np.arange(x.get_shape().ndims), dtype=dtypes.int32)
   1301     if (isinstance(x, sparse_tensor.SparseTensor) and
   1302         x.dense_shape.get_shape().is_fully_defined()):
   1303       rank = x.dense_shape.get_shape()[0].value  # sparse.dense_shape is 1-D.
   1304       return constant_op.constant(np.arange(rank), dtype=dtypes.int32)
   1305 
   1306     # Otherwise, we rely on Range and Rank to do the right thing at run-time.
   1307     return range(0, array_ops.rank(x))
   1308 
   1309 
   1310 def _may_reduce_to_scalar(keepdims, axis, reduction_indices, output):
   1311   """Set a reduction's output's shape to be a scalar if we are certain."""
   1312   if (not output.shape.is_fully_defined()) and (not keepdims) and (
   1313       axis is None) and (reduction_indices is None):
   1314     output.set_shape(())
   1315   return output
   1316 
   1317 
   1318 @tf_export("reduce_sum")
   1319 @deprecation.deprecated_args(
   1320     None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
   1321 def reduce_sum(input_tensor,
   1322                axis=None,
   1323                keepdims=None,
   1324                name=None,
   1325                reduction_indices=None,
   1326                keep_dims=None):
   1327   """Computes the sum of elements across dimensions of a tensor.
   1328 
   1329   Reduces `input_tensor` along the dimensions given in `axis`.
   1330   Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
   1331   entry in `axis`. If `keepdims` is true, the reduced dimensions
   1332   are retained with length 1.
   1333 
   1334   If `axis` has no entries, all dimensions are reduced, and a
   1335   tensor with a single element is returned.
   1336 
   1337   For example:
   1338 
   1339   ```python
   1340   x = tf.constant([[1, 1, 1], [1, 1, 1]])
   1341   tf.reduce_sum(x)  # 6
   1342   tf.reduce_sum(x, 0)  # [2, 2, 2]
   1343   tf.reduce_sum(x, 1)  # [3, 3]
   1344   tf.reduce_sum(x, 1, keepdims=True)  # [[3], [3]]
   1345   tf.reduce_sum(x, [0, 1])  # 6
   1346   ```
   1347 
   1348   Args:
   1349     input_tensor: The tensor to reduce. Should have numeric type.
   1350     axis: The dimensions to reduce. If `None` (the default),
   1351       reduces all dimensions. Must be in the range
   1352       `[-rank(input_tensor), rank(input_tensor))`.
   1353     keepdims: If true, retains reduced dimensions with length 1.
   1354     name: A name for the operation (optional).
   1355     reduction_indices: The old (deprecated) name for axis.
   1356     keep_dims: Deprecated alias for `keepdims`.
   1357 
   1358   Returns:
   1359     The reduced tensor.
   1360 
   1361   @compatibility(numpy)
   1362   Equivalent to np.sum
   1363   @end_compatibility
   1364   """
   1365   keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
   1366                                                     "keep_dims", keep_dims)
   1367   if keepdims is None:
   1368     keepdims = False
   1369 
   1370   return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
   1371                                gen_math_ops._sum(
   1372                                    input_tensor,
   1373                                    _ReductionDims(input_tensor, axis,
   1374                                                   reduction_indices),
   1375                                    keepdims,
   1376                                    name=name))
   1377 
   1378 
   1379 @tf_export("count_nonzero")
   1380 @deprecation.deprecated_args(
   1381     None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
   1382 def count_nonzero(input_tensor,
   1383                   axis=None,
   1384                   keepdims=None,
   1385                   dtype=dtypes.int64,
   1386                   name=None,
   1387                   reduction_indices=None,
   1388                   keep_dims=None):
   1389   """Computes number of nonzero elements across dimensions of a tensor.
   1390 
   1391   Reduces `input_tensor` along the dimensions given in `axis`.
   1392   Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
   1393   entry in `axis`. If `keepdims` is true, the reduced dimensions
   1394   are retained with length 1.
   1395 
   1396   If `axis` has no entries, all dimensions are reduced, and a
   1397   tensor with a single element is returned.
   1398 
   1399   **NOTE** Floating point comparison to zero is done by exact floating point
   1400   equality check.  Small values are **not** rounded to zero for purposes of
   1401   the nonzero check.
   1402 
   1403   For example:
   1404 
   1405   ```python
   1406   x = tf.constant([[0, 1, 0], [1, 1, 0]])
   1407   tf.count_nonzero(x)  # 3
   1408   tf.count_nonzero(x, 0)  # [1, 2, 0]
   1409   tf.count_nonzero(x, 1)  # [1, 2]
   1410   tf.count_nonzero(x, 1, keepdims=True)  # [[1], [2]]
   1411   tf.count_nonzero(x, [0, 1])  # 3
   1412   ```
   1413 
   1414   Args:
   1415     input_tensor: The tensor to reduce. Should be of numeric type, or `bool`.
   1416     axis: The dimensions to reduce. If `None` (the default),
   1417       reduces all dimensions. Must be in the range
   1418       `[-rank(input_tensor), rank(input_tensor))`.
   1419     keepdims: If true, retains reduced dimensions with length 1.
   1420     dtype: The output dtype; defaults to `tf.int64`.
   1421     name: A name for the operation (optional).
   1422     reduction_indices: The old (deprecated) name for axis.
   1423     keep_dims: Deprecated alias for `keepdims`.
   1424 
   1425   Returns:
   1426     The reduced tensor (number of nonzero values).
   1427   """
   1428   keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
   1429                                                     "keep_dims", keep_dims)
   1430   if keepdims is None:
   1431     keepdims = False
   1432 
   1433   with ops.name_scope(name, "count_nonzero", [input_tensor]):
   1434     input_tensor = ops.convert_to_tensor(input_tensor, name="input_tensor")
   1435     zero = input_tensor.dtype.as_numpy_dtype()
   1436     return cast(
   1437         reduce_sum(
   1438             # int64 reduction happens on GPU
   1439             to_int64(gen_math_ops.not_equal(input_tensor, zero)),
   1440             axis=axis,
   1441             keepdims=keepdims,
   1442             reduction_indices=reduction_indices),
   1443         dtype=dtype)
   1444 
   1445 
   1446 @tf_export("reduce_mean")
   1447 @deprecation.deprecated_args(
   1448     None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
   1449 def reduce_mean(input_tensor,
   1450                 axis=None,
   1451                 keepdims=None,
   1452                 name=None,
   1453                 reduction_indices=None,
   1454                 keep_dims=None):
   1455   """Computes the mean of elements across dimensions of a tensor.
   1456 
   1457   Reduces `input_tensor` along the dimensions given in `axis`.
   1458   Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
   1459   entry in `axis`. If `keepdims` is true, the reduced dimensions
   1460   are retained with length 1.
   1461 
   1462   If `axis` has no entries, all dimensions are reduced, and a
   1463   tensor with a single element is returned.
   1464 
   1465   For example:
   1466 
   1467   ```python
   1468   x = tf.constant([[1., 1.], [2., 2.]])
   1469   tf.reduce_mean(x)  # 1.5
   1470   tf.reduce_mean(x, 0)  # [1.5, 1.5]
   1471   tf.reduce_mean(x, 1)  # [1.,  2.]
   1472   ```
   1473 
   1474   Args:
   1475     input_tensor: The tensor to reduce. Should have numeric type.
   1476     axis: The dimensions to reduce. If `None` (the default),
   1477       reduces all dimensions. Must be in the range
   1478       `[-rank(input_tensor), rank(input_tensor)]`.
   1479     keepdims: If true, retains reduced dimensions with length 1.
   1480     name: A name for the operation (optional).
   1481     reduction_indices: The old (deprecated) name for axis.
   1482     keep_dims: Deprecated alias for `keepdims`.
   1483 
   1484   Returns:
   1485     The reduced tensor.
   1486 
   1487   @compatibility(numpy)
   1488   Equivalent to np.mean
   1489 
   1490   Please note that `np.mean` has a `dtype` parameter that could be used to
   1491   specify the output type. By default this is `dtype=float64`. On the other
   1492   hand, `tf.reduce_mean` has an aggressive type inference from `input_tensor`,
   1493   for example:
   1494 
   1495   ```python
   1496   x = tf.constant([1, 0, 1, 0])
   1497   tf.reduce_mean(x)  # 0
   1498   y = tf.constant([1., 0., 1., 0.])
   1499   tf.reduce_mean(y)  # 0.5
   1500   ```
   1501 
   1502   @end_compatibility
   1503   """
   1504   keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
   1505                                                     "keep_dims", keep_dims)
   1506 
   1507   if keepdims is None:
   1508     keepdims = False
   1509   return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
   1510                                gen_math_ops._mean(
   1511                                    input_tensor,
   1512                                    _ReductionDims(input_tensor, axis,
   1513                                                   reduction_indices),
   1514                                    keepdims,
   1515                                    name=name))
   1516 
   1517 
   1518 @tf_export("reduce_prod")
   1519 @deprecation.deprecated_args(
   1520     None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
   1521 def reduce_prod(input_tensor,
   1522                 axis=None,
   1523                 keepdims=None,
   1524                 name=None,
   1525                 reduction_indices=None,
   1526                 keep_dims=None):
   1527   """Computes the product of elements across dimensions of a tensor.
   1528 
   1529   Reduces `input_tensor` along the dimensions given in `axis`.
   1530   Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
   1531   entry in `axis`. If `keepdims` is true, the reduced dimensions
   1532   are retained with length 1.
   1533 
   1534   If `axis` has no entries, all dimensions are reduced, and a
   1535   tensor with a single element is returned.
   1536 
   1537   Args:
   1538     input_tensor: The tensor to reduce. Should have numeric type.
   1539     axis: The dimensions to reduce. If `None` (the default),
   1540       reduces all dimensions. Must be in the range
   1541       `[-rank(input_tensor), rank(input_tensor))`.
   1542     keepdims: If true, retains reduced dimensions with length 1.
   1543     name: A name for the operation (optional).
   1544     reduction_indices: The old (deprecated) name for axis.
   1545     keep_dims: Deprecated alias for `keepdims`.
   1546 
   1547   Returns:
   1548     The reduced tensor.
   1549 
   1550   @compatibility(numpy)
   1551   Equivalent to np.prod
   1552   @end_compatibility
   1553   """
   1554   keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
   1555                                                     "keep_dims", keep_dims)
   1556 
   1557   if keepdims is None:
   1558     keepdims = False
   1559   return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
   1560                                gen_math_ops._prod(
   1561                                    input_tensor,
   1562                                    _ReductionDims(input_tensor, axis,
   1563                                                   reduction_indices),
   1564                                    keepdims,
   1565                                    name=name))
   1566 
   1567 
   1568 @tf_export("reduce_min")
   1569 @deprecation.deprecated_args(
   1570     None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
   1571 def reduce_min(input_tensor,
   1572                axis=None,
   1573                keepdims=None,
   1574                name=None,
   1575                reduction_indices=None,
   1576                keep_dims=None):
   1577   """Computes the minimum of elements across dimensions of a tensor.
   1578 
   1579   Reduces `input_tensor` along the dimensions given in `axis`.
   1580   Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
   1581   entry in `axis`. If `keepdims` is true, the reduced dimensions
   1582   are retained with length 1.
   1583 
   1584   If `axis` has no entries, all dimensions are reduced, and a
   1585   tensor with a single element is returned.
   1586 
   1587   Args:
   1588     input_tensor: The tensor to reduce. Should have numeric type.
   1589     axis: The dimensions to reduce. If `None` (the default),
   1590       reduces all dimensions. Must be in the range
   1591       `[-rank(input_tensor), rank(input_tensor))`.
   1592     keepdims: If true, retains reduced dimensions with length 1.
   1593     name: A name for the operation (optional).
   1594     reduction_indices: The old (deprecated) name for axis.
   1595     keep_dims: Deprecated alias for `keepdims`.
   1596 
   1597   Returns:
   1598     The reduced tensor.
   1599 
   1600   @compatibility(numpy)
   1601   Equivalent to np.min
   1602   @end_compatibility
   1603   """
   1604   keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
   1605                                                     "keep_dims", keep_dims)
   1606   if keepdims is None:
   1607     keepdims = False
   1608   return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
   1609                                gen_math_ops._min(
   1610                                    input_tensor,
   1611                                    _ReductionDims(input_tensor, axis,
   1612                                                   reduction_indices),
   1613                                    keepdims,
   1614                                    name=name))
   1615 
   1616 
   1617 @tf_export("reduce_max")
   1618 @deprecation.deprecated_args(
   1619     None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
   1620 def reduce_max(input_tensor,
   1621                axis=None,
   1622                keepdims=None,
   1623                name=None,
   1624                reduction_indices=None,
   1625                keep_dims=None):
   1626   """Computes the maximum of elements across dimensions of a tensor.
   1627 
   1628   Reduces `input_tensor` along the dimensions given in `axis`.
   1629   Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
   1630   entry in `axis`. If `keepdims` is true, the reduced dimensions
   1631   are retained with length 1.
   1632 
   1633   If `axis` has no entries, all dimensions are reduced, and a
   1634   tensor with a single element is returned.
   1635 
   1636   Args:
   1637     input_tensor: The tensor to reduce. Should have numeric type.
   1638     axis: The dimensions to reduce. If `None` (the default),
   1639       reduces all dimensions. Must be in the range
   1640       `[-rank(input_tensor), rank(input_tensor))`.
   1641     keepdims: If true, retains reduced dimensions with length 1.
   1642     name: A name for the operation (optional).
   1643     reduction_indices: The old (deprecated) name for axis.
   1644     keep_dims: Deprecated alias for `keepdims`.
   1645 
   1646   Returns:
   1647     The reduced tensor.
   1648 
   1649   @compatibility(numpy)
   1650   Equivalent to np.max
   1651   @end_compatibility
   1652   """
   1653   keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
   1654                                                     "keep_dims", keep_dims)
   1655   if keepdims is None:
   1656     keepdims = False
   1657   return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
   1658                                gen_math_ops._max(
   1659                                    input_tensor,
   1660                                    _ReductionDims(input_tensor, axis,
   1661                                                   reduction_indices),
   1662                                    keepdims,
   1663                                    name=name))
   1664 
   1665 
   1666 @tf_export("reduce_all")
   1667 @deprecation.deprecated_args(
   1668     None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
   1669 def reduce_all(input_tensor,
   1670                axis=None,
   1671                keepdims=None,
   1672                name=None,
   1673                reduction_indices=None,
   1674                keep_dims=None):
   1675   """Computes the "logical and" of elements across dimensions of a tensor.
   1676 
   1677   Reduces `input_tensor` along the dimensions given in `axis`.
   1678   Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
   1679   entry in `axis`. If `keepdims` is true, the reduced dimensions
   1680   are retained with length 1.
   1681 
   1682   If `axis` has no entries, all dimensions are reduced, and a
   1683   tensor with a single element is returned.
   1684 
   1685   For example:
   1686 
   1687   ```python
   1688   x = tf.constant([[True,  True], [False, False]])
   1689   tf.reduce_all(x)  # False
   1690   tf.reduce_all(x, 0)  # [False, False]
   1691   tf.reduce_all(x, 1)  # [True, False]
   1692   ```
   1693 
   1694   Args:
   1695     input_tensor: The boolean tensor to reduce.
   1696     axis: The dimensions to reduce. If `None` (the default),
   1697       reduces all dimensions. Must be in the range
   1698       `[-rank(input_tensor), rank(input_tensor))`.
   1699     keepdims: If true, retains reduced dimensions with length 1.
   1700     name: A name for the operation (optional).
   1701     reduction_indices: The old (deprecated) name for axis.
   1702     keep_dims: Deprecated alias for `keepdims`.
   1703 
   1704   Returns:
   1705     The reduced tensor.
   1706 
   1707   @compatibility(numpy)
   1708   Equivalent to np.all
   1709   @end_compatibility
   1710   """
   1711   keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
   1712                                                     "keep_dims", keep_dims)
   1713   if keepdims is None:
   1714     keepdims = False
   1715   return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
   1716                                gen_math_ops._all(
   1717                                    input_tensor,
   1718                                    _ReductionDims(input_tensor, axis,
   1719                                                   reduction_indices),
   1720                                    keepdims,
   1721                                    name=name))
   1722 
   1723 
   1724 @tf_export("reduce_any")
   1725 @deprecation.deprecated_args(
   1726     None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
   1727 def reduce_any(input_tensor,
   1728                axis=None,
   1729                keepdims=None,
   1730                name=None,
   1731                reduction_indices=None,
   1732                keep_dims=None):
   1733   """Computes the "logical or" of elements across dimensions of a tensor.
   1734 
   1735   Reduces `input_tensor` along the dimensions given in `axis`.
   1736   Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
   1737   entry in `axis`. If `keepdims` is true, the reduced dimensions
   1738   are retained with length 1.
   1739 
   1740   If `axis` has no entries, all dimensions are reduced, and a
   1741   tensor with a single element is returned.
   1742 
   1743   For example:
   1744 
   1745   ```python
   1746   x = tf.constant([[True,  True], [False, False]])
   1747   tf.reduce_any(x)  # True
   1748   tf.reduce_any(x, 0)  # [True, True]
   1749   tf.reduce_any(x, 1)  # [True, False]
   1750   ```
   1751 
   1752   Args:
   1753     input_tensor: The boolean tensor to reduce.
   1754     axis: The dimensions to reduce. If `None` (the default),
   1755       reduces all dimensions. Must be in the range
   1756       `[-rank(input_tensor), rank(input_tensor))`.
   1757     keepdims: If true, retains reduced dimensions with length 1.
   1758     name: A name for the operation (optional).
   1759     reduction_indices: The old (deprecated) name for axis.
   1760     keep_dims: Deprecated alias for `keepdims`.
   1761 
   1762   Returns:
   1763     The reduced tensor.
   1764 
   1765   @compatibility(numpy)
   1766   Equivalent to np.any
   1767   @end_compatibility
   1768   """
   1769   keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
   1770                                                     "keep_dims", keep_dims)
   1771   if keepdims is None:
   1772     keepdims = False
   1773   return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
   1774                                gen_math_ops._any(
   1775                                    input_tensor,
   1776                                    _ReductionDims(input_tensor, axis,
   1777                                                   reduction_indices),
   1778                                    keepdims,
   1779                                    name=name))
   1780 
   1781 
   1782 @tf_export("reduce_logsumexp")
   1783 @deprecation.deprecated_args(
   1784     None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
   1785 def reduce_logsumexp(input_tensor,
   1786                      axis=None,
   1787                      keepdims=None,
   1788                      name=None,
   1789                      reduction_indices=None,
   1790                      keep_dims=None):
   1791   """Computes log(sum(exp(elements across dimensions of a tensor))).
   1792 
   1793   Reduces `input_tensor` along the dimensions given in `axis`.
   1794   Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
   1795   entry in `axis`. If `keepdims` is true, the reduced dimensions
   1796   are retained with length 1.
   1797 
   1798   If `axis` has no entries, all dimensions are reduced, and a
   1799   tensor with a single element is returned.
   1800 
   1801   This function is more numerically stable than log(sum(exp(input))). It avoids
   1802   overflows caused by taking the exp of large inputs and underflows caused by
   1803   taking the log of small inputs.
   1804 
   1805   For example:
   1806 
   1807   ```python
   1808   x = tf.constant([[0., 0., 0.], [0., 0., 0.]])
   1809   tf.reduce_logsumexp(x)  # log(6)
   1810   tf.reduce_logsumexp(x, 0)  # [log(2), log(2), log(2)]
   1811   tf.reduce_logsumexp(x, 1)  # [log(3), log(3)]
   1812   tf.reduce_logsumexp(x, 1, keepdims=True)  # [[log(3)], [log(3)]]
   1813   tf.reduce_logsumexp(x, [0, 1])  # log(6)
   1814   ```
   1815 
   1816   Args:
   1817     input_tensor: The tensor to reduce. Should have numeric type.
   1818     axis: The dimensions to reduce. If `None` (the default),
   1819       reduces all dimensions. Must be in the range
   1820       `[-rank(input_tensor), rank(input_tensor))`.
   1821     keepdims: If true, retains reduced dimensions with length 1.
   1822     name: A name for the operation (optional).
   1823     reduction_indices: The old (deprecated) name for axis.
   1824     keep_dims: Deprecated alias for `keepdims`.
   1825 
   1826   Returns:
   1827     The reduced tensor.
   1828   """
   1829   keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
   1830                                                     "keep_dims", keep_dims)
   1831   if keepdims is None:
   1832     keepdims = False
   1833   with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name:
   1834     raw_max = reduce_max(
   1835         input_tensor,
   1836         axis=axis,
   1837         reduction_indices=reduction_indices,
   1838         keepdims=True)
   1839     my_max = array_ops.stop_gradient(
   1840         array_ops.where(
   1841             gen_math_ops.is_finite(raw_max), raw_max,
   1842             array_ops.zeros_like(raw_max)))
   1843     result = gen_math_ops.log(
   1844         reduce_sum(
   1845             gen_math_ops.exp(input_tensor - my_max),
   1846             axis,
   1847             keepdims=keepdims,
   1848             reduction_indices=reduction_indices))
   1849     if not keepdims:
   1850       my_max = array_ops.reshape(my_max, array_ops.shape(result))
   1851     result += my_max
   1852     return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result)
   1853 
   1854 
   1855 @tf_export("trace", "linalg.trace")
   1856 def trace(x, name=None):
   1857   """Compute the trace of a tensor `x`.
   1858 
   1859   `trace(x)` returns the sum along the main diagonal of each inner-most matrix
   1860   in x. If x is of rank `k` with shape `[I, J, K, ..., L, M, N]`, then output
   1861   is a tensor of rank `k-2` with dimensions `[I, J, K, ..., L]` where
   1862 
   1863   `output[i, j, k, ..., l] = trace(x[i, j, i, ..., l, :, :])`
   1864 
   1865   For example:
   1866 
   1867   ```python
   1868   x = tf.constant([[1, 2], [3, 4]])
   1869   tf.trace(x)  # 5
   1870 
   1871   x = tf.constant([[1, 2, 3],
   1872                    [4, 5, 6],
   1873                    [7, 8, 9]])
   1874   tf.trace(x)  # 15
   1875 
   1876   x = tf.constant([[[1, 2, 3],
   1877                     [4, 5, 6],
   1878                     [7, 8, 9]],
   1879                    [[-1, -2, -3],
   1880                     [-4, -5, -6],
   1881                     [-7, -8, -9]]])
   1882   tf.trace(x)  # [15, -15]
   1883   ```
   1884 
   1885   Args:
   1886     x: tensor.
   1887     name: A name for the operation (optional).
   1888 
   1889   Returns:
   1890     The trace of input tensor.
   1891   """
   1892   with ops.name_scope(name, "Trace", [x]) as name:
   1893     x = ops.convert_to_tensor(x, name="x")
   1894     return reduce_sum(array_ops.matrix_diag_part(x), [-1], name=name)
   1895 
   1896 
   1897 @tf_export("matmul")
   1898 def matmul(a,
   1899            b,
   1900            transpose_a=False,
   1901            transpose_b=False,
   1902            adjoint_a=False,
   1903            adjoint_b=False,
   1904            a_is_sparse=False,
   1905            b_is_sparse=False,
   1906            name=None):
   1907   """Multiplies matrix `a` by matrix `b`, producing `a` * `b`.
   1908 
   1909   The inputs must, following any transpositions, be tensors of rank >= 2
   1910   where the inner 2 dimensions specify valid matrix multiplication arguments,
   1911   and any further outer dimensions match.
   1912 
   1913   Both matrices must be of the same type. The supported types are:
   1914   `float16`, `float32`, `float64`, `int32`, `complex64`, `complex128`.
   1915 
   1916   Either matrix can be transposed or adjointed (conjugated and transposed) on
   1917   the fly by setting one of the corresponding flag to `True`. These are `False`
   1918   by default.
   1919 
   1920   If one or both of the matrices contain a lot of zeros, a more efficient
   1921   multiplication algorithm can be used by setting the corresponding
   1922   `a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default.
   1923   This optimization is only available for plain matrices (rank-2 tensors) with
   1924   datatypes `bfloat16` or `float32`.
   1925 
   1926   For example:
   1927 
   1928   ```python
   1929   # 2-D tensor `a`
   1930   # [[1, 2, 3],
   1931   #  [4, 5, 6]]
   1932   a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])
   1933 
   1934   # 2-D tensor `b`
   1935   # [[ 7,  8],
   1936   #  [ 9, 10],
   1937   #  [11, 12]]
   1938   b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2])
   1939 
   1940   # `a` * `b`
   1941   # [[ 58,  64],
   1942   #  [139, 154]]
   1943   c = tf.matmul(a, b)
   1944 
   1945 
   1946   # 3-D tensor `a`
   1947   # [[[ 1,  2,  3],
   1948   #   [ 4,  5,  6]],
   1949   #  [[ 7,  8,  9],
   1950   #   [10, 11, 12]]]
   1951   a = tf.constant(np.arange(1, 13, dtype=np.int32),
   1952                   shape=[2, 2, 3])
   1953 
   1954   # 3-D tensor `b`
   1955   # [[[13, 14],
   1956   #   [15, 16],
   1957   #   [17, 18]],
   1958   #  [[19, 20],
   1959   #   [21, 22],
   1960   #   [23, 24]]]
   1961   b = tf.constant(np.arange(13, 25, dtype=np.int32),
   1962                   shape=[2, 3, 2])
   1963 
   1964   # `a` * `b`
   1965   # [[[ 94, 100],
   1966   #   [229, 244]],
   1967   #  [[508, 532],
   1968   #   [697, 730]]]
   1969   c = tf.matmul(a, b)
   1970 
   1971   # Since python >= 3.5 the @ operator is supported (see PEP 465).
   1972   # In TensorFlow, it simply calls the `tf.matmul()` function, so the
   1973   # following lines are equivalent:
   1974   d = a @ b @ [[10.], [11.]]
   1975   d = tf.matmul(tf.matmul(a, b), [[10.], [11.]])
   1976   ```
   1977 
   1978   Args:
   1979     a: `Tensor` of type `float16`, `float32`, `float64`, `int32`, `complex64`,
   1980       `complex128` and rank > 1.
   1981     b: `Tensor` with same type and rank as `a`.
   1982     transpose_a: If `True`, `a` is transposed before multiplication.
   1983     transpose_b: If `True`, `b` is transposed before multiplication.
   1984     adjoint_a: If `True`, `a` is conjugated and transposed before
   1985       multiplication.
   1986     adjoint_b: If `True`, `b` is conjugated and transposed before
   1987       multiplication.
   1988     a_is_sparse: If `True`, `a` is treated as a sparse matrix.
   1989     b_is_sparse: If `True`, `b` is treated as a sparse matrix.
   1990     name: Name for the operation (optional).
   1991 
   1992   Returns:
   1993     A `Tensor` of the same type as `a` and `b` where each inner-most matrix is
   1994     the product of the corresponding matrices in `a` and `b`, e.g. if all
   1995     transpose or adjoint attributes are `False`:
   1996 
   1997     `output`[..., i, j] = sum_k (`a`[..., i, k] * `b`[..., k, j]),
   1998     for all indices i, j.
   1999 
   2000     Note: This is matrix product, not element-wise product.
   2001 
   2002 
   2003   Raises:
   2004     ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b
   2005       are both set to True.
   2006   """
   2007   with ops.name_scope(name, "MatMul", [a, b]) as name:
   2008     if transpose_a and adjoint_a:
   2009       raise ValueError("Only one of transpose_a and adjoint_a can be True.")
   2010     if transpose_b and adjoint_b:
   2011       raise ValueError("Only one of transpose_b and adjoint_b can be True.")
   2012 
   2013     a = ops.convert_to_tensor(a, name="a")
   2014     b = ops.convert_to_tensor(b, name="b")
   2015     # TODO(apassos) remove _shape_tuple here when it is not needed.
   2016     a_shape = a._shape_tuple()  # pylint: disable=protected-access
   2017     b_shape = b._shape_tuple()  # pylint: disable=protected-access
   2018     if (not a_is_sparse and
   2019         not b_is_sparse) and ((a_shape is None or len(a_shape) > 2) and
   2020                               (b_shape is None or len(b_shape) > 2)):
   2021       # BatchMatmul does not support transpose, so we conjugate the matrix and
   2022       # use adjoint instead. Conj() is a noop for real matrices.
   2023       if transpose_a:
   2024         a = conj(a)
   2025         adjoint_a = True
   2026       if transpose_b:
   2027         b = conj(b)
   2028         adjoint_b = True
   2029       return gen_math_ops._batch_mat_mul(
   2030           a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
   2031 
   2032     # Neither matmul nor sparse_matmul support adjoint, so we conjugate
   2033     # the matrix and use transpose instead. Conj() is a noop for real
   2034     # matrices.
   2035     if adjoint_a:
   2036       a = conj(a)
   2037       transpose_a = True
   2038     if adjoint_b:
   2039       b = conj(b)
   2040       transpose_b = True
   2041 
   2042     use_sparse_matmul = False
   2043     if a_is_sparse or b_is_sparse:
   2044       sparse_matmul_types = [dtypes.bfloat16, dtypes.float32]
   2045       use_sparse_matmul = (
   2046           a.dtype in sparse_matmul_types and b.dtype in sparse_matmul_types)
   2047     if a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16:
   2048       # matmul currently doesn't handle bfloat16 inputs.
   2049       use_sparse_matmul = True
   2050     if use_sparse_matmul:
   2051       ret = sparse_matmul(
   2052           a,
   2053           b,
   2054           transpose_a=transpose_a,
   2055           transpose_b=transpose_b,
   2056           a_is_sparse=a_is_sparse,
   2057           b_is_sparse=b_is_sparse,
   2058           name=name)
   2059       # sparse_matmul always returns float32, even with
   2060       # bfloat16 inputs. This prevents us from configuring bfloat16 training.
   2061       # casting to bfloat16 also matches non-sparse matmul behavior better.
   2062       if a.dtype == dtypes.bfloat16 and b.dtype == dtypes.bfloat16:
   2063         ret = cast(ret, dtypes.bfloat16)
   2064       return ret
   2065     else:
   2066       return gen_math_ops._mat_mul(
   2067           a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
   2068 
   2069 
   2070 _OverrideBinaryOperatorHelper(matmul, "matmul")
   2071 
   2072 sparse_matmul = gen_math_ops._sparse_mat_mul
   2073 
   2074 
   2075 @ops.RegisterStatistics("MatMul", "flops")
   2076 def _calc_mat_mul_flops(graph, node):
   2077   """Calculates the compute resources needed for MatMul."""
   2078   transpose_a = node.attr["transpose_a"].b
   2079   a_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
   2080   a_shape.assert_is_fully_defined()
   2081   if transpose_a:
   2082     k = int(a_shape[0])
   2083   else:
   2084     k = int(a_shape[1])
   2085   output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
   2086   output_shape.assert_is_fully_defined()
   2087   output_count = np.prod(output_shape.as_list())
   2088   return ops.OpStats("flops", (k * output_count * 2))
   2089 
   2090 
   2091 def _as_indexed_slices(x, optimize=True):
   2092   """Convert 'x' to IndexedSlices.
   2093 
   2094   Convert a dense Tensor to a block-sparse IndexedSlices.
   2095 
   2096   Args:
   2097     x: Either a Tensor object, or an IndexedSlices object.
   2098     optimize: if true, attempt to optimize the conversion of 'x'.
   2099 
   2100   Returns:
   2101     An IndexedSlices object.
   2102 
   2103   Raises:
   2104     TypeError: If 'x' is not a Tensor or an IndexedSlices object.
   2105   """
   2106   # TODO(touts): op_scope
   2107   if not isinstance(x, (ops.Tensor, ops.IndexedSlices)):
   2108     raise TypeError("Not a Tensor or IndexedSlices: %s" % type(x))
   2109   if isinstance(x, ops.IndexedSlices):
   2110     return x
   2111   x_shape = array_ops.shape_internal(x, optimize=optimize)
   2112   return ops.IndexedSlices(x, range(0, x_shape[0]), x_shape)
   2113 
   2114 
   2115 def _as_indexed_slices_list(inputs, optimize=True):
   2116   """Convert all elements of 'inputs' to IndexedSlices.
   2117 
   2118   Additionally, homogenize the types of all the indices to
   2119   either int32 or int64.
   2120 
   2121   Args:
   2122     inputs: List containing either Tensor or IndexedSlices objects.
   2123     optimize: if true, attempt to optimize the conversion of each input.
   2124 
   2125   Returns:
   2126     A list of IndexedSlices objects.
   2127 
   2128   Raises:
   2129     TypeError: If 'inputs' is not a list or a tuple.
   2130   """
   2131   if not isinstance(inputs, (list, tuple)):
   2132     raise TypeError("Expected a list or tuple, not a %s" % type(inputs))
   2133   outputs = [_as_indexed_slices(i, optimize=optimize) for i in inputs]
   2134   with_int32_index = [
   2135       o.indices for o in outputs if o.indices.dtype == dtypes.int32
   2136   ]
   2137   if not with_int32_index or len(with_int32_index) == len(outputs):
   2138     return outputs
   2139   casted_outputs = []
   2140   for o in outputs:
   2141     if o.indices.dtype == dtypes.int32:
   2142       casted_outputs.append(
   2143           ops.IndexedSlices(o.values, cast(o.indices, dtypes.int64),
   2144                             o.dense_shape))
   2145     else:
   2146       casted_outputs.append(o)
   2147   return casted_outputs
   2148 
   2149 
   2150 @tf_export("add_n")
   2151 def add_n(inputs, name=None):
   2152   """Adds all input tensors element-wise.
   2153 
   2154   Args:
   2155     inputs: A list of `Tensor` objects, each with same shape and type.
   2156     name: A name for the operation (optional).
   2157 
   2158   Returns:
   2159     A `Tensor` of same shape and type as the elements of `inputs`.
   2160 
   2161   Raises:
   2162     ValueError: If `inputs` don't all have same shape and dtype or the shape
   2163     cannot be inferred.
   2164   """
   2165   if not inputs or not isinstance(inputs, (list, tuple)):
   2166     raise ValueError("inputs must be a list of at least one Tensor with the "
   2167                      "same dtype and shape")
   2168   inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
   2169   if not all(isinstance(x, ops.Tensor) for x in inputs):
   2170     raise ValueError("inputs must be a list of at least one Tensor with the "
   2171                      "same dtype and shape")
   2172 
   2173   if len(inputs) == 1:
   2174     if name:
   2175       return array_ops.identity(inputs[0], name=name)
   2176     return inputs[0]
   2177   return gen_math_ops._add_n(inputs, name=name)
   2178 
   2179 
   2180 @tf_export("accumulate_n")
   2181 def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
   2182   """Returns the element-wise sum of a list of tensors.
   2183 
   2184   Optionally, pass `shape` and `tensor_dtype` for shape and type checking,
   2185   otherwise, these are inferred.
   2186 
   2187   NOTE: This operation is not differentiable and cannot be used if inputs depend
   2188   on trainable variables. Please use `tf.add_n` for such cases.
   2189 
   2190   Aside from differentiability, `tf.accumulate_n` performs the same operation as
   2191   `tf.add_n`, but does not wait for all of its inputs to be ready before
   2192   beginning to sum. This can save memory if inputs are ready at different times,
   2193   since minimum temporary storage is proportional to the output size rather than
   2194   the inputs size.
   2195 
   2196   For example:
   2197 
   2198   ```python
   2199   a = tf.constant([[1, 2], [3, 4]])
   2200   b = tf.constant([[5, 0], [0, 6]])
   2201   tf.accumulate_n([a, b, a])  # [[7, 4], [6, 14]]
   2202 
   2203   # Explicitly pass shape and type
   2204   tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32)  # [[7,  4],
   2205                                                                    #  [6, 14]]
   2206   ```
   2207 
   2208   Args:
   2209     inputs: A list of `Tensor` objects, each with same shape and type.
   2210     shape: Shape of elements of `inputs`.
   2211     tensor_dtype: The type of `inputs`.
   2212     name: A name for the operation (optional).
   2213 
   2214   Returns:
   2215     A `Tensor` of same shape and type as the elements of `inputs`.
   2216 
   2217   Raises:
   2218     ValueError: If `inputs` don't all have same shape and dtype or the shape
   2219     cannot be inferred.
   2220   """
   2221   if context.in_eager_mode():
   2222     # TODO(apassos) remove this once the lifetime of eager variables gets
   2223     # addressed.
   2224     raise ValueError("accumulate_n not supported in eager mode")
   2225   if not inputs or not isinstance(inputs, (list, tuple)):
   2226     raise ValueError("inputs must be a list of at least one Tensor with the "
   2227                      "same dtype and shape")
   2228   inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
   2229   if not all(isinstance(x, ops.Tensor) for x in inputs):
   2230     raise ValueError("inputs must be a list of at least one Tensor with the "
   2231                      "same dtype and shape")
   2232   if not all(x.dtype == inputs[0].dtype for x in inputs):
   2233     raise ValueError("inputs must be a list of at least one Tensor with the "
   2234                      "same dtype and shape")
   2235   if shape is not None:
   2236     shape = tensor_shape.as_shape(shape)
   2237   else:
   2238     shape = tensor_shape.unknown_shape()
   2239   for input_tensor in inputs:
   2240     if isinstance(input_tensor, ops.Tensor):
   2241       shape = shape.merge_with(input_tensor.get_shape())
   2242   if tensor_dtype is None:
   2243     tensor_dtype = inputs[0].dtype
   2244   if tensor_dtype != inputs[0].dtype:
   2245     raise TypeError("tensor_dtype is {}, but input is of type {}".format(
   2246         tensor_dtype, inputs[0].dtype))
   2247   if len(inputs) == 1:
   2248     return inputs[0]
   2249   with ops.name_scope(name, "AccumulateN", inputs) as name:
   2250     var = gen_state_ops._temporary_variable(
   2251         shape=tensor_shape.vector(0), dtype=tensor_dtype)
   2252     with ops.colocate_with(var):
   2253       zeros = array_ops.zeros_like(gen_control_flow_ops._merge(inputs)[0])
   2254       zeros.set_shape(shape)
   2255       ref = state_ops.assign(var, zeros, validate_shape=False)
   2256       update_ops = [
   2257           state_ops.assign_add(ref, input_tensor, use_locking=True)
   2258           for input_tensor in inputs
   2259       ]
   2260       with ops.control_dependencies(update_ops):
   2261         return gen_state_ops._destroy_temporary_variable(
   2262             ref, var_name=var.op.name, name=name)
   2263 
   2264 
   2265 @tf_export("nn.sigmoid", "sigmoid")
   2266 def sigmoid(x, name=None):
   2267   """Computes sigmoid of `x` element-wise.
   2268 
   2269   Specifically, `y = 1 / (1 + exp(-x))`.
   2270 
   2271   Args:
   2272     x: A Tensor with type `float16`, `float32`, `float64`, `complex64`,
   2273       or `complex128`.
   2274     name: A name for the operation (optional).
   2275 
   2276   Returns:
   2277     A Tensor with the same type as `x`.
   2278 
   2279   @compatibility(numpy)
   2280   Equivalent to np.scipy.special.expit
   2281   @end_compatibility
   2282   """
   2283   with ops.name_scope(name, "Sigmoid", [x]) as name:
   2284     x = ops.convert_to_tensor(x, name="x")
   2285     return gen_math_ops._sigmoid(x, name=name)
   2286 
   2287 
   2288 @tf_export("log_sigmoid")
   2289 def log_sigmoid(x, name=None):
   2290   """Computes log sigmoid of `x` element-wise.
   2291 
   2292   Specifically, `y = log(1 / (1 + exp(-x)))`.  For numerical stability,
   2293   we use `y = -tf.nn.softplus(-x)`.
   2294 
   2295   Args:
   2296     x: A Tensor with type `float32` or `float64`.
   2297     name: A name for the operation (optional).
   2298 
   2299   Returns:
   2300     A Tensor with the same type as `x`.
   2301   """
   2302   with ops.name_scope(name, "LogSigmoid", [x]) as name:
   2303     x = ops.convert_to_tensor(x, name="x")
   2304     return gen_math_ops._neg(gen_nn_ops.softplus(-x), name=name)
   2305 
   2306 
   2307 @tf_export("nn.tanh", "tanh")
   2308 def tanh(x, name=None):
   2309   """Computes hyperbolic tangent of `x` element-wise.
   2310 
   2311   Args:
   2312     x: A Tensor or SparseTensor with type `float16`, `float32`, `double`,
   2313       `complex64`, or `complex128`.
   2314     name: A name for the operation (optional).
   2315 
   2316   Returns:
   2317     A Tensor or SparseTensor respectively with the same type as `x`.
   2318   """
   2319   with ops.name_scope(name, "Tanh", [x]) as name:
   2320     if isinstance(x, sparse_tensor.SparseTensor):
   2321       x_tanh = gen_math_ops._tanh(x.values, name=name)
   2322       return sparse_tensor.SparseTensor(
   2323           indices=x.indices, values=x_tanh, dense_shape=x.dense_shape)
   2324     else:
   2325       return gen_math_ops._tanh(x, name=name)
   2326 
   2327 
   2328 @tf_export("bincount")
   2329 def bincount(arr,
   2330              weights=None,
   2331              minlength=None,
   2332              maxlength=None,
   2333              dtype=dtypes.int32):
   2334   """Counts the number of occurrences of each value in an integer array.
   2335 
   2336   If `minlength` and `maxlength` are not given, returns a vector with length
   2337   `tf.reduce_max(arr) + 1` if `arr` is non-empty, and length 0 otherwise.
   2338   If `weights` are non-None, then index `i` of the output stores the sum of the
   2339   value in `weights` at each index where the corresponding value in `arr` is
   2340   `i`.
   2341 
   2342   Args:
   2343     arr: An int32 tensor of non-negative values.
   2344     weights: If non-None, must be the same shape as arr. For each value in
   2345         `arr`, the bin will be incremented by the corresponding weight instead
   2346         of 1.
   2347     minlength: If given, ensures the output has length at least `minlength`,
   2348         padding with zeros at the end if necessary.
   2349     maxlength: If given, skips values in `arr` that are equal or greater than
   2350         `maxlength`, ensuring that the output has length at most `maxlength`.
   2351     dtype: If `weights` is None, determines the type of the output bins.
   2352 
   2353   Returns:
   2354     A vector with the same dtype as `weights` or the given `dtype`. The bin
   2355     values.
   2356   """
   2357   arr = ops.convert_to_tensor(arr, name="arr", dtype=dtypes.int32)
   2358   array_is_nonempty = reduce_prod(array_ops.shape(arr)) > 0
   2359   output_size = cast(array_is_nonempty, dtypes.int32) * (reduce_max(arr) + 1)
   2360   if minlength is not None:
   2361     minlength = ops.convert_to_tensor(
   2362         minlength, name="minlength", dtype=dtypes.int32)
   2363     output_size = gen_math_ops.maximum(minlength, output_size)
   2364   if maxlength is not None:
   2365     maxlength = ops.convert_to_tensor(
   2366         maxlength, name="maxlength", dtype=dtypes.int32)
   2367     output_size = gen_math_ops.minimum(maxlength, output_size)
   2368   if weights is not None:
   2369     weights = ops.convert_to_tensor(weights, name="weights")
   2370     return gen_math_ops.unsorted_segment_sum(weights, arr, output_size)
   2371   weights = constant_op.constant([], dtype)
   2372   return gen_math_ops.bincount(arr, output_size, weights)
   2373 
   2374 
   2375 @tf_export("cumsum")
   2376 def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
   2377   """Compute the cumulative sum of the tensor `x` along `axis`.
   2378 
   2379   By default, this op performs an inclusive cumsum, which means that the first
   2380   element of the input is identical to the first element of the output:
   2381 
   2382   ```python
   2383   tf.cumsum([a, b, c])  # [a, a + b, a + b + c]
   2384   ```
   2385 
   2386   By setting the `exclusive` kwarg to `True`, an exclusive cumsum is performed
   2387   instead:
   2388 
   2389   ```python
   2390   tf.cumsum([a, b, c], exclusive=True)  # [0, a, a + b]
   2391   ```
   2392 
   2393   By setting the `reverse` kwarg to `True`, the cumsum is performed in the
   2394   opposite direction:
   2395 
   2396   ```python
   2397   tf.cumsum([a, b, c], reverse=True)  # [a + b + c, b + c, c]
   2398   ```
   2399 
   2400   This is more efficient than using separate `tf.reverse` ops.
   2401 
   2402   The `reverse` and `exclusive` kwargs can also be combined:
   2403 
   2404   ```python
   2405   tf.cumsum([a, b, c], exclusive=True, reverse=True)  # [b + c, c, 0]
   2406   ```
   2407 
   2408   Args:
   2409     x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
   2410        `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
   2411        `complex128`, `qint8`, `quint8`, `qint32`, `half`.
   2412     axis: A `Tensor` of type `int32` (default: 0). Must be in the range
   2413       `[-rank(x), rank(x))`.
   2414     exclusive: If `True`, perform exclusive cumsum.
   2415     reverse: A `bool` (default: False).
   2416     name: A name for the operation (optional).
   2417 
   2418   Returns:
   2419     A `Tensor`. Has the same type as `x`.
   2420   """
   2421   with ops.name_scope(name, "Cumsum", [x]) as name:
   2422     x = ops.convert_to_tensor(x, name="x")
   2423     return gen_math_ops.cumsum(
   2424         x, axis, exclusive=exclusive, reverse=reverse, name=name)
   2425 
   2426 
   2427 @tf_export("cumprod")
   2428 def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
   2429   """Compute the cumulative product of the tensor `x` along `axis`.
   2430 
   2431   By default, this op performs an inclusive cumprod, which means that the
   2432   first element of the input is identical to the first element of the output:
   2433 
   2434   ```python
   2435   tf.cumprod([a, b, c])  # [a, a * b, a * b * c]
   2436   ```
   2437 
   2438   By setting the `exclusive` kwarg to `True`, an exclusive cumprod is
   2439   performed
   2440   instead:
   2441 
   2442   ```python
   2443   tf.cumprod([a, b, c], exclusive=True)  # [1, a, a * b]
   2444   ```
   2445 
   2446   By setting the `reverse` kwarg to `True`, the cumprod is performed in the
   2447   opposite direction:
   2448 
   2449   ```python
   2450   tf.cumprod([a, b, c], reverse=True)  # [a * b * c, b * c, c]
   2451   ```
   2452 
   2453   This is more efficient than using separate `tf.reverse` ops.
   2454   The `reverse` and `exclusive` kwargs can also be combined:
   2455 
   2456   ```python
   2457   tf.cumprod([a, b, c], exclusive=True, reverse=True)  # [b * c, c, 1]
   2458   ```
   2459 
   2460   Args:
   2461     x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
   2462        `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
   2463        `complex128`, `qint8`, `quint8`, `qint32`, `half`.
   2464     axis: A `Tensor` of type `int32` (default: 0). Must be in the range
   2465       `[-rank(x), rank(x))`.
   2466     exclusive: If `True`, perform exclusive cumprod.
   2467     reverse: A `bool` (default: False).
   2468     name: A name for the operation (optional).
   2469 
   2470   Returns:
   2471     A `Tensor`. Has the same type as `x`.
   2472   """
   2473   with ops.name_scope(name, "Cumprod", [x]) as name:
   2474     x = ops.convert_to_tensor(x, name="x")
   2475     return gen_math_ops.cumprod(
   2476         x, axis, exclusive=exclusive, reverse=reverse, name=name)
   2477 
   2478 
   2479 @tf_export("conj")
   2480 def conj(x, name=None):
   2481   r"""Returns the complex conjugate of a complex number.
   2482 
   2483   Given a tensor `input` of complex numbers, this operation returns a tensor of
   2484   complex numbers that are the complex conjugate of each element in `input`. The
   2485   complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the
   2486   real part and *b* is the imaginary part.
   2487 
   2488   The complex conjugate returned by this operation is of the form \\(a - bj\\).
   2489 
   2490   For example:
   2491 
   2492       # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
   2493       tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
   2494 
   2495   If `x` is real, it is returned unchanged.
   2496 
   2497   Args:
   2498     x: `Tensor` to conjugate.  Must have numeric or variant type.
   2499     name: A name for the operation (optional).
   2500 
   2501   Returns:
   2502     A `Tensor` that is the conjugate of `x` (with the same type).
   2503 
   2504   Raises:
   2505     TypeError: If `x` is not a numeric tensor.
   2506   """
   2507   if isinstance(x, ops.Tensor):
   2508     dt = x.dtype
   2509     if dt.is_floating or dt.is_integer:
   2510       return x
   2511   with ops.name_scope(name, "Conj", [x]) as name:
   2512     x = ops.convert_to_tensor(x, name="x")
   2513     if x.dtype.is_complex or x.dtype == dtypes.variant:
   2514       return gen_math_ops._conj(x, name=name)
   2515     elif x.dtype.is_floating or x.dtype.is_integer:
   2516       return x
   2517     else:
   2518       raise TypeError(
   2519           "Expected numeric or variant tensor, got dtype %r" % x.dtype)
   2520 
   2521 
   2522 def _BroadcastShape(op):
   2523   """Common shape function for binary operators that broadcast their inputs."""
   2524   return [
   2525       common_shapes.broadcast_shape(op.inputs[0].get_shape(),
   2526                                     op.inputs[1].get_shape())
   2527   ]
   2528 
   2529 
   2530 def reduced_shape(input_shape, axes):
   2531   """Helper function for reduction ops.
   2532 
   2533   Args:
   2534     input_shape: 1-D Tensor, the shape of the Tensor being reduced.
   2535     axes: 1-D Tensor, the reduction axes.
   2536   Returns:
   2537     A 1-D Tensor, the output shape as if keepdims were set to True.
   2538   """
   2539   # Example:
   2540   # cast needed for SparseTensor reductions
   2541   input_shape = to_int32(input_shape)  # [2, 3, 5, 7]
   2542   axes = to_int32(axes)  # [1, 2]
   2543 
   2544   input_rank = array_ops.size(input_shape)  # 4
   2545   axes = (axes + input_rank) % input_rank
   2546   axes_shape = array_ops.shape(axes)  # [2]
   2547   return gen_data_flow_ops.dynamic_stitch(  # [2, 1, 1, 7]
   2548       [
   2549           range(input_rank),  # [0, 1, 2, 3]
   2550           axes
   2551       ],  # [1, 2]
   2552       [
   2553           input_shape,  # [2, 3, 5, 7]
   2554           array_ops.fill(axes_shape, 1)
   2555       ])  # [1, 1]
   2556 
   2557 
   2558 def _unsorted_segment_N(data, segment_ids, num_segments):
   2559   """ Helper function for unsorted_segment_mean/_sqrtN. Computes the number
   2560       of segment entries with 0-entries set to 1 to allow division by N.
   2561   """
   2562   # bincount doesn't support negative indices so we use unsorted_segment_sum
   2563   ones_tensor = array_ops.ones(segment_ids.shape, dtype=data.dtype)
   2564   N = gen_math_ops.unsorted_segment_sum(ones_tensor, segment_ids, num_segments)
   2565   # add dimensions for all non-reduced axes
   2566   ndims_output = data.shape.ndims - segment_ids.shape.ndims
   2567   broadcast_shape = [num_segments] + [1] * ndims_output
   2568   N = array_ops.reshape(N, broadcast_shape)
   2569   return gen_math_ops.maximum(N, 1)
   2570 
   2571 
   2572 @tf_export("unsorted_segment_mean")
   2573 def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
   2574   r""" Computes the mean along segments of a tensor.
   2575 
   2576   Read @{$math_ops#segmentation$the section on segmentation} for an explanation
   2577   of segments.
   2578 
   2579   This operator is similar to the unsorted segment sum operator found
   2580   [here](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
   2581   Instead of computing the sum over segments, it computes the mean of all
   2582   entries belonging to a segment such that:
   2583 
   2584   \\(output_i = 1/N_i \sum data_j\\) where the sum is over `j` such
   2585   that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences
   2586   of id \\i\\.
   2587 
   2588   If there is no entry for a given segment ID `i`, it outputs 0.
   2589 
   2590   segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
   2591   first dimension.
   2592 
   2593   output: Has same shape as data, except for dimension 0 which
   2594   has size `num_segments`.
   2595   """
   2596   with ops.name_scope(name, "UnsortedSegmentMean"):
   2597     data = ops.convert_to_tensor(data)
   2598     segment_ids = ops.convert_to_tensor(segment_ids)
   2599     N = _unsorted_segment_N(data, segment_ids, num_segments)
   2600     summed = gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments)
   2601     return summed / N
   2602 
   2603 
   2604 @tf_export("unsorted_segment_sqrt_n")
   2605 def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
   2606   r"""Computes the sum along segments of a tensor divided by the sqrt(N).
   2607 
   2608   Read @{$math_ops#segmentation$the section on segmentation} for an explanation
   2609   of segments.
   2610 
   2611   This operator is similar to the unsorted segment sum operator found
   2612   [here](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
   2613   Additionally to computing the sum over segments, it divides the results by
   2614   sqrt(N).
   2615 
   2616   \\(output_i = 1/sqrt(N_i) \sum data_j\\) where the sum is over `j` such
   2617   that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences
   2618   of id \\i\\.
   2619 
   2620   If there is no entry for a given segment ID `i`, it outputs 0.
   2621 
   2622   Note that this op only supports floating point and complex dtypes,
   2623   due to tf.sqrt only supporting these types.
   2624 
   2625   segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
   2626   first dimension.
   2627 
   2628   output: Has same shape as data, except for dimension 0 which
   2629   has size `num_segments`.
   2630   """
   2631   with ops.name_scope(name, "UnsortedSegmentSqrtN"):
   2632     data = ops.convert_to_tensor(data)
   2633     segment_ids = ops.convert_to_tensor(segment_ids)
   2634     N = _unsorted_segment_N(data, segment_ids, num_segments)
   2635     summed = gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments)
   2636     return summed / gen_math_ops.sqrt(N)
   2637 
   2638 
   2639 @tf_export("sparse_segment_sum")
   2640 def sparse_segment_sum(data, indices, segment_ids, name=None,
   2641                        num_segments=None):
   2642   r"""Computes the sum along sparse segments of a tensor.
   2643 
   2644   Read @{$math_ops#Segmentation$the section on segmentation} for an explanation
   2645   of segments.
   2646 
   2647   Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
   2648   dimension, selecting a subset of dimension 0, specified by `indices`.
   2649   `segment_ids` is allowed to have missing ids, in which case the output will
   2650   be zeros at those indices. In those cases `num_segments` is used to determine
   2651   the size of the output.
   2652 
   2653   For example:
   2654 
   2655   ```python
   2656   c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
   2657 
   2658   # Select two rows, one segment.
   2659   tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
   2660   # => [[0 0 0 0]]
   2661 
   2662   # Select two rows, two segment.
   2663   tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
   2664   # => [[ 1  2  3  4]
   2665   #     [-1 -2 -3 -4]]
   2666 
   2667   # With missing segment ids.
   2668   tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 2]),
   2669                         num_segments=4)
   2670   # => [[ 1  2  3  4]
   2671   #     [ 0  0  0  0]
   2672   #     [-1 -2 -3 -4]
   2673   #     [ 0  0  0  0]]
   2674 
   2675   # Select all rows, two segments.
   2676   tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
   2677   # => [[0 0 0 0]
   2678   #     [5 6 7 8]]
   2679 
   2680   # Which is equivalent to:
   2681   tf.segment_sum(c, tf.constant([0, 0, 1]))
   2682   ```
   2683 
   2684   Args:
   2685     data: A `Tensor` with data that will be assembled in the output.
   2686     indices: A 1-D `Tensor` with indices into `data`. Has same rank as
   2687       `segment_ids`.
   2688     segment_ids: A 1-D `Tensor` with indices into the output `Tensor`.
   2689       Values should be sorted and can be repeated.
   2690     name: A name for the operation (optional).
   2691     num_segments: An optional int32 scalar. Indicates the size of the output
   2692       `Tensor`.
   2693 
   2694   Returns:
   2695     A `tensor` of the shape as data, except for dimension 0 which
   2696     has size `k`, the number of segments specified via `num_segments` or
   2697     inferred for the last element in `segments_ids`.
   2698   """
   2699   if num_segments is not None:
   2700     return gen_math_ops.sparse_segment_sum_with_num_segments(
   2701         data=data,
   2702         indices=indices,
   2703         segment_ids=segment_ids,
   2704         num_segments=num_segments,
   2705         name=name)
   2706   else:
   2707     return gen_math_ops.sparse_segment_sum(
   2708         data=data,
   2709         indices=indices,
   2710         segment_ids=segment_ids,
   2711         name=name)
   2712 
   2713 
   2714 @tf_export("sparse_segment_mean")
   2715 def sparse_segment_mean(data, indices, segment_ids, name=None,
   2716                         num_segments=None):
   2717   r"""Computes the mean along sparse segments of a tensor.
   2718 
   2719   Read @{$math_ops#Segmentation$the section on segmentation} for an explanation
   2720   of segments.
   2721 
   2722   Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
   2723   dimension, selecting a subset of dimension 0, specified by `indices`.
   2724   `segment_ids` is allowed to have missing ids, in which case the output will
   2725   be zeros at those indices. In those cases `num_segments` is used to determine
   2726   the size of the output.
   2727 
   2728   Args:
   2729     data: A `Tensor` with data that will be assembled in the output.
   2730     indices: A 1-D `Tensor` with indices into `data`. Has same rank as
   2731       `segment_ids`.
   2732     segment_ids: A 1-D `Tensor` with indices into the output `Tensor`.
   2733       Values should be sorted and can be repeated.
   2734     name: A name for the operation (optional).
   2735     num_segments: An optional int32 scalar. Indicates the size of the output
   2736       `Tensor`.
   2737 
   2738   Returns:
   2739     A `tensor` of the shape as data, except for dimension 0 which
   2740     has size `k`, the number of segments specified via `num_segments` or
   2741     inferred for the last element in `segments_ids`.
   2742   """
   2743   if num_segments is not None:
   2744     return gen_math_ops.sparse_segment_mean_with_num_segments(
   2745         data=data,
   2746         indices=indices,
   2747         segment_ids=segment_ids,
   2748         num_segments=num_segments,
   2749         name=name)
   2750   else:
   2751     return gen_math_ops.sparse_segment_mean(
   2752         data=data,
   2753         indices=indices,
   2754         segment_ids=segment_ids,
   2755         name=name)
   2756 
   2757 
   2758 @tf_export("sparse_segment_sqrt_n")
   2759 def sparse_segment_sqrt_n(data, indices, segment_ids, name=None,
   2760                           num_segments=None):
   2761   r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N).
   2762 
   2763   `N` is the size of the segment being reduced.
   2764 
   2765   Args:
   2766     data: A `Tensor` with data that will be assembled in the output.
   2767     indices: A 1-D `Tensor` with indices into `data`. Has same rank as
   2768       `segment_ids`.
   2769     segment_ids: A 1-D `Tensor` with indices into the output `Tensor`.
   2770       Values should be sorted and can be repeated.
   2771     name: A name for the operation (optional).
   2772     num_segments: An optional int32 scalar. Indicates the size of the output
   2773       `Tensor`.
   2774 
   2775   Returns:
   2776     A `tensor` of the shape as data, except for dimension 0 which
   2777     has size `k`, the number of segments specified via `num_segments` or
   2778     inferred for the last element in `segments_ids`.
   2779   """
   2780   if num_segments is not None:
   2781     return gen_math_ops.sparse_segment_sqrt_n_with_num_segments(
   2782         data=data,
   2783         indices=indices,
   2784         segment_ids=segment_ids,
   2785         num_segments=num_segments,
   2786         name=name)
   2787   else:
   2788     return gen_math_ops.sparse_segment_sqrt_n(
   2789         data=data,
   2790         indices=indices,
   2791         segment_ids=segment_ids,
   2792         name=name)
   2793 
   2794 
   2795 @tf_export("tensordot", "linalg.tensordot")
   2796 def tensordot(a, b, axes, name=None):
   2797   r"""Tensor contraction of a and b along specified axes.
   2798 
   2799   Tensordot (also known as tensor contraction) sums the product of elements
   2800   from `a` and `b` over the indices specified by `a_axes` and `b_axes`.
   2801   The lists `a_axes` and `b_axes` specify those pairs of axes along which to
   2802   contract the tensors. The axis `a_axes[i]` of `a` must have the same dimension
   2803   as axis `b_axes[i]` of `b` for all `i` in `range(0, len(a_axes))`. The lists
   2804   `a_axes` and `b_axes` must have identical length and consist of unique
   2805   integers that specify valid axes for each of the tensors.
   2806 
   2807   This operation corresponds to `numpy.tensordot(a, b, axes)`.
   2808 
   2809   Example 1: When `a` and `b` are matrices (order 2), the case `axes = 1`
   2810   is equivalent to matrix multiplication.
   2811 
   2812   Example 2: When `a` and `b` are matrices (order 2), the case
   2813   `axes = [[1], [0]]` is equivalent to matrix multiplication.
   2814 
   2815   Example 3: Suppose that \\(a_{ijk}\\) and \\(b_{lmn}\\) represent two
   2816   tensors of order 3. Then, `contract(a, b, [[0], [2]])` is the order 4 tensor
   2817   \\(c_{jklm}\\) whose entry
   2818   corresponding to the indices \\((j,k,l,m)\\) is given by:
   2819 
   2820   \\( c_{jklm} = \sum_i a_{ijk} b_{lmi} \\).
   2821 
   2822   In general, `order(c) = order(a) + order(b) - 2*len(axes[0])`.
   2823 
   2824   Args:
   2825     a: `Tensor` of type `float32` or `float64`.
   2826     b: `Tensor` with the same type as `a`.
   2827     axes: Either a scalar `N`, or a list or an `int32` `Tensor` of shape [2, k].
   2828      If axes is a scalar, sum over the last N axes of a and the first N axes
   2829      of b in order.
   2830      If axes is a list or `Tensor` the first and second row contain the set of
   2831      unique integers specifying axes along which the contraction is computed,
   2832      for `a` and `b`, respectively. The number of axes for `a` and `b` must
   2833      be equal.
   2834     name: A name for the operation (optional).
   2835 
   2836   Returns:
   2837     A `Tensor` with the same type as `a`.
   2838 
   2839   Raises:
   2840     ValueError: If the shapes of `a`, `b`, and `axes` are incompatible.
   2841     IndexError: If the values in axes exceed the rank of the corresponding
   2842       tensor.
   2843   """
   2844 
   2845   def _tensordot_reshape(a, axes, flipped=False):
   2846     """Helper method to perform transpose and reshape for contraction op.
   2847 
   2848     This method is helpful in reducing `math_ops.tensordot` to `math_ops.matmul`
   2849     using `array_ops.transpose` and `array_ops.reshape`. The method takes a
   2850     tensor and performs the correct transpose and reshape operation for a given
   2851     set of indices. It returns the reshaped tensor as well as a list of indices
   2852     necessary to reshape the tensor again after matrix multiplication.
   2853 
   2854     Args:
   2855       a: `Tensor`.
   2856       axes: List or `int32` `Tensor` of unique indices specifying valid axes of
   2857        `a`.
   2858       flipped: An optional `bool`. Defaults to `False`. If `True`, the method
   2859         assumes that `a` is the second argument in the contraction operation.
   2860 
   2861     Returns:
   2862       A tuple `(reshaped_a, free_dims, free_dims_static)` where `reshaped_a` is
   2863       the tensor `a` reshaped to allow contraction via `matmul`, `free_dims` is
   2864       either a list of integers or an `int32` `Tensor`, depending on whether
   2865       the shape of a is fully specified, and free_dims_static is either a list
   2866       of integers and None values, or None, representing the inferred
   2867       static shape of the free dimensions
   2868     """
   2869     if a.get_shape().is_fully_defined() and isinstance(axes, (list, tuple)):
   2870       shape_a = a.get_shape().as_list()
   2871       axes = [i if i >= 0 else i + len(shape_a) for i in axes]
   2872       free = [i for i in xrange(len(shape_a)) if i not in axes]
   2873       free_dims = [shape_a[i] for i in free]
   2874       prod_free = int(np.prod([shape_a[i] for i in free]))
   2875       prod_axes = int(np.prod([shape_a[i] for i in axes]))
   2876       perm = list(axes) + free if flipped else free + list(axes)
   2877       new_shape = [prod_axes, prod_free] if flipped else [prod_free, prod_axes]
   2878       reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
   2879       return reshaped_a, free_dims, free_dims
   2880     else:
   2881       if a.get_shape().ndims is not None and isinstance(axes, (list, tuple)):
   2882         shape_a = a.get_shape().as_list()
   2883         axes = [i if i >= 0 else i + len(shape_a) for i in axes]
   2884         free = [i for i in xrange(len(shape_a)) if i not in axes]
   2885         free_dims_static = [shape_a[i] for i in free]
   2886       else:
   2887         free_dims_static = None
   2888       shape_a = array_ops.shape(a)
   2889       rank_a = array_ops.rank(a)
   2890       axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
   2891       axes = cast(axes >= 0, dtypes.int32) * axes + cast(
   2892           axes < 0, dtypes.int32) * (
   2893               axes + rank_a)
   2894       free, _ = array_ops.setdiff1d(range(rank_a), axes)
   2895       free_dims = array_ops.gather(shape_a, free)
   2896       axes_dims = array_ops.gather(shape_a, axes)
   2897       prod_free_dims = reduce_prod(free_dims)
   2898       prod_axes_dims = reduce_prod(axes_dims)
   2899       perm = array_ops.concat([axes_dims, free_dims], 0)
   2900       if flipped:
   2901         perm = array_ops.concat([axes, free], 0)
   2902         new_shape = array_ops.stack([prod_axes_dims, prod_free_dims])
   2903       else:
   2904         perm = array_ops.concat([free, axes], 0)
   2905         new_shape = array_ops.stack([prod_free_dims, prod_axes_dims])
   2906       reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
   2907       return reshaped_a, free_dims, free_dims_static
   2908 
   2909   def _tensordot_axes(a, axes):
   2910     """Generates two sets of contraction axes for the two tensor arguments."""
   2911     a_shape = a.get_shape()
   2912     if isinstance(axes, compat.integral_types):
   2913       if axes < 0:
   2914         raise ValueError("'axes' must be at least 0.")
   2915       if a_shape.ndims is not None:
   2916         if axes > a_shape.ndims:
   2917           raise ValueError("'axes' must not be larger than the number of "
   2918                            "dimensions of tensor %s." % a)
   2919         return (list(xrange(a_shape.ndims - axes, a_shape.ndims)),
   2920                 list(xrange(axes)))
   2921       else:
   2922         rank = array_ops.rank(a)
   2923         return (range(rank - axes, rank, dtype=dtypes.int32),
   2924                 range(axes, dtype=dtypes.int32))
   2925     elif isinstance(axes, (list, tuple)):
   2926       if len(axes) != 2:
   2927         raise ValueError("'axes' must be an integer or have length 2.")
   2928       a_axes = axes[0]
   2929       b_axes = axes[1]
   2930       if isinstance(a_axes, compat.integral_types) and \
   2931           isinstance(b_axes, compat.integral_types):
   2932         a_axes = [a_axes]
   2933         b_axes = [b_axes]
   2934       if len(a_axes) != len(b_axes):
   2935         raise ValueError(
   2936             "Different number of contraction axes 'a' and 'b', %s != %s." %
   2937             (len(a_axes), len(b_axes)))
   2938       return a_axes, b_axes
   2939     else:
   2940       axes = ops.convert_to_tensor(axes, name="axes", dtype=dtypes.int32)
   2941       return axes[0], axes[1]
   2942 
   2943   with ops.name_scope(name, "Tensordot", [a, b, axes]) as name:
   2944     a = ops.convert_to_tensor(a, name="a")
   2945     b = ops.convert_to_tensor(b, name="b")
   2946     a_axes, b_axes = _tensordot_axes(a, axes)
   2947     a_reshape, a_free_dims, a_free_dims_static = _tensordot_reshape(a, a_axes)
   2948     b_reshape, b_free_dims, b_free_dims_static = _tensordot_reshape(
   2949         b, b_axes, True)
   2950     ab_matmul = matmul(a_reshape, b_reshape)
   2951     if isinstance(a_free_dims, list) and isinstance(b_free_dims, list):
   2952       return array_ops.reshape(ab_matmul, a_free_dims + b_free_dims, name=name)
   2953     else:
   2954       a_free_dims = ops.convert_to_tensor(a_free_dims, dtype=dtypes.int32)
   2955       b_free_dims = ops.convert_to_tensor(b_free_dims, dtype=dtypes.int32)
   2956       product = array_ops.reshape(
   2957           ab_matmul, array_ops.concat([a_free_dims, b_free_dims], 0), name=name)
   2958       if a_free_dims_static is not None and b_free_dims_static is not None:
   2959         product.set_shape(a_free_dims_static + b_free_dims_static)
   2960       return product
   2961 
   2962 
   2963 # FFT ops were moved to tf.spectral. tf.fft symbols were part of the TensorFlow
   2964 # 1.0 API so we leave these here for backwards compatibility.
   2965 fft = gen_spectral_ops.fft
   2966 ifft = gen_spectral_ops.ifft
   2967 fft2d = gen_spectral_ops.fft2d
   2968 ifft2d = gen_spectral_ops.ifft2d
   2969 fft3d = gen_spectral_ops.fft3d
   2970 ifft3d = gen_spectral_ops.ifft3d
   2971