1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/types.h" 17 #include "tensorflow/core/framework/register_types.h" 18 19 #include "tensorflow/core/lib/strings/str_util.h" 20 #include "tensorflow/core/lib/strings/strcat.h" 21 #include "tensorflow/core/platform/logging.h" 22 23 namespace tensorflow { 24 25 bool DeviceType::operator<(const DeviceType& other) const { 26 return type_ < other.type_; 27 } 28 29 bool DeviceType::operator==(const DeviceType& other) const { 30 return type_ == other.type_; 31 } 32 33 std::ostream& operator<<(std::ostream& os, const DeviceType& d) { 34 os << d.type(); 35 return os; 36 } 37 38 const char* const DEVICE_CPU = "CPU"; 39 const char* const DEVICE_GPU = "GPU"; 40 const char* const DEVICE_SYCL = "SYCL"; 41 42 const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU; 43 #if GOOGLE_CUDA 44 const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU; 45 #endif // GOOGLE_CUDA 46 #ifdef TENSORFLOW_USE_SYCL 47 const std::string DeviceName<Eigen::SyclDevice>::value = DEVICE_SYCL; 48 #endif // TENSORFLOW_USE_SYCL 49 50 namespace { 51 string DataTypeStringInternal(DataType dtype) { 52 switch (dtype) { 53 case DT_INVALID: 54 return "INVALID"; 55 case DT_FLOAT: 56 return "float"; 57 case DT_DOUBLE: 58 return "double"; 59 case DT_INT32: 60 return "int32"; 61 case DT_UINT32: 62 return "uint32"; 63 case DT_UINT8: 64 return "uint8"; 65 case DT_UINT16: 66 return "uint16"; 67 case DT_INT16: 68 return "int16"; 69 case DT_INT8: 70 return "int8"; 71 case DT_STRING: 72 return "string"; 73 case DT_COMPLEX64: 74 return "complex64"; 75 case DT_COMPLEX128: 76 return "complex128"; 77 case DT_INT64: 78 return "int64"; 79 case DT_UINT64: 80 return "uint64"; 81 case DT_BOOL: 82 return "bool"; 83 case DT_QINT8: 84 return "qint8"; 85 case DT_QUINT8: 86 return "quint8"; 87 case DT_QUINT16: 88 return "quint16"; 89 case DT_QINT16: 90 return "qint16"; 91 case DT_QINT32: 92 return "qint32"; 93 case DT_BFLOAT16: 94 return "bfloat16"; 95 case DT_HALF: 96 return "half"; 97 case DT_RESOURCE: 98 return "resource"; 99 case DT_VARIANT: 100 return "variant"; 101 default: 102 LOG(ERROR) << "Unrecognized DataType enum value " << dtype; 103 return strings::StrCat("unknown dtype enum (", dtype, ")"); 104 } 105 } 106 } // end namespace 107 108 string DataTypeString(DataType dtype) { 109 if (IsRefType(dtype)) { 110 DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset); 111 return strings::StrCat(DataTypeStringInternal(non_ref), "_ref"); 112 } 113 return DataTypeStringInternal(dtype); 114 } 115 116 bool DataTypeFromString(StringPiece sp, DataType* dt) { 117 if (str_util::EndsWith(sp, "_ref")) { 118 sp.remove_suffix(4); 119 DataType non_ref; 120 if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) { 121 *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset); 122 return true; 123 } else { 124 return false; 125 } 126 } 127 128 if (sp == "float" || sp == "float32") { 129 *dt = DT_FLOAT; 130 return true; 131 } else if (sp == "double" || sp == "float64") { 132 *dt = DT_DOUBLE; 133 return true; 134 } else if (sp == "int32") { 135 *dt = DT_INT32; 136 return true; 137 } else if (sp == "uint32") { 138 *dt = DT_UINT32; 139 return true; 140 } else if (sp == "uint8") { 141 *dt = DT_UINT8; 142 return true; 143 } else if (sp == "uint16") { 144 *dt = DT_UINT16; 145 return true; 146 } else if (sp == "int16") { 147 *dt = DT_INT16; 148 return true; 149 } else if (sp == "int8") { 150 *dt = DT_INT8; 151 return true; 152 } else if (sp == "string") { 153 *dt = DT_STRING; 154 return true; 155 } else if (sp == "complex64") { 156 *dt = DT_COMPLEX64; 157 return true; 158 } else if (sp == "complex128") { 159 *dt = DT_COMPLEX128; 160 return true; 161 } else if (sp == "int64") { 162 *dt = DT_INT64; 163 return true; 164 } else if (sp == "uint64") { 165 *dt = DT_UINT64; 166 return true; 167 } else if (sp == "bool") { 168 *dt = DT_BOOL; 169 return true; 170 } else if (sp == "qint8") { 171 *dt = DT_QINT8; 172 return true; 173 } else if (sp == "quint8") { 174 *dt = DT_QUINT8; 175 return true; 176 } else if (sp == "qint16") { 177 *dt = DT_QINT16; 178 return true; 179 } else if (sp == "quint16") { 180 *dt = DT_QUINT16; 181 return true; 182 } else if (sp == "qint32") { 183 *dt = DT_QINT32; 184 return true; 185 } else if (sp == "bfloat16") { 186 *dt = DT_BFLOAT16; 187 return true; 188 } else if (sp == "half" || sp == "float16") { 189 *dt = DT_HALF; 190 return true; 191 } else if (sp == "resource") { 192 *dt = DT_RESOURCE; 193 return true; 194 } else if (sp == "variant") { 195 *dt = DT_VARIANT; 196 return true; 197 } 198 return false; 199 } 200 201 string DeviceTypeString(const DeviceType& device_type) { 202 return device_type.type(); 203 } 204 205 string DataTypeSliceString(const DataTypeSlice types) { 206 string out; 207 for (auto it = types.begin(); it != types.end(); ++it) { 208 strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "), 209 DataTypeString(*it)); 210 } 211 return out; 212 } 213 214 bool DataTypeAlwaysOnHost(DataType dt) { 215 // Includes DT_STRING and DT_RESOURCE. 216 switch (dt) { 217 case DT_STRING: 218 case DT_STRING_REF: 219 case DT_RESOURCE: 220 return true; 221 default: 222 return false; 223 } 224 } 225 226 int DataTypeSize(DataType dt) { 227 #define CASE(T) \ 228 case DataTypeToEnum<T>::value: \ 229 return sizeof(T); 230 switch (dt) { 231 TF_CALL_POD_TYPES(CASE); 232 TF_CALL_QUANTIZED_TYPES(CASE); 233 // TF_CALL_QUANTIZED_TYPES() macro does no cover quint16 and qint16, since 234 // they are not supported widely, but are explicitly listed here for 235 // bitcast. 236 TF_CALL_qint16(CASE); 237 TF_CALL_quint16(CASE); 238 239 // uint32 and uint64 aren't included in TF_CALL_POD_TYPES because we 240 // don't want to define kernels for them at this stage to avoid binary 241 // bloat. 242 TF_CALL_uint32(CASE); 243 TF_CALL_uint64(CASE); 244 default: 245 return 0; 246 } 247 #undef CASE 248 } 249 250 } // namespace tensorflow 251