Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Library of dtypes (Tensor element types)."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.core.framework import types_pb2
     23 from tensorflow.python import pywrap_tensorflow
     24 from tensorflow.python.util.tf_export import tf_export
     25 
     26 _np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
     27 
     28 
     29 @tf_export("DType")
     30 class DType(object):
     31   """Represents the type of the elements in a `Tensor`.
     32 
     33   The following `DType` objects are defined:
     34 
     35   * `tf.float16`: 16-bit half-precision floating-point.
     36   * `tf.float32`: 32-bit single-precision floating-point.
     37   * `tf.float64`: 64-bit double-precision floating-point.
     38   * `tf.bfloat16`: 16-bit truncated floating-point.
     39   * `tf.complex64`: 64-bit single-precision complex.
     40   * `tf.complex128`: 128-bit double-precision complex.
     41   * `tf.int8`: 8-bit signed integer.
     42   * `tf.uint8`: 8-bit unsigned integer.
     43   * `tf.uint16`: 16-bit unsigned integer.
     44   * `tf.uint32`: 32-bit unsigned integer.
     45   * `tf.uint64`: 64-bit unsigned integer.
     46   * `tf.int16`: 16-bit signed integer.
     47   * `tf.int32`: 32-bit signed integer.
     48   * `tf.int64`: 64-bit signed integer.
     49   * `tf.bool`: Boolean.
     50   * `tf.string`: String.
     51   * `tf.qint8`: Quantized 8-bit signed integer.
     52   * `tf.quint8`: Quantized 8-bit unsigned integer.
     53   * `tf.qint16`: Quantized 16-bit signed integer.
     54   * `tf.quint16`: Quantized 16-bit unsigned integer.
     55   * `tf.qint32`: Quantized 32-bit signed integer.
     56   * `tf.resource`: Handle to a mutable resource.
     57   * `tf.variant`: Values of arbitrary types.
     58 
     59   In addition, variants of these types with the `_ref` suffix are
     60   defined for reference-typed tensors.
     61 
     62   The `tf.as_dtype()` function converts numpy types and string type
     63   names to a `DType` object.
     64   """
     65 
     66   def __init__(self, type_enum):
     67     """Creates a new `DataType`.
     68 
     69     NOTE(mrry): In normal circumstances, you should not need to
     70     construct a `DataType` object directly. Instead, use the
     71     `tf.as_dtype()` function.
     72 
     73     Args:
     74       type_enum: A `types_pb2.DataType` enum value.
     75 
     76     Raises:
     77       TypeError: If `type_enum` is not a value `types_pb2.DataType`.
     78 
     79     """
     80     # TODO(mrry): Make the necessary changes (using __new__) to ensure
     81     # that calling this returns one of the interned values.
     82     type_enum = int(type_enum)
     83     if (type_enum not in types_pb2.DataType.values() or
     84         type_enum == types_pb2.DT_INVALID):
     85       raise TypeError(
     86           "type_enum is not a valid types_pb2.DataType: %s" % type_enum)
     87     self._type_enum = type_enum
     88 
     89   @property
     90   def _is_ref_dtype(self):
     91     """Returns `True` if this `DType` represents a reference type."""
     92     return self._type_enum > 100
     93 
     94   @property
     95   def _as_ref(self):
     96     """Returns a reference `DType` based on this `DType`."""
     97     if self._is_ref_dtype:
     98       return self
     99     else:
    100       return _INTERN_TABLE[self._type_enum + 100]
    101 
    102   @property
    103   def base_dtype(self):
    104     """Returns a non-reference `DType` based on this `DType`."""
    105     if self._is_ref_dtype:
    106       return _INTERN_TABLE[self._type_enum - 100]
    107     else:
    108       return self
    109 
    110   @property
    111   def real_dtype(self):
    112     """Returns the dtype correspond to this dtype's real part."""
    113     base = self.base_dtype
    114     if base == complex64:
    115       return float32
    116     elif base == complex128:
    117       return float64
    118     else:
    119       return self
    120 
    121   @property
    122   def is_numpy_compatible(self):
    123     numpy_incompatible = [
    124         types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE,
    125         types_pb2.DT_RESOURCE_REF
    126     ]
    127     return self._type_enum not in numpy_incompatible
    128 
    129   @property
    130   def as_numpy_dtype(self):
    131     """Returns a `numpy.dtype` based on this `DType`."""
    132     return _TF_TO_NP[self._type_enum]
    133 
    134   @property
    135   def as_datatype_enum(self):
    136     """Returns a `types_pb2.DataType` enum value based on this `DType`."""
    137     return self._type_enum
    138 
    139   @property
    140   def is_bool(self):
    141     """Returns whether this is a boolean data type"""
    142     return self.base_dtype == bool
    143 
    144   @property
    145   def is_integer(self):
    146     """Returns whether this is a (non-quantized) integer type."""
    147     return (self.is_numpy_compatible and not self.is_quantized and
    148             np.issubdtype(self.as_numpy_dtype, np.integer))
    149 
    150   @property
    151   def is_floating(self):
    152     """Returns whether this is a (non-quantized, real) floating point type."""
    153     return ((self.is_numpy_compatible and
    154              np.issubdtype(self.as_numpy_dtype, np.floating)) or
    155             self.base_dtype == bfloat16)
    156 
    157   @property
    158   def is_complex(self):
    159     """Returns whether this is a complex floating point type."""
    160     return self.base_dtype in (complex64, complex128)
    161 
    162   @property
    163   def is_quantized(self):
    164     """Returns whether this is a quantized data type."""
    165     return self.base_dtype in [qint8, quint8, qint16, quint16, qint32]
    166 
    167   @property
    168   def is_unsigned(self):
    169     """Returns whether this type is unsigned.
    170 
    171     Non-numeric, unordered, and quantized types are not considered unsigned, and
    172     this function returns `False`.
    173 
    174     Returns:
    175       Whether a `DType` is unsigned.
    176     """
    177     try:
    178       return self.min == 0
    179     except TypeError:
    180       return False
    181 
    182   @property
    183   def min(self):
    184     """Returns the minimum representable value in this data type.
    185 
    186     Raises:
    187       TypeError: if this is a non-numeric, unordered, or quantized type.
    188 
    189     """
    190     if (self.is_quantized or
    191         self.base_dtype in (bool, string, complex64, complex128)):
    192       raise TypeError("Cannot find minimum value of %s." % self)
    193 
    194     # there is no simple way to get the min value of a dtype, we have to check
    195     # float and int types separately
    196     try:
    197       return np.finfo(self.as_numpy_dtype()).min
    198     except:  # bare except as possible raises by finfo not documented
    199       try:
    200         return np.iinfo(self.as_numpy_dtype()).min
    201       except:
    202         if self.base_dtype == bfloat16:
    203           return _np_bfloat16(float.fromhex("-0x1.FEp127"))
    204         raise TypeError("Cannot find minimum value of %s." % self)
    205 
    206   @property
    207   def max(self):
    208     """Returns the maximum representable value in this data type.
    209 
    210     Raises:
    211       TypeError: if this is a non-numeric, unordered, or quantized type.
    212 
    213     """
    214     if (self.is_quantized or
    215         self.base_dtype in (bool, string, complex64, complex128)):
    216       raise TypeError("Cannot find maximum value of %s." % self)
    217 
    218     # there is no simple way to get the max value of a dtype, we have to check
    219     # float and int types separately
    220     try:
    221       return np.finfo(self.as_numpy_dtype()).max
    222     except:  # bare except as possible raises by finfo not documented
    223       try:
    224         return np.iinfo(self.as_numpy_dtype()).max
    225       except:
    226         if self.base_dtype == bfloat16:
    227           return _np_bfloat16(float.fromhex("0x1.FEp127"))
    228         raise TypeError("Cannot find maximum value of %s." % self)
    229 
    230   @property
    231   def limits(self, clip_negative=True):
    232     """Return intensity limits, i.e. (min, max) tuple, of the dtype.
    233     Args:
    234       clip_negative : bool, optional
    235           If True, clip the negative range (i.e. return 0 for min intensity)
    236           even if the image dtype allows negative values.
    237     Returns
    238       min, max : tuple
    239         Lower and upper intensity limits.
    240     """
    241     min, max = dtype_range[self.as_numpy_dtype]  # pylint: disable=redefined-builtin
    242     if clip_negative:
    243       min = 0  # pylint: disable=redefined-builtin
    244     return min, max
    245 
    246   def is_compatible_with(self, other):
    247     """Returns True if the `other` DType will be converted to this DType.
    248 
    249     The conversion rules are as follows:
    250 
    251     ```python
    252     DType(T)       .is_compatible_with(DType(T))        == True
    253     DType(T)       .is_compatible_with(DType(T).as_ref) == True
    254     DType(T).as_ref.is_compatible_with(DType(T))        == False
    255     DType(T).as_ref.is_compatible_with(DType(T).as_ref) == True
    256     ```
    257 
    258     Args:
    259       other: A `DType` (or object that may be converted to a `DType`).
    260 
    261     Returns:
    262       True if a Tensor of the `other` `DType` will be implicitly converted to
    263       this `DType`.
    264     """
    265     other = as_dtype(other)
    266     return self._type_enum in (other.as_datatype_enum,
    267                                other.base_dtype.as_datatype_enum)
    268 
    269   def __eq__(self, other):
    270     """Returns True iff this DType refers to the same type as `other`."""
    271     if other is None:
    272       return False
    273     try:
    274       dtype = as_dtype(other).as_datatype_enum
    275       return self._type_enum == dtype  # pylint: disable=protected-access
    276     except TypeError:
    277       return False
    278 
    279   def __ne__(self, other):
    280     """Returns True iff self != other."""
    281     return not self.__eq__(other)
    282 
    283   @property
    284   def name(self):
    285     """Returns the string name for this `DType`."""
    286     return _TYPE_TO_STRING[self._type_enum]
    287 
    288   def __int__(self):
    289     return self._type_enum
    290 
    291   def __str__(self):
    292     return "<dtype: %r>" % self.name
    293 
    294   def __repr__(self):
    295     return "tf." + self.name
    296 
    297   def __hash__(self):
    298     return self._type_enum
    299 
    300   @property
    301   def size(self):
    302     if (self._type_enum == types_pb2.DT_VARIANT or
    303         self._type_enum == types_pb2.DT_RESOURCE):
    304       return 1
    305     return np.dtype(self.as_numpy_dtype).itemsize
    306 
    307 
    308 # Define data type range of numpy dtype
    309 dtype_range = {
    310     np.bool_: (False, True),
    311     np.bool8: (False, True),
    312     np.uint8: (0, 255),
    313     np.uint16: (0, 65535),
    314     np.int8: (-128, 127),
    315     np.int16: (-32768, 32767),
    316     np.int64: (-2**63, 2**63 - 1),
    317     np.uint64: (0, 2**64 - 1),
    318     np.int32: (-2**31, 2**31 - 1),
    319     np.uint32: (0, 2**32 - 1),
    320     np.float32: (-1, 1),
    321     np.float64: (-1, 1)
    322 }
    323 
    324 # Define standard wrappers for the types_pb2.DataType enum.
    325 resource = DType(types_pb2.DT_RESOURCE)
    326 tf_export("resource").export_constant(__name__, "resource")
    327 variant = DType(types_pb2.DT_VARIANT)
    328 tf_export("variant").export_constant(__name__, "variant")
    329 float16 = DType(types_pb2.DT_HALF)
    330 tf_export("float16").export_constant(__name__, "float16")
    331 half = float16
    332 tf_export("half").export_constant(__name__, "half")
    333 float32 = DType(types_pb2.DT_FLOAT)
    334 tf_export("float32").export_constant(__name__, "float32")
    335 float64 = DType(types_pb2.DT_DOUBLE)
    336 tf_export("float64").export_constant(__name__, "float64")
    337 double = float64
    338 tf_export("double").export_constant(__name__, "double")
    339 int32 = DType(types_pb2.DT_INT32)
    340 tf_export("int32").export_constant(__name__, "int32")
    341 uint8 = DType(types_pb2.DT_UINT8)
    342 tf_export("uint8").export_constant(__name__, "uint8")
    343 uint16 = DType(types_pb2.DT_UINT16)
    344 tf_export("uint16").export_constant(__name__, "uint16")
    345 uint32 = DType(types_pb2.DT_UINT32)
    346 uint64 = DType(types_pb2.DT_UINT64)
    347 int16 = DType(types_pb2.DT_INT16)
    348 tf_export("int16").export_constant(__name__, "int16")
    349 int8 = DType(types_pb2.DT_INT8)
    350 tf_export("int8").export_constant(__name__, "int8")
    351 string = DType(types_pb2.DT_STRING)
    352 tf_export("string").export_constant(__name__, "string")
    353 complex64 = DType(types_pb2.DT_COMPLEX64)
    354 tf_export("complex64").export_constant(__name__, "complex64")
    355 complex128 = DType(types_pb2.DT_COMPLEX128)
    356 tf_export("complex128").export_constant(__name__, "complex128")
    357 int64 = DType(types_pb2.DT_INT64)
    358 tf_export("int64").export_constant(__name__, "int64")
    359 bool = DType(types_pb2.DT_BOOL)  # pylint: disable=redefined-builtin
    360 tf_export("bool").export_constant(__name__, "bool")
    361 qint8 = DType(types_pb2.DT_QINT8)
    362 tf_export("qint8").export_constant(__name__, "qint8")
    363 quint8 = DType(types_pb2.DT_QUINT8)
    364 tf_export("quint8").export_constant(__name__, "quint8")
    365 qint16 = DType(types_pb2.DT_QINT16)
    366 tf_export("qint16").export_constant(__name__, "qint16")
    367 quint16 = DType(types_pb2.DT_QUINT16)
    368 tf_export("quint16").export_constant(__name__, "quint16")
    369 qint32 = DType(types_pb2.DT_QINT32)
    370 tf_export("qint32").export_constant(__name__, "qint32")
    371 resource_ref = DType(types_pb2.DT_RESOURCE_REF)
    372 variant_ref = DType(types_pb2.DT_VARIANT_REF)
    373 bfloat16 = DType(types_pb2.DT_BFLOAT16)
    374 tf_export("bfloat16").export_constant(__name__, "bfloat16")
    375 float16_ref = DType(types_pb2.DT_HALF_REF)
    376 half_ref = float16_ref
    377 float32_ref = DType(types_pb2.DT_FLOAT_REF)
    378 float64_ref = DType(types_pb2.DT_DOUBLE_REF)
    379 double_ref = float64_ref
    380 int32_ref = DType(types_pb2.DT_INT32_REF)
    381 uint32_ref = DType(types_pb2.DT_UINT32_REF)
    382 uint8_ref = DType(types_pb2.DT_UINT8_REF)
    383 uint16_ref = DType(types_pb2.DT_UINT16_REF)
    384 int16_ref = DType(types_pb2.DT_INT16_REF)
    385 int8_ref = DType(types_pb2.DT_INT8_REF)
    386 string_ref = DType(types_pb2.DT_STRING_REF)
    387 complex64_ref = DType(types_pb2.DT_COMPLEX64_REF)
    388 complex128_ref = DType(types_pb2.DT_COMPLEX128_REF)
    389 int64_ref = DType(types_pb2.DT_INT64_REF)
    390 uint64_ref = DType(types_pb2.DT_UINT64_REF)
    391 bool_ref = DType(types_pb2.DT_BOOL_REF)
    392 qint8_ref = DType(types_pb2.DT_QINT8_REF)
    393 quint8_ref = DType(types_pb2.DT_QUINT8_REF)
    394 qint16_ref = DType(types_pb2.DT_QINT16_REF)
    395 quint16_ref = DType(types_pb2.DT_QUINT16_REF)
    396 qint32_ref = DType(types_pb2.DT_QINT32_REF)
    397 bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF)
    398 
    399 # Maintain an intern table so that we don't have to create a large
    400 # number of small objects.
    401 _INTERN_TABLE = {
    402     types_pb2.DT_HALF: float16,
    403     types_pb2.DT_FLOAT: float32,
    404     types_pb2.DT_DOUBLE: float64,
    405     types_pb2.DT_INT32: int32,
    406     types_pb2.DT_UINT8: uint8,
    407     types_pb2.DT_UINT16: uint16,
    408     types_pb2.DT_UINT32: uint32,
    409     types_pb2.DT_UINT64: uint64,
    410     types_pb2.DT_INT16: int16,
    411     types_pb2.DT_INT8: int8,
    412     types_pb2.DT_STRING: string,
    413     types_pb2.DT_COMPLEX64: complex64,
    414     types_pb2.DT_COMPLEX128: complex128,
    415     types_pb2.DT_INT64: int64,
    416     types_pb2.DT_BOOL: bool,
    417     types_pb2.DT_QINT8: qint8,
    418     types_pb2.DT_QUINT8: quint8,
    419     types_pb2.DT_QINT16: qint16,
    420     types_pb2.DT_QUINT16: quint16,
    421     types_pb2.DT_QINT32: qint32,
    422     types_pb2.DT_BFLOAT16: bfloat16,
    423     types_pb2.DT_RESOURCE: resource,
    424     types_pb2.DT_VARIANT: variant,
    425     types_pb2.DT_HALF_REF: float16_ref,
    426     types_pb2.DT_FLOAT_REF: float32_ref,
    427     types_pb2.DT_DOUBLE_REF: float64_ref,
    428     types_pb2.DT_INT32_REF: int32_ref,
    429     types_pb2.DT_UINT32_REF: uint32_ref,
    430     types_pb2.DT_UINT8_REF: uint8_ref,
    431     types_pb2.DT_UINT16_REF: uint16_ref,
    432     types_pb2.DT_INT16_REF: int16_ref,
    433     types_pb2.DT_INT8_REF: int8_ref,
    434     types_pb2.DT_STRING_REF: string_ref,
    435     types_pb2.DT_COMPLEX64_REF: complex64_ref,
    436     types_pb2.DT_COMPLEX128_REF: complex128_ref,
    437     types_pb2.DT_INT64_REF: int64_ref,
    438     types_pb2.DT_UINT64_REF: uint64_ref,
    439     types_pb2.DT_BOOL_REF: bool_ref,
    440     types_pb2.DT_QINT8_REF: qint8_ref,
    441     types_pb2.DT_QUINT8_REF: quint8_ref,
    442     types_pb2.DT_QINT16_REF: qint16_ref,
    443     types_pb2.DT_QUINT16_REF: quint16_ref,
    444     types_pb2.DT_QINT32_REF: qint32_ref,
    445     types_pb2.DT_BFLOAT16_REF: bfloat16_ref,
    446     types_pb2.DT_RESOURCE_REF: resource_ref,
    447     types_pb2.DT_VARIANT_REF: variant_ref,
    448 }
    449 
    450 # Standard mappings between types_pb2.DataType values and string names.
    451 _TYPE_TO_STRING = {
    452     types_pb2.DT_HALF: "float16",
    453     types_pb2.DT_FLOAT: "float32",
    454     types_pb2.DT_DOUBLE: "float64",
    455     types_pb2.DT_INT32: "int32",
    456     types_pb2.DT_UINT8: "uint8",
    457     types_pb2.DT_UINT16: "uint16",
    458     types_pb2.DT_UINT32: "uint32",
    459     types_pb2.DT_UINT64: "uint64",
    460     types_pb2.DT_INT16: "int16",
    461     types_pb2.DT_INT8: "int8",
    462     types_pb2.DT_STRING: "string",
    463     types_pb2.DT_COMPLEX64: "complex64",
    464     types_pb2.DT_COMPLEX128: "complex128",
    465     types_pb2.DT_INT64: "int64",
    466     types_pb2.DT_BOOL: "bool",
    467     types_pb2.DT_QINT8: "qint8",
    468     types_pb2.DT_QUINT8: "quint8",
    469     types_pb2.DT_QINT16: "qint16",
    470     types_pb2.DT_QUINT16: "quint16",
    471     types_pb2.DT_QINT32: "qint32",
    472     types_pb2.DT_BFLOAT16: "bfloat16",
    473     types_pb2.DT_RESOURCE: "resource",
    474     types_pb2.DT_VARIANT: "variant",
    475     types_pb2.DT_HALF_REF: "float16_ref",
    476     types_pb2.DT_FLOAT_REF: "float32_ref",
    477     types_pb2.DT_DOUBLE_REF: "float64_ref",
    478     types_pb2.DT_INT32_REF: "int32_ref",
    479     types_pb2.DT_UINT32_REF: "uint32_ref",
    480     types_pb2.DT_UINT8_REF: "uint8_ref",
    481     types_pb2.DT_UINT16_REF: "uint16_ref",
    482     types_pb2.DT_INT16_REF: "int16_ref",
    483     types_pb2.DT_INT8_REF: "int8_ref",
    484     types_pb2.DT_STRING_REF: "string_ref",
    485     types_pb2.DT_COMPLEX64_REF: "complex64_ref",
    486     types_pb2.DT_COMPLEX128_REF: "complex128_ref",
    487     types_pb2.DT_INT64_REF: "int64_ref",
    488     types_pb2.DT_UINT64_REF: "uint64_ref",
    489     types_pb2.DT_BOOL_REF: "bool_ref",
    490     types_pb2.DT_QINT8_REF: "qint8_ref",
    491     types_pb2.DT_QUINT8_REF: "quint8_ref",
    492     types_pb2.DT_QINT16_REF: "qint16_ref",
    493     types_pb2.DT_QUINT16_REF: "quint16_ref",
    494     types_pb2.DT_QINT32_REF: "qint32_ref",
    495     types_pb2.DT_BFLOAT16_REF: "bfloat16_ref",
    496     types_pb2.DT_RESOURCE_REF: "resource_ref",
    497     types_pb2.DT_VARIANT_REF: "variant_ref",
    498 }
    499 _STRING_TO_TF = {
    500     value: _INTERN_TABLE[key]
    501     for key, value in _TYPE_TO_STRING.items()
    502 }
    503 # Add non-canonical aliases.
    504 _STRING_TO_TF["half"] = float16
    505 _STRING_TO_TF["half_ref"] = float16_ref
    506 _STRING_TO_TF["float"] = float32
    507 _STRING_TO_TF["float_ref"] = float32_ref
    508 _STRING_TO_TF["double"] = float64
    509 _STRING_TO_TF["double_ref"] = float64_ref
    510 
    511 # Numpy representation for quantized dtypes.
    512 #
    513 # These are magic strings that are used in the swig wrapper to identify
    514 # quantized types.
    515 # TODO(mrry,keveman): Investigate Numpy type registration to replace this
    516 # hard-coding of names.
    517 _np_qint8 = np.dtype([("qint8", np.int8, 1)])
    518 _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
    519 _np_qint16 = np.dtype([("qint16", np.int16, 1)])
    520 _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
    521 _np_qint32 = np.dtype([("qint32", np.int32, 1)])
    522 
    523 # _np_bfloat16 is defined by a module import.
    524 
    525 # Custom struct dtype for directly-fed ResourceHandles of supported type(s).
    526 np_resource = np.dtype([("resource", np.ubyte, 1)])
    527 
    528 # Standard mappings between types_pb2.DataType values and numpy.dtypes.
    529 _NP_TO_TF = frozenset([
    530     (np.float16, float16),
    531     (np.float32, float32),
    532     (np.float64, float64),
    533     (np.int32, int32),
    534     (np.int64, int64),
    535     (np.uint8, uint8),
    536     (np.uint16, uint16),
    537     (np.uint32, uint32),
    538     (np.uint64, uint64),
    539     (np.int16, int16),
    540     (np.int8, int8),
    541     (np.complex64, complex64),
    542     (np.complex128, complex128),
    543     (np.object, string),
    544     (np.bool, bool),
    545     (_np_qint8, qint8),
    546     (_np_quint8, quint8),
    547     (_np_qint16, qint16),
    548     (_np_quint16, quint16),
    549     (_np_qint32, qint32),
    550     (_np_bfloat16, bfloat16),
    551 ])
    552 _TF_TO_NP = {
    553     types_pb2.DT_HALF:
    554         np.float16,
    555     types_pb2.DT_FLOAT:
    556         np.float32,
    557     types_pb2.DT_DOUBLE:
    558         np.float64,
    559     types_pb2.DT_INT32:
    560         np.int32,
    561     types_pb2.DT_UINT8:
    562         np.uint8,
    563     types_pb2.DT_UINT16:
    564         np.uint16,
    565     types_pb2.DT_UINT32:
    566         np.uint32,
    567     types_pb2.DT_UINT64:
    568         np.uint64,
    569     types_pb2.DT_INT16:
    570         np.int16,
    571     types_pb2.DT_INT8:
    572         np.int8,
    573     # NOTE(touts): For strings we use np.object as it supports variable length
    574     # strings.
    575     types_pb2.DT_STRING:
    576         np.object,
    577     types_pb2.DT_COMPLEX64:
    578         np.complex64,
    579     types_pb2.DT_COMPLEX128:
    580         np.complex128,
    581     types_pb2.DT_INT64:
    582         np.int64,
    583     types_pb2.DT_BOOL:
    584         np.bool,
    585     types_pb2.DT_QINT8:
    586         _np_qint8,
    587     types_pb2.DT_QUINT8:
    588         _np_quint8,
    589     types_pb2.DT_QINT16:
    590         _np_qint16,
    591     types_pb2.DT_QUINT16:
    592         _np_quint16,
    593     types_pb2.DT_QINT32:
    594         _np_qint32,
    595     types_pb2.DT_BFLOAT16:
    596         _np_bfloat16,
    597 
    598     # Ref types
    599     types_pb2.DT_HALF_REF:
    600         np.float16,
    601     types_pb2.DT_FLOAT_REF:
    602         np.float32,
    603     types_pb2.DT_DOUBLE_REF:
    604         np.float64,
    605     types_pb2.DT_INT32_REF:
    606         np.int32,
    607     types_pb2.DT_UINT32_REF:
    608         np.uint32,
    609     types_pb2.DT_UINT8_REF:
    610         np.uint8,
    611     types_pb2.DT_UINT16_REF:
    612         np.uint16,
    613     types_pb2.DT_INT16_REF:
    614         np.int16,
    615     types_pb2.DT_INT8_REF:
    616         np.int8,
    617     types_pb2.DT_STRING_REF:
    618         np.object,
    619     types_pb2.DT_COMPLEX64_REF:
    620         np.complex64,
    621     types_pb2.DT_COMPLEX128_REF:
    622         np.complex128,
    623     types_pb2.DT_INT64_REF:
    624         np.int64,
    625     types_pb2.DT_UINT64_REF:
    626         np.uint64,
    627     types_pb2.DT_BOOL_REF:
    628         np.bool,
    629     types_pb2.DT_QINT8_REF:
    630         _np_qint8,
    631     types_pb2.DT_QUINT8_REF:
    632         _np_quint8,
    633     types_pb2.DT_QINT16_REF:
    634         _np_qint16,
    635     types_pb2.DT_QUINT16_REF:
    636         _np_quint16,
    637     types_pb2.DT_QINT32_REF:
    638         _np_qint32,
    639     types_pb2.DT_BFLOAT16_REF:
    640         _np_bfloat16,
    641 }
    642 
    643 QUANTIZED_DTYPES = frozenset([
    644     qint8, quint8, qint16, quint16, qint32, qint8_ref, quint8_ref, qint16_ref,
    645     quint16_ref, qint32_ref
    646 ])
    647 tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES")
    648 
    649 
    650 @tf_export("as_dtype")
    651 def as_dtype(type_value):
    652   """Converts the given `type_value` to a `DType`.
    653 
    654   Args:
    655     type_value: A value that can be converted to a `tf.DType`
    656       object. This may currently be a `tf.DType` object, a
    657       [`DataType`
    658         enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto),
    659       a string type name, or a `numpy.dtype`.
    660 
    661   Returns:
    662     A `DType` corresponding to `type_value`.
    663 
    664   Raises:
    665     TypeError: If `type_value` cannot be converted to a `DType`.
    666   """
    667   if isinstance(type_value, DType):
    668     return type_value
    669 
    670   try:
    671     return _INTERN_TABLE[type_value]
    672   except KeyError:
    673     pass
    674 
    675   try:
    676     return _STRING_TO_TF[type_value]
    677   except KeyError:
    678     pass
    679 
    680   if isinstance(type_value, np.dtype):
    681     # The numpy dtype for strings is variable length. We can not compare
    682     # dtype with a single constant (np.string does not exist) to decide
    683     # dtype is a "string" type. We need to compare the dtype.type to be
    684     # sure it's a string type.
    685     if type_value.type == np.string_ or type_value.type == np.unicode_:
    686       return string
    687 
    688   for key, val in _NP_TO_TF:
    689     try:
    690       if key == type_value:
    691         return val
    692     except TypeError as e:
    693       raise TypeError("Cannot convert {} to a dtype. {}".format(type_value, e))
    694 
    695   raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value)
    696