1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Utilities to create TensorProtos.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 import six 22 23 from tensorflow.core.framework import tensor_pb2 24 from tensorflow.core.framework import tensor_shape_pb2 25 from tensorflow.python.eager import context 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import tensor_shape 28 from tensorflow.python.util import compat 29 30 # Fallback in case fast_tensor_util is not properly compiled. 31 # pylint: disable=g-import-not-at-top 32 try: 33 from tensorflow.python.framework import fast_tensor_util 34 _FAST_TENSOR_UTIL_AVAILABLE = True 35 except ImportError: 36 _FAST_TENSOR_UTIL_AVAILABLE = False 37 38 from tensorflow.python.framework import dtypes 39 from tensorflow.python.framework import ops 40 from tensorflow.python.util.tf_export import tf_export 41 42 # pylint: enable=g-import-not-at-top 43 44 45 def ExtractBitsFromFloat16(x): 46 return np.asscalar(np.asarray(x, dtype=np.float16).view(np.uint16)) 47 48 49 def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values): 50 tensor_proto.half_val.extend( 51 [ExtractBitsFromFloat16(x) for x in proto_values]) 52 53 54 def ExtractBitsFromBFloat16(x): 55 return np.asscalar( 56 np.asarray(x, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16)) 57 58 59 def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): 60 tensor_proto.half_val.extend( 61 [ExtractBitsFromBFloat16(x) for x in proto_values]) 62 63 64 if _FAST_TENSOR_UTIL_AVAILABLE: 65 _NP_TO_APPEND_FN = { 66 dtypes.bfloat16.as_numpy_dtype: 67 SlowAppendBFloat16ArrayToTensorProto, 68 # TODO(sesse): We should have a 69 # fast_tensor_util.AppendFloat16ArrayToTensorProto, 70 # but it seems np.float16_t doesn't exist? 71 np.float16: 72 SlowAppendFloat16ArrayToTensorProto, 73 np.float32: 74 fast_tensor_util.AppendFloat32ArrayToTensorProto, 75 np.float64: 76 fast_tensor_util.AppendFloat64ArrayToTensorProto, 77 np.int32: 78 fast_tensor_util.AppendInt32ArrayToTensorProto, 79 np.int64: 80 fast_tensor_util.AppendInt64ArrayToTensorProto, 81 np.uint8: 82 fast_tensor_util.AppendUInt8ArrayToTensorProto, 83 np.uint16: 84 fast_tensor_util.AppendUInt16ArrayToTensorProto, 85 np.uint32: 86 fast_tensor_util.AppendUInt32ArrayToTensorProto, 87 np.uint64: 88 fast_tensor_util.AppendUInt64ArrayToTensorProto, 89 np.int8: 90 fast_tensor_util.AppendInt8ArrayToTensorProto, 91 np.int16: 92 fast_tensor_util.AppendInt16ArrayToTensorProto, 93 np.complex64: 94 fast_tensor_util.AppendComplex64ArrayToTensorProto, 95 np.complex128: 96 fast_tensor_util.AppendComplex128ArrayToTensorProto, 97 np.object: 98 fast_tensor_util.AppendObjectArrayToTensorProto, 99 np.bool: 100 fast_tensor_util.AppendBoolArrayToTensorProto, 101 dtypes.qint8.as_numpy_dtype: 102 fast_tensor_util.AppendInt8ArrayToTensorProto, 103 dtypes.quint8.as_numpy_dtype: 104 fast_tensor_util.AppendUInt8ArrayToTensorProto, 105 dtypes.qint16.as_numpy_dtype: 106 fast_tensor_util.AppendInt8ArrayToTensorProto, 107 dtypes.quint16.as_numpy_dtype: 108 fast_tensor_util.AppendUInt8ArrayToTensorProto, 109 dtypes.qint32.as_numpy_dtype: 110 fast_tensor_util.AppendInt32ArrayToTensorProto, 111 # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. 112 } 113 else: 114 115 def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values): 116 tensor_proto.float_val.extend([np.asscalar(x) for x in proto_values]) 117 118 def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values): 119 tensor_proto.double_val.extend([np.asscalar(x) for x in proto_values]) 120 121 def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values): 122 tensor_proto.int_val.extend([np.asscalar(x) for x in proto_values]) 123 124 def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values): 125 tensor_proto.int64_val.extend([np.asscalar(x) for x in proto_values]) 126 127 def SlowAppendQIntArrayToTensorProto(tensor_proto, proto_values): 128 tensor_proto.int_val.extend([np.asscalar(x[0]) for x in proto_values]) 129 130 def SlowAppendUInt32ArrayToTensorProto(tensor_proto, proto_values): 131 tensor_proto.uint32_val.extend([np.asscalar(x) for x in proto_values]) 132 133 def SlowAppendUInt64ArrayToTensorProto(tensor_proto, proto_values): 134 tensor_proto.uint64_val.extend([np.asscalar(x) for x in proto_values]) 135 136 def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values): 137 tensor_proto.scomplex_val.extend( 138 [np.asscalar(v) for x in proto_values for v in [x.real, x.imag]]) 139 140 def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values): 141 tensor_proto.dcomplex_val.extend( 142 [np.asscalar(v) for x in proto_values for v in [x.real, x.imag]]) 143 144 def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values): 145 tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values]) 146 147 def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values): 148 tensor_proto.bool_val.extend([np.asscalar(x) for x in proto_values]) 149 150 _NP_TO_APPEND_FN = { 151 dtypes.bfloat16.as_numpy_dtype: SlowAppendBFloat16ArrayToTensorProto, 152 np.float16: SlowAppendFloat16ArrayToTensorProto, 153 np.float32: SlowAppendFloat32ArrayToTensorProto, 154 np.float64: SlowAppendFloat64ArrayToTensorProto, 155 np.int32: SlowAppendIntArrayToTensorProto, 156 np.int64: SlowAppendInt64ArrayToTensorProto, 157 np.uint8: SlowAppendIntArrayToTensorProto, 158 np.uint16: SlowAppendIntArrayToTensorProto, 159 np.uint32: SlowAppendUInt32ArrayToTensorProto, 160 np.uint64: SlowAppendUInt64ArrayToTensorProto, 161 np.int8: SlowAppendIntArrayToTensorProto, 162 np.int16: SlowAppendIntArrayToTensorProto, 163 np.complex64: SlowAppendComplex64ArrayToTensorProto, 164 np.complex128: SlowAppendComplex128ArrayToTensorProto, 165 np.object: SlowAppendObjectArrayToTensorProto, 166 np.bool: SlowAppendBoolArrayToTensorProto, 167 dtypes.qint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 168 dtypes.quint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 169 dtypes.qint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 170 dtypes.quint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 171 dtypes.qint32.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 172 # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. 173 } 174 175 176 def GetFromNumpyDTypeDict(dtype_dict, dtype): 177 # NOTE: dtype_dict.get(dtype) always returns None. 178 for key, val in six.iteritems(dtype_dict): 179 if key == dtype: 180 return val 181 return None 182 183 184 def GetNumpyAppendFn(dtype): 185 # numpy dtype for strings are variable length. We can not compare 186 # dtype with a single constant (np.string does not exist) to decide 187 # dtype is a "string" type. We need to compare the dtype.type to be 188 # sure it's a string type. 189 if dtype.type == np.string_ or dtype.type == np.unicode_: 190 if _FAST_TENSOR_UTIL_AVAILABLE: 191 return fast_tensor_util.AppendObjectArrayToTensorProto 192 else: 193 return SlowAppendObjectArrayToTensorProto 194 return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype) 195 196 197 def TensorShapeProtoToList(shape): 198 """Convert a TensorShape to a list. 199 200 Args: 201 shape: A TensorShapeProto. 202 203 Returns: 204 List of integers representing the dimensions of the tensor. 205 """ 206 return [dim.size for dim in shape.dim] 207 208 209 def _GetDenseDimensions(list_of_lists): 210 """Returns the inferred dense dimensions of a list of lists.""" 211 if not isinstance(list_of_lists, (list, tuple)): 212 return [] 213 elif not list_of_lists: 214 return [0] 215 else: 216 return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0]) 217 218 219 def _FlattenToStrings(nested_strings): 220 if isinstance(nested_strings, (list, tuple)): 221 for inner in nested_strings: 222 for flattened_string in _FlattenToStrings(inner): 223 yield flattened_string 224 else: 225 yield nested_strings 226 227 228 _TENSOR_CONTENT_TYPES = frozenset([ 229 dtypes.float32, dtypes.float64, dtypes.int32, dtypes.uint8, dtypes.int16, 230 dtypes.int8, dtypes.int64, dtypes.qint8, dtypes.quint8, dtypes.qint16, 231 dtypes.quint16, dtypes.qint32, dtypes.uint32, dtypes.uint64 232 ]) 233 234 235 class _Message(object): 236 237 def __init__(self, message): 238 self._message = message 239 240 def __repr__(self): 241 return self._message 242 243 244 def _FirstNotNone(l): 245 for x in l: 246 if x is not None: 247 if isinstance(x, ops.Tensor): 248 return _Message("list containing Tensors") 249 else: 250 return x 251 return None 252 253 254 def _NotNone(v): 255 if v is None: 256 return _Message("None") 257 else: 258 return v 259 260 261 def _FilterTuple(v): 262 if not isinstance(v, (list, tuple)): 263 return v 264 if isinstance(v, tuple): 265 if not any(isinstance(x, (list, tuple)) for x in v): 266 return None 267 if isinstance(v, list): 268 if not any(isinstance(x, (list, tuple)) for x in v): 269 return _FirstNotNone( 270 [None if isinstance(x, (list, tuple)) else x for x in v]) 271 return _FirstNotNone([_FilterTuple(x) for x in v]) 272 273 274 def _FilterInt(v): 275 if isinstance(v, (list, tuple)): 276 return _FirstNotNone([_FilterInt(x) for x in v]) 277 return None if isinstance( 278 v, (compat.integral_types, tensor_shape.Dimension)) else _NotNone(v) 279 280 281 def _FilterFloat(v): 282 if isinstance(v, (list, tuple)): 283 return _FirstNotNone([_FilterFloat(x) for x in v]) 284 return None if isinstance(v, compat.real_types) else _NotNone(v) 285 286 287 def _FilterComplex(v): 288 if isinstance(v, (list, tuple)): 289 return _FirstNotNone([_FilterComplex(x) for x in v]) 290 return None if isinstance(v, compat.complex_types) else _NotNone(v) 291 292 293 def _FilterStr(v): 294 if isinstance(v, (list, tuple)): 295 return _FirstNotNone([_FilterStr(x) for x in v]) 296 if isinstance(v, compat.bytes_or_text_types): 297 return None 298 else: 299 return _NotNone(v) 300 301 302 def _FilterBool(v): 303 if isinstance(v, (list, tuple)): 304 return _FirstNotNone([_FilterBool(x) for x in v]) 305 return None if isinstance(v, bool) else _NotNone(v) 306 307 308 def _FilterNotTensor(v): 309 if isinstance(v, (list, tuple)): 310 return _FirstNotNone([_FilterNotTensor(x) for x in v]) 311 return str(v) if isinstance(v, ops.Tensor) else None 312 313 314 _TF_TO_IS_OK = { 315 dtypes.bool: [_FilterBool], 316 dtypes.complex128: [_FilterComplex], 317 dtypes.complex64: [_FilterComplex], 318 dtypes.float16: [_FilterFloat], 319 dtypes.float32: [_FilterFloat], 320 dtypes.float64: [_FilterFloat], 321 dtypes.int16: [_FilterInt], 322 dtypes.int32: [_FilterInt], 323 dtypes.int64: [_FilterInt], 324 dtypes.int8: [_FilterInt], 325 dtypes.qint16: [_FilterInt, _FilterTuple], 326 dtypes.qint32: [_FilterInt, _FilterTuple], 327 dtypes.qint8: [_FilterInt, _FilterTuple], 328 dtypes.quint16: [_FilterInt, _FilterTuple], 329 dtypes.quint8: [_FilterInt, _FilterTuple], 330 dtypes.string: [_FilterStr], 331 dtypes.uint16: [_FilterInt], 332 dtypes.uint8: [_FilterInt], 333 } 334 335 336 def _AssertCompatible(values, dtype): 337 fn_list = _TF_TO_IS_OK.get(dtype, [_FilterNotTensor]) 338 mismatch = _FirstNotNone([fn(values) for fn in fn_list]) 339 if mismatch is not None: 340 if dtype is None: 341 raise TypeError("List of Tensors when single Tensor expected") 342 else: 343 raise TypeError("Expected %s, got %s of type '%s' instead." % 344 (dtype.name, repr(mismatch), type(mismatch).__name__)) 345 346 347 @tf_export("make_tensor_proto") 348 def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): 349 """Create a TensorProto. 350 351 Args: 352 values: Values to put in the TensorProto. 353 dtype: Optional tensor_pb2 DataType value. 354 shape: List of integers representing the dimensions of tensor. 355 verify_shape: Boolean that enables verification of a shape of values. 356 357 Returns: 358 A `TensorProto`. Depending on the type, it may contain data in the 359 "tensor_content" attribute, which is not directly useful to Python programs. 360 To access the values you should convert the proto back to a numpy ndarray 361 with `tensor_util.MakeNdarray(proto)`. 362 363 If `values` is a `TensorProto`, it is immediately returned; `dtype` and 364 `shape` are ignored. 365 366 Raises: 367 TypeError: if unsupported types are provided. 368 ValueError: if arguments have inappropriate values or if verify_shape is 369 True and shape of values is not equals to a shape from the argument. 370 371 make_tensor_proto accepts "values" of a python scalar, a python list, a 372 numpy ndarray, or a numpy scalar. 373 374 If "values" is a python scalar or a python list, make_tensor_proto 375 first convert it to numpy ndarray. If dtype is None, the 376 conversion tries its best to infer the right numpy data 377 type. Otherwise, the resulting numpy array has a compatible data 378 type with the given dtype. 379 380 In either case above, the numpy ndarray (either the caller provided 381 or the auto converted) must have the compatible type with dtype. 382 383 make_tensor_proto then converts the numpy array to a tensor proto. 384 385 If "shape" is None, the resulting tensor proto represents the numpy 386 array precisely. 387 388 Otherwise, "shape" specifies the tensor's shape and the numpy array 389 can not have more elements than what "shape" specifies. 390 391 """ 392 if isinstance(values, tensor_pb2.TensorProto): 393 return values 394 395 if dtype: 396 dtype = dtypes.as_dtype(dtype) 397 398 is_quantized = ( 399 dtype in [ 400 dtypes.qint8, dtypes.quint8, dtypes.qint16, dtypes.quint16, 401 dtypes.qint32 402 ]) 403 404 # We first convert value to a numpy array or scalar. 405 if isinstance(values, (np.ndarray, np.generic)): 406 if dtype: 407 nparray = values.astype(dtype.as_numpy_dtype) 408 else: 409 nparray = values 410 elif callable(getattr(values, "__array__", None)) or isinstance( 411 getattr(values, "__array_interface__", None), dict): 412 # If a class has the __array__ method, or __array_interface__ dict, then it 413 # is possible to convert to numpy array. 414 nparray = np.asarray(values, dtype=dtype) 415 416 # This is the preferred way to create an array from the object, so replace 417 # the `values` with the array so that _FlattenToStrings is not run. 418 values = nparray 419 else: 420 if values is None: 421 raise ValueError("None values not supported.") 422 # if dtype is provided, forces numpy array to be the type 423 # provided if possible. 424 if dtype and dtype.is_numpy_compatible: 425 np_dt = dtype.as_numpy_dtype 426 else: 427 np_dt = None 428 # If shape is None, numpy.prod returns None when dtype is not set, but raises 429 # exception when dtype is set to np.int64 430 if shape is not None and np.prod(shape, dtype=np.int64) == 0: 431 nparray = np.empty(shape, dtype=np_dt) 432 else: 433 _AssertCompatible(values, dtype) 434 nparray = np.array(values, dtype=np_dt) 435 # check to them. 436 # We need to pass in quantized values as tuples, so don't apply the shape 437 if (list(nparray.shape) != _GetDenseDimensions(values) and 438 not is_quantized): 439 raise ValueError("""Argument must be a dense tensor: %s""" 440 """ - got shape %s, but wanted %s.""" % 441 (values, list(nparray.shape), 442 _GetDenseDimensions(values))) 443 444 # python/numpy default float type is float64. We prefer float32 instead. 445 if (nparray.dtype == np.float64) and dtype is None: 446 nparray = nparray.astype(np.float32) 447 # python/numpy default int type is int64. We prefer int32 instead. 448 elif (nparray.dtype == np.int64) and dtype is None: 449 downcasted_array = nparray.astype(np.int32) 450 # Do not down cast if it leads to precision loss. 451 if np.array_equal(downcasted_array, nparray): 452 nparray = downcasted_array 453 454 # if dtype is provided, it must be compatible with what numpy 455 # conversion says. 456 numpy_dtype = dtypes.as_dtype(nparray.dtype) 457 if numpy_dtype is None: 458 raise TypeError("Unrecognized data type: %s" % nparray.dtype) 459 460 # If dtype was specified and is a quantized type, we convert 461 # numpy_dtype back into the quantized version. 462 if is_quantized: 463 numpy_dtype = dtype 464 465 if dtype is not None and (not hasattr(dtype, "base_dtype") or 466 dtype.base_dtype != numpy_dtype.base_dtype): 467 raise TypeError("Incompatible types: %s vs. %s. Value is %s" % 468 (dtype, nparray.dtype, values)) 469 470 # If shape is not given, get the shape from the numpy array. 471 if shape is None: 472 shape = nparray.shape 473 is_same_size = True 474 shape_size = nparray.size 475 else: 476 shape = [int(dim) for dim in shape] 477 shape_size = np.prod(shape, dtype=np.int64) 478 is_same_size = shape_size == nparray.size 479 480 if verify_shape: 481 if not nparray.shape == tuple(shape): 482 raise TypeError("Expected Tensor's shape: %s, got %s." % 483 (tuple(shape), nparray.shape)) 484 485 if nparray.size > shape_size: 486 raise ValueError( 487 "Too many elements provided. Needed at most %d, but received %d" % 488 (shape_size, nparray.size)) 489 490 tensor_proto = tensor_pb2.TensorProto( 491 dtype=numpy_dtype.as_datatype_enum, 492 tensor_shape=tensor_shape.as_shape(shape).as_proto()) 493 494 if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1: 495 if nparray.size * nparray.itemsize >= (1 << 31): 496 raise ValueError( 497 "Cannot create a tensor proto whose content is larger than 2GB.") 498 tensor_proto.tensor_content = nparray.tostring() 499 return tensor_proto 500 501 # If we were not given values as a numpy array, compute the proto_values 502 # from the given values directly, to avoid numpy trimming nulls from the 503 # strings. Since values could be a list of strings, or a multi-dimensional 504 # list of lists that might or might not correspond to the given shape, 505 # we flatten it conservatively. 506 if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray): 507 proto_values = _FlattenToStrings(values) 508 509 # At this point, values may be a list of objects that we could not 510 # identify a common type for (hence it was inferred as 511 # np.object/dtypes.string). If we are unable to convert it to a 512 # string, we raise a more helpful error message. 513 # 514 # Ideally, we'd be able to convert the elements of the list to a 515 # common type, but this type inference requires some thinking and 516 # so we defer it for now. 517 try: 518 str_values = [compat.as_bytes(x) for x in proto_values] 519 except TypeError: 520 raise TypeError("Failed to convert object of type %s to Tensor. " 521 "Contents: %s. Consider casting elements to a " 522 "supported type." % (type(values), values)) 523 tensor_proto.string_val.extend(str_values) 524 return tensor_proto 525 526 # TensorFlow expects C order (a.k.a., eigen row major). 527 proto_values = nparray.ravel() 528 529 append_fn = GetNumpyAppendFn(proto_values.dtype) 530 if append_fn is None: 531 raise TypeError( 532 "Element type not supported in TensorProto: %s" % numpy_dtype.name) 533 append_fn(tensor_proto, proto_values) 534 535 return tensor_proto 536 537 538 @tf_export("make_ndarray") 539 def MakeNdarray(tensor): 540 """Create a numpy ndarray from a tensor. 541 542 Create a numpy ndarray with the same shape and data as the tensor. 543 544 Args: 545 tensor: A TensorProto. 546 547 Returns: 548 A numpy array with the tensor contents. 549 550 Raises: 551 TypeError: if tensor has unsupported type. 552 553 """ 554 shape = [d.size for d in tensor.tensor_shape.dim] 555 num_elements = np.prod(shape, dtype=np.int64) 556 tensor_dtype = dtypes.as_dtype(tensor.dtype) 557 dtype = tensor_dtype.as_numpy_dtype 558 559 if tensor.tensor_content: 560 return (np.frombuffer(tensor.tensor_content, dtype=dtype).copy() 561 .reshape(shape)) 562 elif tensor_dtype == dtypes.float16: 563 # the half_val field of the TensorProto stores the binary representation 564 # of the fp16: we need to reinterpret this as a proper float16 565 if len(tensor.half_val) == 1: 566 tmp = np.array(tensor.half_val[0], dtype=np.uint16) 567 tmp.dtype = np.float16 568 return np.repeat(tmp, num_elements).reshape(shape) 569 else: 570 tmp = np.fromiter(tensor.half_val, dtype=np.uint16) 571 tmp.dtype = np.float16 572 return tmp.reshape(shape) 573 elif tensor_dtype == dtypes.float32: 574 if len(tensor.float_val) == 1: 575 return np.repeat( 576 np.array(tensor.float_val[0], dtype=dtype), 577 num_elements).reshape(shape) 578 else: 579 return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape) 580 elif tensor_dtype == dtypes.float64: 581 if len(tensor.double_val) == 1: 582 return np.repeat( 583 np.array(tensor.double_val[0], dtype=dtype), 584 num_elements).reshape(shape) 585 else: 586 return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape) 587 elif tensor_dtype in [ 588 dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8, 589 dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16, 590 dtypes.bfloat16 591 ]: 592 if len(tensor.int_val) == 1: 593 return np.repeat(np.array(tensor.int_val[0], dtype=dtype), 594 num_elements).reshape(shape) 595 else: 596 return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape) 597 elif tensor_dtype == dtypes.int64: 598 if len(tensor.int64_val) == 1: 599 return np.repeat( 600 np.array(tensor.int64_val[0], dtype=dtype), 601 num_elements).reshape(shape) 602 else: 603 return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape) 604 elif tensor_dtype == dtypes.string: 605 if len(tensor.string_val) == 1: 606 return np.repeat( 607 np.array(tensor.string_val[0], dtype=dtype), 608 num_elements).reshape(shape) 609 else: 610 return np.array( 611 [x for x in tensor.string_val], dtype=dtype).reshape(shape) 612 elif tensor_dtype == dtypes.complex64: 613 it = iter(tensor.scomplex_val) 614 if len(tensor.scomplex_val) == 2: 615 return np.repeat( 616 np.array( 617 complex(tensor.scomplex_val[0], tensor.scomplex_val[1]), 618 dtype=dtype), num_elements).reshape(shape) 619 else: 620 return np.array( 621 [complex(x[0], x[1]) for x in zip(it, it)], 622 dtype=dtype).reshape(shape) 623 elif tensor_dtype == dtypes.complex128: 624 it = iter(tensor.dcomplex_val) 625 if len(tensor.dcomplex_val) == 2: 626 return np.repeat( 627 np.array( 628 complex(tensor.dcomplex_val[0], tensor.dcomplex_val[1]), 629 dtype=dtype), num_elements).reshape(shape) 630 else: 631 return np.array( 632 [complex(x[0], x[1]) for x in zip(it, it)], 633 dtype=dtype).reshape(shape) 634 elif tensor_dtype == dtypes.bool: 635 if len(tensor.bool_val) == 1: 636 return np.repeat(np.array(tensor.bool_val[0], dtype=dtype), 637 num_elements).reshape(shape) 638 else: 639 return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape) 640 else: 641 raise TypeError("Unsupported tensor type: %s" % tensor.dtype) 642 643 644 def ShapeEquals(tensor_proto, shape): 645 """Returns True if "tensor_proto" has the given "shape". 646 647 Args: 648 tensor_proto: A TensorProto. 649 shape: A tensor shape, expressed as a TensorShape, list, or tuple. 650 651 Returns: 652 True if "tensor_proto" has the given "shape", otherwise False. 653 654 Raises: 655 TypeError: If "tensor_proto" is not a TensorProto, or shape is not a 656 TensorShape, list, or tuple. 657 """ 658 if not isinstance(tensor_proto, tensor_pb2.TensorProto): 659 raise TypeError("tensor_proto is not a tensor_pb2.TensorProto object") 660 if isinstance(shape, tensor_shape_pb2.TensorShapeProto): 661 shape = [d.size for d in shape.dim] 662 elif not isinstance(shape, (list, tuple)): 663 raise TypeError("shape is not a list or tuple") 664 tensor_shape_list = [d.size for d in tensor_proto.tensor_shape.dim] 665 return all(x == y for x, y in zip(tensor_shape_list, shape)) 666 667 668 def _ConstantValue(tensor, partial): 669 # TODO(touts): Support Variables? 670 if not isinstance(tensor, ops.Tensor): 671 raise TypeError("tensor is not a Tensor") 672 if tensor.op.type == "Const": 673 return MakeNdarray(tensor.op.get_attr("value")) 674 elif tensor.op.type == "Shape": 675 input_shape = tensor.op.inputs[0].get_shape() 676 if input_shape.is_fully_defined(): 677 return np.array( 678 [dim.value for dim in input_shape.dims], 679 dtype=tensor.dtype.as_numpy_dtype) 680 else: 681 return None 682 elif tensor.op.type == "Size": 683 input_shape = tensor.op.inputs[0].get_shape() 684 if input_shape.is_fully_defined(): 685 return np.prod([dim.value for dim in input_shape.dims], dtype=np.int32) 686 else: 687 return None 688 elif tensor.op.type == "Rank": 689 input_shape = tensor.op.inputs[0].get_shape() 690 if input_shape.ndims is not None: 691 return np.ndarray( 692 shape=(), 693 buffer=np.array([input_shape.ndims], dtype=np.int32), 694 dtype=np.int32) 695 else: 696 return None 697 elif tensor.op.type == "Range": 698 start = constant_value(tensor.op.inputs[0]) 699 if start is None: 700 return None 701 limit = constant_value(tensor.op.inputs[1]) 702 if limit is None: 703 return None 704 delta = constant_value(tensor.op.inputs[2]) 705 if delta is None: 706 return None 707 return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype) 708 elif tensor.op.type == "Cast": 709 pre_cast = constant_value(tensor.op.inputs[0]) 710 if pre_cast is None: 711 return None 712 cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT")) 713 return pre_cast.astype(cast_dtype.as_numpy_dtype) 714 elif tensor.op.type == "Concat": 715 dim = constant_value(tensor.op.inputs[0]) 716 if dim is None: 717 return None 718 values = [] 719 for x in tensor.op.inputs[1:]: 720 value = constant_value(x) 721 if value is None: 722 return None 723 values.append(value) 724 return np.concatenate(values, axis=dim) 725 elif tensor.op.type == "ConcatV2": 726 dim = constant_value(tensor.op.inputs[-1]) 727 if dim is None: 728 return None 729 values = [] 730 for x in tensor.op.inputs[:-1]: 731 value = constant_value(x) 732 if value is None: 733 return None 734 values.append(value) 735 return np.concatenate(values, axis=dim) 736 elif tensor.op.type == "Pack": 737 values = [] 738 # Some imported GraphDefs have Pack ops with zero inputs. Those are invalid 739 # and shouldn't be produced, but to deal sensibly with them here we check 740 # and return None. 741 if not tensor.op.inputs: 742 return None 743 # We can't handle axis != 0 Packs at the moment. 744 if tensor.op.get_attr("axis") != 0: 745 return None 746 for x in tensor.op.inputs: 747 value = constant_value(x, partial) 748 if value is None and not partial: 749 return None 750 values.append(value) 751 return np.array(values) 752 elif tensor.op.type == "Fill": 753 fill_shape = tensor.shape 754 fill_value = constant_value(tensor.op.inputs[1]) 755 if fill_shape.is_fully_defined() and fill_value is not None: 756 return np.full(fill_shape.as_list(), fill_value, dtype=fill_value.dtype) 757 else: 758 return None 759 elif tensor.op.type == "Equal": 760 value1 = constant_value(tensor.op.inputs[0]) 761 if value1 is None: 762 return None 763 value2 = constant_value(tensor.op.inputs[1]) 764 if value2 is None: 765 return None 766 return np.equal(value1, value2) 767 elif tensor.op.type == "NotEqual": 768 value1 = constant_value(tensor.op.inputs[0]) 769 if value1 is None: 770 return None 771 value2 = constant_value(tensor.op.inputs[1]) 772 if value2 is None: 773 return None 774 return np.not_equal(value1, value2) 775 else: 776 return None 777 778 779 def constant_value(tensor, partial=False): # pylint: disable=invalid-name 780 """Returns the constant value of the given tensor, if efficiently calculable. 781 782 This function attempts to partially evaluate the given tensor, and 783 returns its value as a numpy ndarray if this succeeds. 784 785 TODO(mrry): Consider whether this function should use a registration 786 mechanism like gradients and ShapeFunctions, so that it is easily 787 extensible. 788 789 NOTE: If `constant_value(tensor)` returns a non-`None` result, it will no 790 longer be possible to feed a different value for `tensor`. This allows the 791 result of this function to influence the graph that is constructed, and 792 permits static shape optimizations. 793 794 Args: 795 tensor: The Tensor to be evaluated. 796 partial: If True, the returned numpy array is allowed to have partially 797 evaluated values. Values that can't be evaluated will be None. 798 799 Returns: 800 A numpy ndarray containing the constant value of the given `tensor`, 801 or None if it cannot be calculated. 802 803 Raises: 804 TypeError: if tensor is not an ops.Tensor. 805 """ 806 if isinstance(tensor, ops.EagerTensor): 807 return tensor.numpy() 808 ret = _ConstantValue(tensor, partial) 809 if ret is not None: 810 # The caller may now depend on the constant value of `tensor`, so we 811 # conservatively prevent it from being fed. 812 tensor.graph.prevent_feeding(tensor) 813 return ret 814 815 816 def constant_value_as_shape(tensor): # pylint: disable=invalid-name 817 """A version of `constant_value()` that returns a `TensorShape`. 818 819 This version should be used when a constant tensor value is 820 interpreted as a (possibly partial) shape, e.g. in the shape 821 function for `tf.reshape()`. By explicitly requesting a 822 `TensorShape` as the return value, it is possible to represent 823 unknown dimensions; by contrast, `constant_value()` is 824 all-or-nothing. 825 826 Args: 827 tensor: The rank-1 Tensor to be evaluated. 828 829 Returns: 830 A `TensorShape` based on the constant value of the given `tensor`. 831 """ 832 if context.in_eager_mode(): 833 return tensor_shape.as_shape( 834 [dim if dim != -1 else None for dim in tensor.numpy()]) 835 836 shape = tensor.get_shape().with_rank(1) 837 if tensor.get_shape() == [0]: 838 return tensor_shape.scalar() 839 elif tensor.op.type == "Shape": 840 return tensor.op.inputs[0].get_shape() 841 elif tensor.op.type == "Pack": 842 ret = tensor_shape.scalar() # Empty list. 843 # Since we expect rank 1 inputs, Pack's axis must be zero, otherwise it 844 # would not be rank 1. 845 assert tensor.op.get_attr("axis") == 0 846 for pack_input in tensor.op.inputs: 847 # `pack_input` must be a scalar. Attempt to evaluate it, and append it 848 # to `ret`. 849 pack_input_val = constant_value(pack_input) 850 if pack_input_val is None or pack_input_val < 0: 851 new_dim = tensor_shape.Dimension(None) 852 else: 853 new_dim = tensor_shape.Dimension(pack_input_val) 854 ret = ret.concatenate([new_dim]) 855 return ret 856 elif tensor.op.type == "Concat": 857 # We assume that `tensor.op.inputs[0]` evaluates to 0, as this is 858 # the only legal value when concatenating vectors, and it will 859 # have been checked by a previous shape function. 860 ret = tensor_shape.scalar() # Empty list. 861 for concat_input in tensor.op.inputs[1:]: 862 # `concat_input` must be a vector. Attempt to evaluate it as a shape, 863 # and concatenate it with `ret`. 864 ret = ret.concatenate(constant_value_as_shape(concat_input)) 865 return ret 866 elif tensor.op.type == "ConcatV2": 867 # We assume that `tensor.op.inputs[-1]` evaluates to 0, as this is 868 # the only legal value when concatenating vectors, and it will 869 # have been checked by a previous shape function. 870 ret = tensor_shape.scalar() # Empty list. 871 for concat_input in tensor.op.inputs[:-1]: 872 # `concat_input` must be a vector. Attempt to evaluate it as a shape, 873 # and concatenate it with `ret`. 874 ret = ret.concatenate(constant_value_as_shape(concat_input)) 875 return ret 876 elif tensor.op.type == "StridedSlice": 877 try: 878 begin = constant_value(tensor.op.inputs[1]) 879 end = constant_value(tensor.op.inputs[2]) 880 strides = constant_value(tensor.op.inputs[3]) 881 if begin is not None and end is not None and strides is not None: 882 begin = begin[0] 883 end = end[0] 884 strides = strides[0] 885 begin_mask = tensor.op.get_attr("begin_mask") 886 if begin_mask == 1: 887 begin = None 888 end_mask = tensor.op.get_attr("end_mask") 889 if end_mask == 1: 890 end = None 891 892 ellipsis_mask = tensor.op.get_attr("ellipsis_mask") 893 new_axis_mask = tensor.op.get_attr("new_axis_mask") 894 shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask") 895 valid_attributes = (not ellipsis_mask and not new_axis_mask and 896 not shrink_axis_mask and (not begin_mask or 897 (begin_mask == 1)) and 898 (not end_mask or (end_mask == 1))) 899 if valid_attributes: # additional inputs not supported 900 prev = constant_value_as_shape(tensor.op.inputs[0]) 901 prev = prev[begin:end:strides] 902 ret = tensor_shape.TensorShape(prev) 903 return ret 904 905 except ValueError: # Could come from get_attr or slicing prev. 906 pass 907 except TypeError: # Could come from slicing prev. 908 pass 909 910 ret = tensor_shape.unknown_shape(shape[0].value) 911 value = constant_value(tensor) 912 if value is not None: 913 ret = ret.merge_with( 914 tensor_shape.TensorShape([d if d >= 0 else None for d in value])) 915 return ret 916 917 918 def is_tensor(x): # pylint: disable=invalid-name 919 """Check whether `x` is of tensor type. 920 921 Check whether an object is a tensor. Equivalent to 922 `isinstance(x, [tf.Tensor, tf.SparseTensor, tf.Variable])`. 923 924 Args: 925 x: A python object to check. 926 927 Returns: 928 `True` if `x` is a tensor, `False` if not. 929 """ 930 return isinstance(x, ops._TensorLike) or ops.is_dense_tensor_like(x) # pylint: disable=protected-access 931