Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/types.h"
     17 #include "tensorflow/core/framework/register_types.h"
     18 
     19 #include "tensorflow/core/lib/strings/str_util.h"
     20 #include "tensorflow/core/lib/strings/strcat.h"
     21 #include "tensorflow/core/platform/logging.h"
     22 
     23 namespace tensorflow {
     24 
     25 bool DeviceType::operator<(const DeviceType& other) const {
     26   return type_ < other.type_;
     27 }
     28 
     29 bool DeviceType::operator==(const DeviceType& other) const {
     30   return type_ == other.type_;
     31 }
     32 
     33 std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
     34   os << d.type();
     35   return os;
     36 }
     37 
     38 const char* const DEVICE_CPU = "CPU";
     39 const char* const DEVICE_GPU = "GPU";
     40 const char* const DEVICE_SYCL = "SYCL";
     41 
     42 const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
     43 #if GOOGLE_CUDA
     44 const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
     45 #endif  // GOOGLE_CUDA
     46 #ifdef TENSORFLOW_USE_SYCL
     47 const std::string DeviceName<Eigen::SyclDevice>::value = DEVICE_SYCL;
     48 #endif  // TENSORFLOW_USE_SYCL
     49 
     50 namespace {
     51 string DataTypeStringInternal(DataType dtype) {
     52   switch (dtype) {
     53     case DT_INVALID:
     54       return "INVALID";
     55     case DT_FLOAT:
     56       return "float";
     57     case DT_DOUBLE:
     58       return "double";
     59     case DT_INT32:
     60       return "int32";
     61     case DT_UINT32:
     62       return "uint32";
     63     case DT_UINT8:
     64       return "uint8";
     65     case DT_UINT16:
     66       return "uint16";
     67     case DT_INT16:
     68       return "int16";
     69     case DT_INT8:
     70       return "int8";
     71     case DT_STRING:
     72       return "string";
     73     case DT_COMPLEX64:
     74       return "complex64";
     75     case DT_COMPLEX128:
     76       return "complex128";
     77     case DT_INT64:
     78       return "int64";
     79     case DT_UINT64:
     80       return "uint64";
     81     case DT_BOOL:
     82       return "bool";
     83     case DT_QINT8:
     84       return "qint8";
     85     case DT_QUINT8:
     86       return "quint8";
     87     case DT_QUINT16:
     88       return "quint16";
     89     case DT_QINT16:
     90       return "qint16";
     91     case DT_QINT32:
     92       return "qint32";
     93     case DT_BFLOAT16:
     94       return "bfloat16";
     95     case DT_HALF:
     96       return "half";
     97     case DT_RESOURCE:
     98       return "resource";
     99     case DT_VARIANT:
    100       return "variant";
    101     default:
    102       LOG(ERROR) << "Unrecognized DataType enum value " << dtype;
    103       return strings::StrCat("unknown dtype enum (", dtype, ")");
    104   }
    105 }
    106 }  // end namespace
    107 
    108 string DataTypeString(DataType dtype) {
    109   if (IsRefType(dtype)) {
    110     DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset);
    111     return strings::StrCat(DataTypeStringInternal(non_ref), "_ref");
    112   }
    113   return DataTypeStringInternal(dtype);
    114 }
    115 
    116 bool DataTypeFromString(StringPiece sp, DataType* dt) {
    117   if (str_util::EndsWith(sp, "_ref")) {
    118     sp.remove_suffix(4);
    119     DataType non_ref;
    120     if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) {
    121       *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset);
    122       return true;
    123     } else {
    124       return false;
    125     }
    126   }
    127 
    128   if (sp == "float" || sp == "float32") {
    129     *dt = DT_FLOAT;
    130     return true;
    131   } else if (sp == "double" || sp == "float64") {
    132     *dt = DT_DOUBLE;
    133     return true;
    134   } else if (sp == "int32") {
    135     *dt = DT_INT32;
    136     return true;
    137   } else if (sp == "uint32") {
    138     *dt = DT_UINT32;
    139     return true;
    140   } else if (sp == "uint8") {
    141     *dt = DT_UINT8;
    142     return true;
    143   } else if (sp == "uint16") {
    144     *dt = DT_UINT16;
    145     return true;
    146   } else if (sp == "int16") {
    147     *dt = DT_INT16;
    148     return true;
    149   } else if (sp == "int8") {
    150     *dt = DT_INT8;
    151     return true;
    152   } else if (sp == "string") {
    153     *dt = DT_STRING;
    154     return true;
    155   } else if (sp == "complex64") {
    156     *dt = DT_COMPLEX64;
    157     return true;
    158   } else if (sp == "complex128") {
    159     *dt = DT_COMPLEX128;
    160     return true;
    161   } else if (sp == "int64") {
    162     *dt = DT_INT64;
    163     return true;
    164   } else if (sp == "uint64") {
    165     *dt = DT_UINT64;
    166     return true;
    167   } else if (sp == "bool") {
    168     *dt = DT_BOOL;
    169     return true;
    170   } else if (sp == "qint8") {
    171     *dt = DT_QINT8;
    172     return true;
    173   } else if (sp == "quint8") {
    174     *dt = DT_QUINT8;
    175     return true;
    176   } else if (sp == "qint16") {
    177     *dt = DT_QINT16;
    178     return true;
    179   } else if (sp == "quint16") {
    180     *dt = DT_QUINT16;
    181     return true;
    182   } else if (sp == "qint32") {
    183     *dt = DT_QINT32;
    184     return true;
    185   } else if (sp == "bfloat16") {
    186     *dt = DT_BFLOAT16;
    187     return true;
    188   } else if (sp == "half" || sp == "float16") {
    189     *dt = DT_HALF;
    190     return true;
    191   } else if (sp == "resource") {
    192     *dt = DT_RESOURCE;
    193     return true;
    194   } else if (sp == "variant") {
    195     *dt = DT_VARIANT;
    196     return true;
    197   }
    198   return false;
    199 }
    200 
    201 string DeviceTypeString(const DeviceType& device_type) {
    202   return device_type.type();
    203 }
    204 
    205 string DataTypeSliceString(const DataTypeSlice types) {
    206   string out;
    207   for (auto it = types.begin(); it != types.end(); ++it) {
    208     strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "),
    209                        DataTypeString(*it));
    210   }
    211   return out;
    212 }
    213 
    214 bool DataTypeAlwaysOnHost(DataType dt) {
    215   // Includes DT_STRING and DT_RESOURCE.
    216   switch (dt) {
    217     case DT_STRING:
    218     case DT_STRING_REF:
    219     case DT_RESOURCE:
    220       return true;
    221     default:
    222       return false;
    223   }
    224 }
    225 
    226 int DataTypeSize(DataType dt) {
    227 #define CASE(T)                  \
    228   case DataTypeToEnum<T>::value: \
    229     return sizeof(T);
    230   switch (dt) {
    231     TF_CALL_POD_TYPES(CASE);
    232     TF_CALL_QUANTIZED_TYPES(CASE);
    233     // TF_CALL_QUANTIZED_TYPES() macro does no cover quint16 and qint16, since
    234     // they are not supported widely, but are explicitly listed here for
    235     // bitcast.
    236     TF_CALL_qint16(CASE);
    237     TF_CALL_quint16(CASE);
    238 
    239     // uint32 and uint64 aren't included in TF_CALL_POD_TYPES because we
    240     // don't want to define kernels for them at this stage to avoid binary
    241     // bloat.
    242     TF_CALL_uint32(CASE);
    243     TF_CALL_uint64(CASE);
    244     default:
    245       return 0;
    246   }
    247 #undef CASE
    248 }
    249 
    250 }  // namespace tensorflow
    251