Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Tensor utility functions."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.framework import sparse_tensor
     26 from tensorflow.python.framework import tensor_util
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import check_ops
     29 from tensorflow.python.ops import control_flow_ops
     30 from tensorflow.python.ops import math_ops
     31 from tensorflow.python.util.deprecation import deprecated
     32 
     33 
     34 __all__ = [
     35     'assert_same_float_dtype',
     36     'assert_scalar',
     37     'assert_scalar_int',
     38     'convert_to_tensor_or_sparse_tensor',
     39     'is_tensor',
     40     'reduce_sum_n',
     41     'remove_squeezable_dimensions',
     42     'with_shape',
     43     'with_same_shape']
     44 
     45 
     46 # Temporary for backwards compatibility
     47 is_tensor = tensor_util.is_tensor
     48 assert_same_float_dtype = check_ops.assert_same_float_dtype
     49 assert_scalar = check_ops.assert_scalar
     50 
     51 convert_to_tensor_or_sparse_tensor = (
     52     sparse_tensor.convert_to_tensor_or_sparse_tensor)
     53 
     54 
     55 def reduce_sum_n(tensors, name=None):
     56   """Reduce tensors to a scalar sum.
     57 
     58   This reduces each tensor in `tensors` to a scalar via `tf.reduce_sum`, then
     59   adds them via `tf.add_n`.
     60 
     61   Args:
     62     tensors: List of tensors, all of the same numeric type.
     63     name: Tensor name, and scope for all other ops.
     64 
     65   Returns:
     66     Total loss tensor, or None if no losses have been configured.
     67 
     68   Raises:
     69     ValueError: if `losses` is missing or empty.
     70   """
     71   if not tensors:
     72     raise ValueError('No tensors provided.')
     73   with ops.name_scope(name, 'reduce_sum_n', tensors) as name_scope:
     74     tensors = [
     75         math_ops.reduce_sum(t, name='%s/sum' % t.op.name) for t in tensors]
     76     if len(tensors) == 1:
     77       return tensors[0]
     78     return math_ops.add_n(tensors, name=name_scope)
     79 
     80 @deprecated(
     81     None, "Please switch to remove_squeezable_dimensions from "
     82     "tf.confusion_matrix. Note that the order of the inputs and outputs of "
     83     "labels and predictions have also been switched.")
     84 def remove_squeezable_dimensions(predictions, labels, name=None):
     85   """Squeeze last dim if ranks of `predictions` and `labels` differ by 1.
     86 
     87   This will use static shape if available. Otherwise, it will add graph
     88   operations, which could result in a performance hit.
     89 
     90   Args:
     91     predictions: Predicted values, a `Tensor` of arbitrary dimensions.
     92     labels: Label values, a `Tensor` whose dimensions match `predictions`.
     93     name: Name of the op.
     94 
     95   Returns:
     96     Tuple of `predictions` and `labels`, possibly with last dim squeezed.
     97   """
     98   with ops.name_scope(name, 'remove_squeezable_dimensions',
     99                       [predictions, labels]):
    100     predictions = ops.convert_to_tensor(predictions)
    101     labels = ops.convert_to_tensor(labels)
    102     predictions_shape = predictions.get_shape()
    103     predictions_rank = predictions_shape.ndims
    104     labels_shape = labels.get_shape()
    105     labels_rank = labels_shape.ndims
    106     if (labels_rank is not None) and (predictions_rank is not None):
    107       # Use static rank.
    108       rank_diff = predictions_rank - labels_rank
    109       if rank_diff == -1:
    110         labels = array_ops.squeeze(labels, [-1])
    111       elif rank_diff == 1:
    112         predictions = array_ops.squeeze(predictions, [-1])
    113       return predictions, labels
    114 
    115     # Use dynamic rank.
    116     rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
    117     if (predictions_rank is None) or (
    118         predictions_shape.dims[-1].is_compatible_with(1)):
    119       predictions = control_flow_ops.cond(
    120           math_ops.equal(1, rank_diff),
    121           lambda: array_ops.squeeze(predictions, [-1]),
    122           lambda: predictions)
    123     if (labels_rank is None) or (
    124         labels_shape.dims[-1].is_compatible_with(1)):
    125       labels = control_flow_ops.cond(
    126           math_ops.equal(-1, rank_diff),
    127           lambda: array_ops.squeeze(labels, [-1]),
    128           lambda: labels)
    129     return predictions, labels
    130 
    131 
    132 def _all_equal(tensor0, tensor1):
    133   with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope:
    134     return math_ops.reduce_all(
    135         math_ops.equal(tensor0, tensor1, name='equal'), name=scope)
    136 
    137 
    138 def _is_rank(expected_rank, actual_tensor):
    139   """Returns whether actual_tensor's rank is expected_rank.
    140 
    141   Args:
    142     expected_rank: Integer defining the expected rank, or tensor of same.
    143     actual_tensor: Tensor to test.
    144   Returns:
    145     New tensor.
    146   """
    147   with ops.name_scope('is_rank', values=[actual_tensor]) as scope:
    148     expected = ops.convert_to_tensor(expected_rank, name='expected')
    149     actual = array_ops.rank(actual_tensor, name='actual')
    150     return math_ops.equal(expected, actual, name=scope)
    151 
    152 
    153 def _is_shape(expected_shape, actual_tensor, actual_shape=None):
    154   """Returns whether actual_tensor's shape is expected_shape.
    155 
    156   Args:
    157     expected_shape: Integer list defining the expected shape, or tensor of same.
    158     actual_tensor: Tensor to test.
    159     actual_shape: Shape of actual_tensor, if we already have it.
    160   Returns:
    161     New tensor.
    162   """
    163   with ops.name_scope('is_shape', values=[actual_tensor]) as scope:
    164     is_rank = _is_rank(array_ops.size(expected_shape), actual_tensor)
    165     if actual_shape is None:
    166       actual_shape = array_ops.shape(actual_tensor, name='actual')
    167     shape_equal = _all_equal(
    168         ops.convert_to_tensor(expected_shape, name='expected'),
    169         actual_shape)
    170     return math_ops.logical_and(is_rank, shape_equal, name=scope)
    171 
    172 
    173 def _assert_shape_op(expected_shape, actual_tensor):
    174   """Asserts actual_tensor's shape is expected_shape.
    175 
    176   Args:
    177     expected_shape: List of integers defining the expected shape, or tensor of
    178         same.
    179     actual_tensor: Tensor to test.
    180   Returns:
    181     New assert tensor.
    182   """
    183   with ops.name_scope('assert_shape', values=[actual_tensor]) as scope:
    184     actual_shape = array_ops.shape(actual_tensor, name='actual')
    185     is_shape = _is_shape(expected_shape, actual_tensor, actual_shape)
    186     return control_flow_ops.Assert(
    187         is_shape, [
    188             'Wrong shape for %s [expected] [actual].' % actual_tensor.name,
    189             expected_shape,
    190             actual_shape
    191         ], name=scope)
    192 
    193 
    194 def with_same_shape(expected_tensor, tensor):
    195   """Assert tensors are the same shape, from the same graph.
    196 
    197   Args:
    198     expected_tensor: Tensor with expected shape.
    199     tensor: Tensor of actual values.
    200   Returns:
    201     The original tensor argument, possibly with assert ops added.
    202   """
    203   with ops.name_scope('%s/' % tensor.op.name, values=[expected_tensor, tensor]):
    204     tensor_shape = expected_tensor.get_shape()
    205     expected_shape = (
    206         tensor_shape.as_list() if tensor_shape.is_fully_defined()
    207         else array_ops.shape(expected_tensor, name='expected_shape'))
    208     return with_shape(expected_shape, tensor)
    209 
    210 
    211 def with_shape(expected_shape, tensor):
    212   """Asserts tensor has expected shape.
    213 
    214   If tensor shape and expected_shape, are fully defined, assert they match.
    215   Otherwise, add assert op that will validate the shape when tensor is
    216   evaluated, and set shape on tensor.
    217 
    218   Args:
    219     expected_shape: Expected shape to assert, as a 1D array of ints, or tensor
    220         of same.
    221     tensor: Tensor whose shape we're validating.
    222   Returns:
    223     tensor, perhaps with a dependent assert operation.
    224   Raises:
    225     ValueError: if tensor has an invalid shape.
    226   """
    227   if isinstance(tensor, sparse_tensor.SparseTensor):
    228     raise ValueError('SparseTensor not supported.')
    229 
    230   # Shape type must be 1D int32.
    231   if tensor_util.is_tensor(expected_shape):
    232     if expected_shape.dtype.base_dtype != dtypes.int32:
    233       raise ValueError(
    234           'Invalid dtype %s for shape %s expected of tensor %s.' % (
    235               expected_shape.dtype, expected_shape, tensor.name))
    236   if isinstance(expected_shape, (list, tuple)):
    237     if not expected_shape:
    238       expected_shape = np.asarray([], dtype=np.int32)
    239     else:
    240       np_expected_shape = np.asarray(expected_shape)
    241       expected_shape = (
    242           np.asarray(expected_shape, dtype=np.int32)
    243           if np_expected_shape.dtype == np.int64 else np_expected_shape)
    244   if isinstance(expected_shape, np.ndarray):
    245     if expected_shape.ndim > 1:
    246       raise ValueError(
    247           'Invalid rank %s for shape %s expected of tensor %s.' % (
    248               expected_shape.ndim, expected_shape, tensor.name))
    249     if expected_shape.dtype != np.int32:
    250       raise ValueError(
    251           'Invalid dtype %s for shape %s expected of tensor %s.' % (
    252               expected_shape.dtype, expected_shape, tensor.name))
    253 
    254   actual_shape = tensor.get_shape()
    255 
    256   if (not actual_shape.is_fully_defined()
    257       or tensor_util.is_tensor(expected_shape)):
    258     with ops.name_scope('%s/' % tensor.op.name, values=[tensor]):
    259       if (not tensor_util.is_tensor(expected_shape)
    260           and (len(expected_shape) < 1)):
    261         # TODO(irving): Remove scalar special case
    262         return array_ops.reshape(tensor, [])
    263       with ops.control_dependencies([_assert_shape_op(expected_shape, tensor)]):
    264         result = array_ops.identity(tensor)
    265       if not tensor_util.is_tensor(expected_shape):
    266         result.set_shape(expected_shape)
    267       return result
    268 
    269   if (not tensor_util.is_tensor(expected_shape) and
    270       not actual_shape.is_compatible_with(expected_shape)):
    271     if (len(expected_shape) < 1) and actual_shape.is_compatible_with([1]):
    272       # TODO(irving): Remove scalar special case.
    273       with ops.name_scope('%s/' % tensor.op.name, values=[tensor]):
    274         return array_ops.reshape(tensor, [])
    275     raise ValueError('Invalid shape for tensor %s, expected %s, got %s.' % (
    276         tensor.name, expected_shape, actual_shape))
    277 
    278   return tensor
    279 
    280 
    281 def assert_scalar_int(tensor, name=None):
    282   """Assert `tensor` is 0-D, of type `tf.int32` or `tf.int64`.
    283 
    284   Args:
    285     tensor: `Tensor` to test.
    286     name: Name of the op and of the new `Tensor` if one is created.
    287   Returns:
    288     `tensor`, for chaining.
    289   Raises:
    290     ValueError: if `tensor` is not 0-D, of integer type.
    291   """
    292   with ops.name_scope(name, 'assert_scalar_int', [tensor]) as name_scope:
    293     tensor = ops.convert_to_tensor(tensor)
    294     data_type = tensor.dtype
    295     if not data_type.base_dtype.is_integer:
    296       raise ValueError('Expected integer type for %s, received type: %s.'
    297                        % (tensor.name, data_type))
    298     return check_ops.assert_scalar(tensor, name=name_scope)
    299