1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Tensor utility functions.""" 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import ops 25 from tensorflow.python.framework import sparse_tensor 26 from tensorflow.python.framework import tensor_util 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import check_ops 29 from tensorflow.python.ops import control_flow_ops 30 from tensorflow.python.ops import math_ops 31 from tensorflow.python.util.deprecation import deprecated 32 33 34 __all__ = [ 35 'assert_same_float_dtype', 36 'assert_scalar', 37 'assert_scalar_int', 38 'convert_to_tensor_or_sparse_tensor', 39 'is_tensor', 40 'reduce_sum_n', 41 'remove_squeezable_dimensions', 42 'with_shape', 43 'with_same_shape'] 44 45 46 # Temporary for backwards compatibility 47 is_tensor = tensor_util.is_tensor 48 assert_same_float_dtype = check_ops.assert_same_float_dtype 49 assert_scalar = check_ops.assert_scalar 50 51 convert_to_tensor_or_sparse_tensor = ( 52 sparse_tensor.convert_to_tensor_or_sparse_tensor) 53 54 55 def reduce_sum_n(tensors, name=None): 56 """Reduce tensors to a scalar sum. 57 58 This reduces each tensor in `tensors` to a scalar via `tf.reduce_sum`, then 59 adds them via `tf.add_n`. 60 61 Args: 62 tensors: List of tensors, all of the same numeric type. 63 name: Tensor name, and scope for all other ops. 64 65 Returns: 66 Total loss tensor, or None if no losses have been configured. 67 68 Raises: 69 ValueError: if `losses` is missing or empty. 70 """ 71 if not tensors: 72 raise ValueError('No tensors provided.') 73 with ops.name_scope(name, 'reduce_sum_n', tensors) as name_scope: 74 tensors = [ 75 math_ops.reduce_sum(t, name='%s/sum' % t.op.name) for t in tensors] 76 if len(tensors) == 1: 77 return tensors[0] 78 return math_ops.add_n(tensors, name=name_scope) 79 80 @deprecated( 81 None, "Please switch to remove_squeezable_dimensions from " 82 "tf.confusion_matrix. Note that the order of the inputs and outputs of " 83 "labels and predictions have also been switched.") 84 def remove_squeezable_dimensions(predictions, labels, name=None): 85 """Squeeze last dim if ranks of `predictions` and `labels` differ by 1. 86 87 This will use static shape if available. Otherwise, it will add graph 88 operations, which could result in a performance hit. 89 90 Args: 91 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 92 labels: Label values, a `Tensor` whose dimensions match `predictions`. 93 name: Name of the op. 94 95 Returns: 96 Tuple of `predictions` and `labels`, possibly with last dim squeezed. 97 """ 98 with ops.name_scope(name, 'remove_squeezable_dimensions', 99 [predictions, labels]): 100 predictions = ops.convert_to_tensor(predictions) 101 labels = ops.convert_to_tensor(labels) 102 predictions_shape = predictions.get_shape() 103 predictions_rank = predictions_shape.ndims 104 labels_shape = labels.get_shape() 105 labels_rank = labels_shape.ndims 106 if (labels_rank is not None) and (predictions_rank is not None): 107 # Use static rank. 108 rank_diff = predictions_rank - labels_rank 109 if rank_diff == -1: 110 labels = array_ops.squeeze(labels, [-1]) 111 elif rank_diff == 1: 112 predictions = array_ops.squeeze(predictions, [-1]) 113 return predictions, labels 114 115 # Use dynamic rank. 116 rank_diff = array_ops.rank(predictions) - array_ops.rank(labels) 117 if (predictions_rank is None) or ( 118 predictions_shape.dims[-1].is_compatible_with(1)): 119 predictions = control_flow_ops.cond( 120 math_ops.equal(1, rank_diff), 121 lambda: array_ops.squeeze(predictions, [-1]), 122 lambda: predictions) 123 if (labels_rank is None) or ( 124 labels_shape.dims[-1].is_compatible_with(1)): 125 labels = control_flow_ops.cond( 126 math_ops.equal(-1, rank_diff), 127 lambda: array_ops.squeeze(labels, [-1]), 128 lambda: labels) 129 return predictions, labels 130 131 132 def _all_equal(tensor0, tensor1): 133 with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope: 134 return math_ops.reduce_all( 135 math_ops.equal(tensor0, tensor1, name='equal'), name=scope) 136 137 138 def _is_rank(expected_rank, actual_tensor): 139 """Returns whether actual_tensor's rank is expected_rank. 140 141 Args: 142 expected_rank: Integer defining the expected rank, or tensor of same. 143 actual_tensor: Tensor to test. 144 Returns: 145 New tensor. 146 """ 147 with ops.name_scope('is_rank', values=[actual_tensor]) as scope: 148 expected = ops.convert_to_tensor(expected_rank, name='expected') 149 actual = array_ops.rank(actual_tensor, name='actual') 150 return math_ops.equal(expected, actual, name=scope) 151 152 153 def _is_shape(expected_shape, actual_tensor, actual_shape=None): 154 """Returns whether actual_tensor's shape is expected_shape. 155 156 Args: 157 expected_shape: Integer list defining the expected shape, or tensor of same. 158 actual_tensor: Tensor to test. 159 actual_shape: Shape of actual_tensor, if we already have it. 160 Returns: 161 New tensor. 162 """ 163 with ops.name_scope('is_shape', values=[actual_tensor]) as scope: 164 is_rank = _is_rank(array_ops.size(expected_shape), actual_tensor) 165 if actual_shape is None: 166 actual_shape = array_ops.shape(actual_tensor, name='actual') 167 shape_equal = _all_equal( 168 ops.convert_to_tensor(expected_shape, name='expected'), 169 actual_shape) 170 return math_ops.logical_and(is_rank, shape_equal, name=scope) 171 172 173 def _assert_shape_op(expected_shape, actual_tensor): 174 """Asserts actual_tensor's shape is expected_shape. 175 176 Args: 177 expected_shape: List of integers defining the expected shape, or tensor of 178 same. 179 actual_tensor: Tensor to test. 180 Returns: 181 New assert tensor. 182 """ 183 with ops.name_scope('assert_shape', values=[actual_tensor]) as scope: 184 actual_shape = array_ops.shape(actual_tensor, name='actual') 185 is_shape = _is_shape(expected_shape, actual_tensor, actual_shape) 186 return control_flow_ops.Assert( 187 is_shape, [ 188 'Wrong shape for %s [expected] [actual].' % actual_tensor.name, 189 expected_shape, 190 actual_shape 191 ], name=scope) 192 193 194 def with_same_shape(expected_tensor, tensor): 195 """Assert tensors are the same shape, from the same graph. 196 197 Args: 198 expected_tensor: Tensor with expected shape. 199 tensor: Tensor of actual values. 200 Returns: 201 The original tensor argument, possibly with assert ops added. 202 """ 203 with ops.name_scope('%s/' % tensor.op.name, values=[expected_tensor, tensor]): 204 tensor_shape = expected_tensor.get_shape() 205 expected_shape = ( 206 tensor_shape.as_list() if tensor_shape.is_fully_defined() 207 else array_ops.shape(expected_tensor, name='expected_shape')) 208 return with_shape(expected_shape, tensor) 209 210 211 def with_shape(expected_shape, tensor): 212 """Asserts tensor has expected shape. 213 214 If tensor shape and expected_shape, are fully defined, assert they match. 215 Otherwise, add assert op that will validate the shape when tensor is 216 evaluated, and set shape on tensor. 217 218 Args: 219 expected_shape: Expected shape to assert, as a 1D array of ints, or tensor 220 of same. 221 tensor: Tensor whose shape we're validating. 222 Returns: 223 tensor, perhaps with a dependent assert operation. 224 Raises: 225 ValueError: if tensor has an invalid shape. 226 """ 227 if isinstance(tensor, sparse_tensor.SparseTensor): 228 raise ValueError('SparseTensor not supported.') 229 230 # Shape type must be 1D int32. 231 if tensor_util.is_tensor(expected_shape): 232 if expected_shape.dtype.base_dtype != dtypes.int32: 233 raise ValueError( 234 'Invalid dtype %s for shape %s expected of tensor %s.' % ( 235 expected_shape.dtype, expected_shape, tensor.name)) 236 if isinstance(expected_shape, (list, tuple)): 237 if not expected_shape: 238 expected_shape = np.asarray([], dtype=np.int32) 239 else: 240 np_expected_shape = np.asarray(expected_shape) 241 expected_shape = ( 242 np.asarray(expected_shape, dtype=np.int32) 243 if np_expected_shape.dtype == np.int64 else np_expected_shape) 244 if isinstance(expected_shape, np.ndarray): 245 if expected_shape.ndim > 1: 246 raise ValueError( 247 'Invalid rank %s for shape %s expected of tensor %s.' % ( 248 expected_shape.ndim, expected_shape, tensor.name)) 249 if expected_shape.dtype != np.int32: 250 raise ValueError( 251 'Invalid dtype %s for shape %s expected of tensor %s.' % ( 252 expected_shape.dtype, expected_shape, tensor.name)) 253 254 actual_shape = tensor.get_shape() 255 256 if (not actual_shape.is_fully_defined() 257 or tensor_util.is_tensor(expected_shape)): 258 with ops.name_scope('%s/' % tensor.op.name, values=[tensor]): 259 if (not tensor_util.is_tensor(expected_shape) 260 and (len(expected_shape) < 1)): 261 # TODO(irving): Remove scalar special case 262 return array_ops.reshape(tensor, []) 263 with ops.control_dependencies([_assert_shape_op(expected_shape, tensor)]): 264 result = array_ops.identity(tensor) 265 if not tensor_util.is_tensor(expected_shape): 266 result.set_shape(expected_shape) 267 return result 268 269 if (not tensor_util.is_tensor(expected_shape) and 270 not actual_shape.is_compatible_with(expected_shape)): 271 if (len(expected_shape) < 1) and actual_shape.is_compatible_with([1]): 272 # TODO(irving): Remove scalar special case. 273 with ops.name_scope('%s/' % tensor.op.name, values=[tensor]): 274 return array_ops.reshape(tensor, []) 275 raise ValueError('Invalid shape for tensor %s, expected %s, got %s.' % ( 276 tensor.name, expected_shape, actual_shape)) 277 278 return tensor 279 280 281 def assert_scalar_int(tensor, name=None): 282 """Assert `tensor` is 0-D, of type `tf.int32` or `tf.int64`. 283 284 Args: 285 tensor: `Tensor` to test. 286 name: Name of the op and of the new `Tensor` if one is created. 287 Returns: 288 `tensor`, for chaining. 289 Raises: 290 ValueError: if `tensor` is not 0-D, of integer type. 291 """ 292 with ops.name_scope(name, 'assert_scalar_int', [tensor]) as name_scope: 293 tensor = ops.convert_to_tensor(tensor) 294 data_type = tensor.dtype 295 if not data_type.base_dtype.is_integer: 296 raise ValueError('Expected integer type for %s, received type: %s.' 297 % (tensor.name, data_type)) 298 return check_ops.assert_scalar(tensor, name=name_scope) 299