Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Utilities to create TensorProtos."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 import six
     22 
     23 from tensorflow.core.framework import tensor_pb2
     24 from tensorflow.core.framework import tensor_shape_pb2
     25 from tensorflow.python.eager import context
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import tensor_shape
     28 from tensorflow.python.util import compat
     29 
     30 # Fallback in case fast_tensor_util is not properly compiled.
     31 # pylint: disable=g-import-not-at-top
     32 try:
     33   from tensorflow.python.framework import fast_tensor_util
     34   _FAST_TENSOR_UTIL_AVAILABLE = True
     35 except ImportError:
     36   _FAST_TENSOR_UTIL_AVAILABLE = False
     37 
     38 from tensorflow.python.framework import dtypes
     39 from tensorflow.python.framework import ops
     40 from tensorflow.python.util.tf_export import tf_export
     41 
     42 # pylint: enable=g-import-not-at-top
     43 
     44 
     45 def ExtractBitsFromFloat16(x):
     46   return np.asscalar(np.asarray(x, dtype=np.float16).view(np.uint16))
     47 
     48 
     49 def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
     50   tensor_proto.half_val.extend(
     51       [ExtractBitsFromFloat16(x) for x in proto_values])
     52 
     53 
     54 def ExtractBitsFromBFloat16(x):
     55   return np.asscalar(
     56       np.asarray(x, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16))
     57 
     58 
     59 def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
     60   tensor_proto.half_val.extend(
     61       [ExtractBitsFromBFloat16(x) for x in proto_values])
     62 
     63 
     64 if _FAST_TENSOR_UTIL_AVAILABLE:
     65   _NP_TO_APPEND_FN = {
     66       dtypes.bfloat16.as_numpy_dtype:
     67           SlowAppendBFloat16ArrayToTensorProto,
     68       # TODO(sesse): We should have a
     69       # fast_tensor_util.AppendFloat16ArrayToTensorProto,
     70       # but it seems np.float16_t doesn't exist?
     71       np.float16:
     72           SlowAppendFloat16ArrayToTensorProto,
     73       np.float32:
     74           fast_tensor_util.AppendFloat32ArrayToTensorProto,
     75       np.float64:
     76           fast_tensor_util.AppendFloat64ArrayToTensorProto,
     77       np.int32:
     78           fast_tensor_util.AppendInt32ArrayToTensorProto,
     79       np.int64:
     80           fast_tensor_util.AppendInt64ArrayToTensorProto,
     81       np.uint8:
     82           fast_tensor_util.AppendUInt8ArrayToTensorProto,
     83       np.uint16:
     84           fast_tensor_util.AppendUInt16ArrayToTensorProto,
     85       np.uint32:
     86           fast_tensor_util.AppendUInt32ArrayToTensorProto,
     87       np.uint64:
     88           fast_tensor_util.AppendUInt64ArrayToTensorProto,
     89       np.int8:
     90           fast_tensor_util.AppendInt8ArrayToTensorProto,
     91       np.int16:
     92           fast_tensor_util.AppendInt16ArrayToTensorProto,
     93       np.complex64:
     94           fast_tensor_util.AppendComplex64ArrayToTensorProto,
     95       np.complex128:
     96           fast_tensor_util.AppendComplex128ArrayToTensorProto,
     97       np.object:
     98           fast_tensor_util.AppendObjectArrayToTensorProto,
     99       np.bool:
    100           fast_tensor_util.AppendBoolArrayToTensorProto,
    101       dtypes.qint8.as_numpy_dtype:
    102           fast_tensor_util.AppendInt8ArrayToTensorProto,
    103       dtypes.quint8.as_numpy_dtype:
    104           fast_tensor_util.AppendUInt8ArrayToTensorProto,
    105       dtypes.qint16.as_numpy_dtype:
    106           fast_tensor_util.AppendInt8ArrayToTensorProto,
    107       dtypes.quint16.as_numpy_dtype:
    108           fast_tensor_util.AppendUInt8ArrayToTensorProto,
    109       dtypes.qint32.as_numpy_dtype:
    110           fast_tensor_util.AppendInt32ArrayToTensorProto,
    111       # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
    112   }
    113 else:
    114 
    115   def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values):
    116     tensor_proto.float_val.extend([np.asscalar(x) for x in proto_values])
    117 
    118   def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values):
    119     tensor_proto.double_val.extend([np.asscalar(x) for x in proto_values])
    120 
    121   def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values):
    122     tensor_proto.int_val.extend([np.asscalar(x) for x in proto_values])
    123 
    124   def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values):
    125     tensor_proto.int64_val.extend([np.asscalar(x) for x in proto_values])
    126 
    127   def SlowAppendQIntArrayToTensorProto(tensor_proto, proto_values):
    128     tensor_proto.int_val.extend([np.asscalar(x[0]) for x in proto_values])
    129 
    130   def SlowAppendUInt32ArrayToTensorProto(tensor_proto, proto_values):
    131     tensor_proto.uint32_val.extend([np.asscalar(x) for x in proto_values])
    132 
    133   def SlowAppendUInt64ArrayToTensorProto(tensor_proto, proto_values):
    134     tensor_proto.uint64_val.extend([np.asscalar(x) for x in proto_values])
    135 
    136   def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values):
    137     tensor_proto.scomplex_val.extend(
    138         [np.asscalar(v) for x in proto_values for v in [x.real, x.imag]])
    139 
    140   def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values):
    141     tensor_proto.dcomplex_val.extend(
    142         [np.asscalar(v) for x in proto_values for v in [x.real, x.imag]])
    143 
    144   def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values):
    145     tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values])
    146 
    147   def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values):
    148     tensor_proto.bool_val.extend([np.asscalar(x) for x in proto_values])
    149 
    150   _NP_TO_APPEND_FN = {
    151       dtypes.bfloat16.as_numpy_dtype: SlowAppendBFloat16ArrayToTensorProto,
    152       np.float16: SlowAppendFloat16ArrayToTensorProto,
    153       np.float32: SlowAppendFloat32ArrayToTensorProto,
    154       np.float64: SlowAppendFloat64ArrayToTensorProto,
    155       np.int32: SlowAppendIntArrayToTensorProto,
    156       np.int64: SlowAppendInt64ArrayToTensorProto,
    157       np.uint8: SlowAppendIntArrayToTensorProto,
    158       np.uint16: SlowAppendIntArrayToTensorProto,
    159       np.uint32: SlowAppendUInt32ArrayToTensorProto,
    160       np.uint64: SlowAppendUInt64ArrayToTensorProto,
    161       np.int8: SlowAppendIntArrayToTensorProto,
    162       np.int16: SlowAppendIntArrayToTensorProto,
    163       np.complex64: SlowAppendComplex64ArrayToTensorProto,
    164       np.complex128: SlowAppendComplex128ArrayToTensorProto,
    165       np.object: SlowAppendObjectArrayToTensorProto,
    166       np.bool: SlowAppendBoolArrayToTensorProto,
    167       dtypes.qint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
    168       dtypes.quint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
    169       dtypes.qint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
    170       dtypes.quint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
    171       dtypes.qint32.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
    172       # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
    173   }
    174 
    175 
    176 def GetFromNumpyDTypeDict(dtype_dict, dtype):
    177   # NOTE: dtype_dict.get(dtype) always returns None.
    178   for key, val in six.iteritems(dtype_dict):
    179     if key == dtype:
    180       return val
    181   return None
    182 
    183 
    184 def GetNumpyAppendFn(dtype):
    185   # numpy dtype for strings are variable length. We can not compare
    186   # dtype with a single constant (np.string does not exist) to decide
    187   # dtype is a "string" type. We need to compare the dtype.type to be
    188   # sure it's a string type.
    189   if dtype.type == np.string_ or dtype.type == np.unicode_:
    190     if _FAST_TENSOR_UTIL_AVAILABLE:
    191       return fast_tensor_util.AppendObjectArrayToTensorProto
    192     else:
    193       return SlowAppendObjectArrayToTensorProto
    194   return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype)
    195 
    196 
    197 def TensorShapeProtoToList(shape):
    198   """Convert a TensorShape to a list.
    199 
    200   Args:
    201     shape: A TensorShapeProto.
    202 
    203   Returns:
    204     List of integers representing the dimensions of the tensor.
    205   """
    206   return [dim.size for dim in shape.dim]
    207 
    208 
    209 def _GetDenseDimensions(list_of_lists):
    210   """Returns the inferred dense dimensions of a list of lists."""
    211   if not isinstance(list_of_lists, (list, tuple)):
    212     return []
    213   elif not list_of_lists:
    214     return [0]
    215   else:
    216     return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0])
    217 
    218 
    219 def _FlattenToStrings(nested_strings):
    220   if isinstance(nested_strings, (list, tuple)):
    221     for inner in nested_strings:
    222       for flattened_string in _FlattenToStrings(inner):
    223         yield flattened_string
    224   else:
    225     yield nested_strings
    226 
    227 
    228 _TENSOR_CONTENT_TYPES = frozenset([
    229     dtypes.float32, dtypes.float64, dtypes.int32, dtypes.uint8, dtypes.int16,
    230     dtypes.int8, dtypes.int64, dtypes.qint8, dtypes.quint8, dtypes.qint16,
    231     dtypes.quint16, dtypes.qint32, dtypes.uint32, dtypes.uint64
    232 ])
    233 
    234 
    235 class _Message(object):
    236 
    237   def __init__(self, message):
    238     self._message = message
    239 
    240   def __repr__(self):
    241     return self._message
    242 
    243 
    244 def _FirstNotNone(l):
    245   for x in l:
    246     if x is not None:
    247       if isinstance(x, ops.Tensor):
    248         return _Message("list containing Tensors")
    249       else:
    250         return x
    251   return None
    252 
    253 
    254 def _NotNone(v):
    255   if v is None:
    256     return _Message("None")
    257   else:
    258     return v
    259 
    260 
    261 def _FilterTuple(v):
    262   if not isinstance(v, (list, tuple)):
    263     return v
    264   if isinstance(v, tuple):
    265     if not any(isinstance(x, (list, tuple)) for x in v):
    266       return None
    267   if isinstance(v, list):
    268     if not any(isinstance(x, (list, tuple)) for x in v):
    269       return _FirstNotNone(
    270           [None if isinstance(x, (list, tuple)) else x for x in v])
    271   return _FirstNotNone([_FilterTuple(x) for x in v])
    272 
    273 
    274 def _FilterInt(v):
    275   if isinstance(v, (list, tuple)):
    276     return _FirstNotNone([_FilterInt(x) for x in v])
    277   return None if isinstance(
    278       v, (compat.integral_types, tensor_shape.Dimension)) else _NotNone(v)
    279 
    280 
    281 def _FilterFloat(v):
    282   if isinstance(v, (list, tuple)):
    283     return _FirstNotNone([_FilterFloat(x) for x in v])
    284   return None if isinstance(v, compat.real_types) else _NotNone(v)
    285 
    286 
    287 def _FilterComplex(v):
    288   if isinstance(v, (list, tuple)):
    289     return _FirstNotNone([_FilterComplex(x) for x in v])
    290   return None if isinstance(v, compat.complex_types) else _NotNone(v)
    291 
    292 
    293 def _FilterStr(v):
    294   if isinstance(v, (list, tuple)):
    295     return _FirstNotNone([_FilterStr(x) for x in v])
    296   if isinstance(v, compat.bytes_or_text_types):
    297     return None
    298   else:
    299     return _NotNone(v)
    300 
    301 
    302 def _FilterBool(v):
    303   if isinstance(v, (list, tuple)):
    304     return _FirstNotNone([_FilterBool(x) for x in v])
    305   return None if isinstance(v, bool) else _NotNone(v)
    306 
    307 
    308 def _FilterNotTensor(v):
    309   if isinstance(v, (list, tuple)):
    310     return _FirstNotNone([_FilterNotTensor(x) for x in v])
    311   return str(v) if isinstance(v, ops.Tensor) else None
    312 
    313 
    314 _TF_TO_IS_OK = {
    315     dtypes.bool: [_FilterBool],
    316     dtypes.complex128: [_FilterComplex],
    317     dtypes.complex64: [_FilterComplex],
    318     dtypes.float16: [_FilterFloat],
    319     dtypes.float32: [_FilterFloat],
    320     dtypes.float64: [_FilterFloat],
    321     dtypes.int16: [_FilterInt],
    322     dtypes.int32: [_FilterInt],
    323     dtypes.int64: [_FilterInt],
    324     dtypes.int8: [_FilterInt],
    325     dtypes.qint16: [_FilterInt, _FilterTuple],
    326     dtypes.qint32: [_FilterInt, _FilterTuple],
    327     dtypes.qint8: [_FilterInt, _FilterTuple],
    328     dtypes.quint16: [_FilterInt, _FilterTuple],
    329     dtypes.quint8: [_FilterInt, _FilterTuple],
    330     dtypes.string: [_FilterStr],
    331     dtypes.uint16: [_FilterInt],
    332     dtypes.uint8: [_FilterInt],
    333 }
    334 
    335 
    336 def _AssertCompatible(values, dtype):
    337   fn_list = _TF_TO_IS_OK.get(dtype, [_FilterNotTensor])
    338   mismatch = _FirstNotNone([fn(values) for fn in fn_list])
    339   if mismatch is not None:
    340     if dtype is None:
    341       raise TypeError("List of Tensors when single Tensor expected")
    342     else:
    343       raise TypeError("Expected %s, got %s of type '%s' instead." %
    344                       (dtype.name, repr(mismatch), type(mismatch).__name__))
    345 
    346 
    347 @tf_export("make_tensor_proto")
    348 def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False):
    349   """Create a TensorProto.
    350 
    351   Args:
    352     values:         Values to put in the TensorProto.
    353     dtype:          Optional tensor_pb2 DataType value.
    354     shape:          List of integers representing the dimensions of tensor.
    355     verify_shape:   Boolean that enables verification of a shape of values.
    356 
    357   Returns:
    358     A `TensorProto`. Depending on the type, it may contain data in the
    359     "tensor_content" attribute, which is not directly useful to Python programs.
    360     To access the values you should convert the proto back to a numpy ndarray
    361     with `tensor_util.MakeNdarray(proto)`.
    362 
    363     If `values` is a `TensorProto`, it is immediately returned; `dtype` and
    364     `shape` are ignored.
    365 
    366   Raises:
    367     TypeError:  if unsupported types are provided.
    368     ValueError: if arguments have inappropriate values or if verify_shape is
    369      True and shape of values is not equals to a shape from the argument.
    370 
    371   make_tensor_proto accepts "values" of a python scalar, a python list, a
    372   numpy ndarray, or a numpy scalar.
    373 
    374   If "values" is a python scalar or a python list, make_tensor_proto
    375   first convert it to numpy ndarray. If dtype is None, the
    376   conversion tries its best to infer the right numpy data
    377   type. Otherwise, the resulting numpy array has a compatible data
    378   type with the given dtype.
    379 
    380   In either case above, the numpy ndarray (either the caller provided
    381   or the auto converted) must have the compatible type with dtype.
    382 
    383   make_tensor_proto then converts the numpy array to a tensor proto.
    384 
    385   If "shape" is None, the resulting tensor proto represents the numpy
    386   array precisely.
    387 
    388   Otherwise, "shape" specifies the tensor's shape and the numpy array
    389   can not have more elements than what "shape" specifies.
    390 
    391   """
    392   if isinstance(values, tensor_pb2.TensorProto):
    393     return values
    394 
    395   if dtype:
    396     dtype = dtypes.as_dtype(dtype)
    397 
    398   is_quantized = (
    399       dtype in [
    400           dtypes.qint8, dtypes.quint8, dtypes.qint16, dtypes.quint16,
    401           dtypes.qint32
    402       ])
    403 
    404   # We first convert value to a numpy array or scalar.
    405   if isinstance(values, (np.ndarray, np.generic)):
    406     if dtype:
    407       nparray = values.astype(dtype.as_numpy_dtype)
    408     else:
    409       nparray = values
    410   elif callable(getattr(values, "__array__", None)) or isinstance(
    411       getattr(values, "__array_interface__", None), dict):
    412     # If a class has the __array__ method, or __array_interface__ dict, then it
    413     # is possible to convert to numpy array.
    414     nparray = np.asarray(values, dtype=dtype)
    415 
    416     # This is the preferred way to create an array from the object, so replace
    417     # the `values` with the array so that _FlattenToStrings is not run.
    418     values = nparray
    419   else:
    420     if values is None:
    421       raise ValueError("None values not supported.")
    422     # if dtype is provided, forces numpy array to be the type
    423     # provided if possible.
    424     if dtype and dtype.is_numpy_compatible:
    425       np_dt = dtype.as_numpy_dtype
    426     else:
    427       np_dt = None
    428     # If shape is None, numpy.prod returns None when dtype is not set, but raises
    429     # exception when dtype is set to np.int64
    430     if shape is not None and np.prod(shape, dtype=np.int64) == 0:
    431       nparray = np.empty(shape, dtype=np_dt)
    432     else:
    433       _AssertCompatible(values, dtype)
    434       nparray = np.array(values, dtype=np_dt)
    435       # check to them.
    436       # We need to pass in quantized values as tuples, so don't apply the shape
    437       if (list(nparray.shape) != _GetDenseDimensions(values) and
    438           not is_quantized):
    439         raise ValueError("""Argument must be a dense tensor: %s"""
    440                          """ - got shape %s, but wanted %s.""" %
    441                          (values, list(nparray.shape),
    442                           _GetDenseDimensions(values)))
    443 
    444     # python/numpy default float type is float64. We prefer float32 instead.
    445     if (nparray.dtype == np.float64) and dtype is None:
    446       nparray = nparray.astype(np.float32)
    447     # python/numpy default int type is int64. We prefer int32 instead.
    448     elif (nparray.dtype == np.int64) and dtype is None:
    449       downcasted_array = nparray.astype(np.int32)
    450       # Do not down cast if it leads to precision loss.
    451       if np.array_equal(downcasted_array, nparray):
    452         nparray = downcasted_array
    453 
    454   # if dtype is provided, it must be compatible with what numpy
    455   # conversion says.
    456   numpy_dtype = dtypes.as_dtype(nparray.dtype)
    457   if numpy_dtype is None:
    458     raise TypeError("Unrecognized data type: %s" % nparray.dtype)
    459 
    460   # If dtype was specified and is a quantized type, we convert
    461   # numpy_dtype back into the quantized version.
    462   if is_quantized:
    463     numpy_dtype = dtype
    464 
    465   if dtype is not None and (not hasattr(dtype, "base_dtype") or
    466                             dtype.base_dtype != numpy_dtype.base_dtype):
    467     raise TypeError("Incompatible types: %s vs. %s. Value is %s" %
    468                     (dtype, nparray.dtype, values))
    469 
    470   # If shape is not given, get the shape from the numpy array.
    471   if shape is None:
    472     shape = nparray.shape
    473     is_same_size = True
    474     shape_size = nparray.size
    475   else:
    476     shape = [int(dim) for dim in shape]
    477     shape_size = np.prod(shape, dtype=np.int64)
    478     is_same_size = shape_size == nparray.size
    479 
    480     if verify_shape:
    481       if not nparray.shape == tuple(shape):
    482         raise TypeError("Expected Tensor's shape: %s, got %s." %
    483                         (tuple(shape), nparray.shape))
    484 
    485     if nparray.size > shape_size:
    486       raise ValueError(
    487           "Too many elements provided. Needed at most %d, but received %d" %
    488           (shape_size, nparray.size))
    489 
    490   tensor_proto = tensor_pb2.TensorProto(
    491       dtype=numpy_dtype.as_datatype_enum,
    492       tensor_shape=tensor_shape.as_shape(shape).as_proto())
    493 
    494   if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1:
    495     if nparray.size * nparray.itemsize >= (1 << 31):
    496       raise ValueError(
    497           "Cannot create a tensor proto whose content is larger than 2GB.")
    498     tensor_proto.tensor_content = nparray.tostring()
    499     return tensor_proto
    500 
    501   # If we were not given values as a numpy array, compute the proto_values
    502   # from the given values directly, to avoid numpy trimming nulls from the
    503   # strings. Since values could be a list of strings, or a multi-dimensional
    504   # list of lists that might or might not correspond to the given shape,
    505   # we flatten it conservatively.
    506   if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray):
    507     proto_values = _FlattenToStrings(values)
    508 
    509     # At this point, values may be a list of objects that we could not
    510     # identify a common type for (hence it was inferred as
    511     # np.object/dtypes.string).  If we are unable to convert it to a
    512     # string, we raise a more helpful error message.
    513     #
    514     # Ideally, we'd be able to convert the elements of the list to a
    515     # common type, but this type inference requires some thinking and
    516     # so we defer it for now.
    517     try:
    518       str_values = [compat.as_bytes(x) for x in proto_values]
    519     except TypeError:
    520       raise TypeError("Failed to convert object of type %s to Tensor. "
    521                       "Contents: %s. Consider casting elements to a "
    522                       "supported type." % (type(values), values))
    523     tensor_proto.string_val.extend(str_values)
    524     return tensor_proto
    525 
    526   # TensorFlow expects C order (a.k.a., eigen row major).
    527   proto_values = nparray.ravel()
    528 
    529   append_fn = GetNumpyAppendFn(proto_values.dtype)
    530   if append_fn is None:
    531     raise TypeError(
    532         "Element type not supported in TensorProto: %s" % numpy_dtype.name)
    533   append_fn(tensor_proto, proto_values)
    534 
    535   return tensor_proto
    536 
    537 
    538 @tf_export("make_ndarray")
    539 def MakeNdarray(tensor):
    540   """Create a numpy ndarray from a tensor.
    541 
    542   Create a numpy ndarray with the same shape and data as the tensor.
    543 
    544   Args:
    545     tensor: A TensorProto.
    546 
    547   Returns:
    548     A numpy array with the tensor contents.
    549 
    550   Raises:
    551     TypeError: if tensor has unsupported type.
    552 
    553   """
    554   shape = [d.size for d in tensor.tensor_shape.dim]
    555   num_elements = np.prod(shape, dtype=np.int64)
    556   tensor_dtype = dtypes.as_dtype(tensor.dtype)
    557   dtype = tensor_dtype.as_numpy_dtype
    558 
    559   if tensor.tensor_content:
    560     return (np.frombuffer(tensor.tensor_content, dtype=dtype).copy()
    561             .reshape(shape))
    562   elif tensor_dtype == dtypes.float16:
    563     # the half_val field of the TensorProto stores the binary representation
    564     # of the fp16: we need to reinterpret this as a proper float16
    565     if len(tensor.half_val) == 1:
    566       tmp = np.array(tensor.half_val[0], dtype=np.uint16)
    567       tmp.dtype = np.float16
    568       return np.repeat(tmp, num_elements).reshape(shape)
    569     else:
    570       tmp = np.fromiter(tensor.half_val, dtype=np.uint16)
    571       tmp.dtype = np.float16
    572       return tmp.reshape(shape)
    573   elif tensor_dtype == dtypes.float32:
    574     if len(tensor.float_val) == 1:
    575       return np.repeat(
    576           np.array(tensor.float_val[0], dtype=dtype),
    577           num_elements).reshape(shape)
    578     else:
    579       return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape)
    580   elif tensor_dtype == dtypes.float64:
    581     if len(tensor.double_val) == 1:
    582       return np.repeat(
    583           np.array(tensor.double_val[0], dtype=dtype),
    584           num_elements).reshape(shape)
    585     else:
    586       return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape)
    587   elif tensor_dtype in [
    588       dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8,
    589       dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16,
    590       dtypes.bfloat16
    591   ]:
    592     if len(tensor.int_val) == 1:
    593       return np.repeat(np.array(tensor.int_val[0], dtype=dtype),
    594                        num_elements).reshape(shape)
    595     else:
    596       return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape)
    597   elif tensor_dtype == dtypes.int64:
    598     if len(tensor.int64_val) == 1:
    599       return np.repeat(
    600           np.array(tensor.int64_val[0], dtype=dtype),
    601           num_elements).reshape(shape)
    602     else:
    603       return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape)
    604   elif tensor_dtype == dtypes.string:
    605     if len(tensor.string_val) == 1:
    606       return np.repeat(
    607           np.array(tensor.string_val[0], dtype=dtype),
    608           num_elements).reshape(shape)
    609     else:
    610       return np.array(
    611           [x for x in tensor.string_val], dtype=dtype).reshape(shape)
    612   elif tensor_dtype == dtypes.complex64:
    613     it = iter(tensor.scomplex_val)
    614     if len(tensor.scomplex_val) == 2:
    615       return np.repeat(
    616           np.array(
    617               complex(tensor.scomplex_val[0], tensor.scomplex_val[1]),
    618               dtype=dtype), num_elements).reshape(shape)
    619     else:
    620       return np.array(
    621           [complex(x[0], x[1]) for x in zip(it, it)],
    622           dtype=dtype).reshape(shape)
    623   elif tensor_dtype == dtypes.complex128:
    624     it = iter(tensor.dcomplex_val)
    625     if len(tensor.dcomplex_val) == 2:
    626       return np.repeat(
    627           np.array(
    628               complex(tensor.dcomplex_val[0], tensor.dcomplex_val[1]),
    629               dtype=dtype), num_elements).reshape(shape)
    630     else:
    631       return np.array(
    632           [complex(x[0], x[1]) for x in zip(it, it)],
    633           dtype=dtype).reshape(shape)
    634   elif tensor_dtype == dtypes.bool:
    635     if len(tensor.bool_val) == 1:
    636       return np.repeat(np.array(tensor.bool_val[0], dtype=dtype),
    637                        num_elements).reshape(shape)
    638     else:
    639       return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape)
    640   else:
    641     raise TypeError("Unsupported tensor type: %s" % tensor.dtype)
    642 
    643 
    644 def ShapeEquals(tensor_proto, shape):
    645   """Returns True if "tensor_proto" has the given "shape".
    646 
    647   Args:
    648     tensor_proto: A TensorProto.
    649     shape: A tensor shape, expressed as a TensorShape, list, or tuple.
    650 
    651   Returns:
    652     True if "tensor_proto" has the given "shape", otherwise False.
    653 
    654   Raises:
    655     TypeError: If "tensor_proto" is not a TensorProto, or shape is not a
    656       TensorShape, list, or tuple.
    657   """
    658   if not isinstance(tensor_proto, tensor_pb2.TensorProto):
    659     raise TypeError("tensor_proto is not a tensor_pb2.TensorProto object")
    660   if isinstance(shape, tensor_shape_pb2.TensorShapeProto):
    661     shape = [d.size for d in shape.dim]
    662   elif not isinstance(shape, (list, tuple)):
    663     raise TypeError("shape is not a list or tuple")
    664   tensor_shape_list = [d.size for d in tensor_proto.tensor_shape.dim]
    665   return all(x == y for x, y in zip(tensor_shape_list, shape))
    666 
    667 
    668 def _ConstantValue(tensor, partial):
    669   # TODO(touts): Support Variables?
    670   if not isinstance(tensor, ops.Tensor):
    671     raise TypeError("tensor is not a Tensor")
    672   if tensor.op.type == "Const":
    673     return MakeNdarray(tensor.op.get_attr("value"))
    674   elif tensor.op.type == "Shape":
    675     input_shape = tensor.op.inputs[0].get_shape()
    676     if input_shape.is_fully_defined():
    677       return np.array(
    678           [dim.value for dim in input_shape.dims],
    679           dtype=tensor.dtype.as_numpy_dtype)
    680     else:
    681       return None
    682   elif tensor.op.type == "Size":
    683     input_shape = tensor.op.inputs[0].get_shape()
    684     if input_shape.is_fully_defined():
    685       return np.prod([dim.value for dim in input_shape.dims], dtype=np.int32)
    686     else:
    687       return None
    688   elif tensor.op.type == "Rank":
    689     input_shape = tensor.op.inputs[0].get_shape()
    690     if input_shape.ndims is not None:
    691       return np.ndarray(
    692           shape=(),
    693           buffer=np.array([input_shape.ndims], dtype=np.int32),
    694           dtype=np.int32)
    695     else:
    696       return None
    697   elif tensor.op.type == "Range":
    698     start = constant_value(tensor.op.inputs[0])
    699     if start is None:
    700       return None
    701     limit = constant_value(tensor.op.inputs[1])
    702     if limit is None:
    703       return None
    704     delta = constant_value(tensor.op.inputs[2])
    705     if delta is None:
    706       return None
    707     return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype)
    708   elif tensor.op.type == "Cast":
    709     pre_cast = constant_value(tensor.op.inputs[0])
    710     if pre_cast is None:
    711       return None
    712     cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))
    713     return pre_cast.astype(cast_dtype.as_numpy_dtype)
    714   elif tensor.op.type == "Concat":
    715     dim = constant_value(tensor.op.inputs[0])
    716     if dim is None:
    717       return None
    718     values = []
    719     for x in tensor.op.inputs[1:]:
    720       value = constant_value(x)
    721       if value is None:
    722         return None
    723       values.append(value)
    724     return np.concatenate(values, axis=dim)
    725   elif tensor.op.type == "ConcatV2":
    726     dim = constant_value(tensor.op.inputs[-1])
    727     if dim is None:
    728       return None
    729     values = []
    730     for x in tensor.op.inputs[:-1]:
    731       value = constant_value(x)
    732       if value is None:
    733         return None
    734       values.append(value)
    735     return np.concatenate(values, axis=dim)
    736   elif tensor.op.type == "Pack":
    737     values = []
    738     # Some imported GraphDefs have Pack ops with zero inputs. Those are invalid
    739     # and shouldn't be produced, but to deal sensibly with them here we check
    740     # and return None.
    741     if not tensor.op.inputs:
    742       return None
    743     # We can't handle axis != 0 Packs at the moment.
    744     if tensor.op.get_attr("axis") != 0:
    745       return None
    746     for x in tensor.op.inputs:
    747       value = constant_value(x, partial)
    748       if value is None and not partial:
    749         return None
    750       values.append(value)
    751     return np.array(values)
    752   elif tensor.op.type == "Fill":
    753     fill_shape = tensor.shape
    754     fill_value = constant_value(tensor.op.inputs[1])
    755     if fill_shape.is_fully_defined() and fill_value is not None:
    756       return np.full(fill_shape.as_list(), fill_value, dtype=fill_value.dtype)
    757     else:
    758       return None
    759   elif tensor.op.type == "Equal":
    760     value1 = constant_value(tensor.op.inputs[0])
    761     if value1 is None:
    762       return None
    763     value2 = constant_value(tensor.op.inputs[1])
    764     if value2 is None:
    765       return None
    766     return np.equal(value1, value2)
    767   elif tensor.op.type == "NotEqual":
    768     value1 = constant_value(tensor.op.inputs[0])
    769     if value1 is None:
    770       return None
    771     value2 = constant_value(tensor.op.inputs[1])
    772     if value2 is None:
    773       return None
    774     return np.not_equal(value1, value2)
    775   else:
    776     return None
    777 
    778 
    779 def constant_value(tensor, partial=False):  # pylint: disable=invalid-name
    780   """Returns the constant value of the given tensor, if efficiently calculable.
    781 
    782   This function attempts to partially evaluate the given tensor, and
    783   returns its value as a numpy ndarray if this succeeds.
    784 
    785   TODO(mrry): Consider whether this function should use a registration
    786   mechanism like gradients and ShapeFunctions, so that it is easily
    787   extensible.
    788 
    789   NOTE: If `constant_value(tensor)` returns a non-`None` result, it will no
    790   longer be possible to feed a different value for `tensor`. This allows the
    791   result of this function to influence the graph that is constructed, and
    792   permits static shape optimizations.
    793 
    794   Args:
    795     tensor: The Tensor to be evaluated.
    796     partial: If True, the returned numpy array is allowed to have partially
    797       evaluated values. Values that can't be evaluated will be None.
    798 
    799   Returns:
    800     A numpy ndarray containing the constant value of the given `tensor`,
    801     or None if it cannot be calculated.
    802 
    803   Raises:
    804     TypeError: if tensor is not an ops.Tensor.
    805   """
    806   if isinstance(tensor, ops.EagerTensor):
    807     return tensor.numpy()
    808   ret = _ConstantValue(tensor, partial)
    809   if ret is not None:
    810     # The caller may now depend on the constant value of `tensor`, so we
    811     # conservatively prevent it from being fed.
    812     tensor.graph.prevent_feeding(tensor)
    813   return ret
    814 
    815 
    816 def constant_value_as_shape(tensor):  # pylint: disable=invalid-name
    817   """A version of `constant_value()` that returns a `TensorShape`.
    818 
    819   This version should be used when a constant tensor value is
    820   interpreted as a (possibly partial) shape, e.g. in the shape
    821   function for `tf.reshape()`. By explicitly requesting a
    822   `TensorShape` as the return value, it is possible to represent
    823   unknown dimensions; by contrast, `constant_value()` is
    824   all-or-nothing.
    825 
    826   Args:
    827     tensor: The rank-1 Tensor to be evaluated.
    828 
    829   Returns:
    830     A `TensorShape` based on the constant value of the given `tensor`.
    831   """
    832   if context.in_eager_mode():
    833     return tensor_shape.as_shape(
    834         [dim if dim != -1 else None for dim in tensor.numpy()])
    835 
    836   shape = tensor.get_shape().with_rank(1)
    837   if tensor.get_shape() == [0]:
    838     return tensor_shape.scalar()
    839   elif tensor.op.type == "Shape":
    840     return tensor.op.inputs[0].get_shape()
    841   elif tensor.op.type == "Pack":
    842     ret = tensor_shape.scalar()  # Empty list.
    843     # Since we expect rank 1 inputs, Pack's axis must be zero, otherwise it
    844     # would not be rank 1.
    845     assert tensor.op.get_attr("axis") == 0
    846     for pack_input in tensor.op.inputs:
    847       # `pack_input` must be a scalar. Attempt to evaluate it, and append it
    848       # to `ret`.
    849       pack_input_val = constant_value(pack_input)
    850       if pack_input_val is None or pack_input_val < 0:
    851         new_dim = tensor_shape.Dimension(None)
    852       else:
    853         new_dim = tensor_shape.Dimension(pack_input_val)
    854       ret = ret.concatenate([new_dim])
    855     return ret
    856   elif tensor.op.type == "Concat":
    857     # We assume that `tensor.op.inputs[0]` evaluates to 0, as this is
    858     # the only legal value when concatenating vectors, and it will
    859     # have been checked by a previous shape function.
    860     ret = tensor_shape.scalar()  # Empty list.
    861     for concat_input in tensor.op.inputs[1:]:
    862       # `concat_input` must be a vector. Attempt to evaluate it as a shape,
    863       # and concatenate it with `ret`.
    864       ret = ret.concatenate(constant_value_as_shape(concat_input))
    865     return ret
    866   elif tensor.op.type == "ConcatV2":
    867     # We assume that `tensor.op.inputs[-1]` evaluates to 0, as this is
    868     # the only legal value when concatenating vectors, and it will
    869     # have been checked by a previous shape function.
    870     ret = tensor_shape.scalar()  # Empty list.
    871     for concat_input in tensor.op.inputs[:-1]:
    872       # `concat_input` must be a vector. Attempt to evaluate it as a shape,
    873       # and concatenate it with `ret`.
    874       ret = ret.concatenate(constant_value_as_shape(concat_input))
    875     return ret
    876   elif tensor.op.type == "StridedSlice":
    877     try:
    878       begin = constant_value(tensor.op.inputs[1])
    879       end = constant_value(tensor.op.inputs[2])
    880       strides = constant_value(tensor.op.inputs[3])
    881       if begin is not None and end is not None and strides is not None:
    882         begin = begin[0]
    883         end = end[0]
    884         strides = strides[0]
    885         begin_mask = tensor.op.get_attr("begin_mask")
    886         if begin_mask == 1:
    887           begin = None
    888         end_mask = tensor.op.get_attr("end_mask")
    889         if end_mask == 1:
    890           end = None
    891 
    892         ellipsis_mask = tensor.op.get_attr("ellipsis_mask")
    893         new_axis_mask = tensor.op.get_attr("new_axis_mask")
    894         shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask")
    895         valid_attributes = (not ellipsis_mask and not new_axis_mask and
    896                             not shrink_axis_mask and (not begin_mask or
    897                                                       (begin_mask == 1)) and
    898                             (not end_mask or (end_mask == 1)))
    899         if valid_attributes:  # additional inputs not supported
    900           prev = constant_value_as_shape(tensor.op.inputs[0])
    901           prev = prev[begin:end:strides]
    902           ret = tensor_shape.TensorShape(prev)
    903           return ret
    904 
    905     except ValueError:  # Could come from get_attr or slicing prev.
    906       pass
    907     except TypeError:  # Could come from slicing prev.
    908       pass
    909 
    910   ret = tensor_shape.unknown_shape(shape[0].value)
    911   value = constant_value(tensor)
    912   if value is not None:
    913     ret = ret.merge_with(
    914         tensor_shape.TensorShape([d if d >= 0 else None for d in value]))
    915   return ret
    916 
    917 
    918 def is_tensor(x):  # pylint: disable=invalid-name
    919   """Check whether `x` is of tensor type.
    920 
    921   Check whether an object is a tensor. Equivalent to
    922   `isinstance(x, [tf.Tensor, tf.SparseTensor, tf.Variable])`.
    923 
    924   Args:
    925     x: A python object to check.
    926 
    927   Returns:
    928     `True` if `x` is a tensor, `False` if not.
    929   """
    930   return isinstance(x, ops._TensorLike) or ops.is_dense_tensor_like(x)  # pylint: disable=protected-access
    931