1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_FRAMEWORK_TYPES_H_ 17 #define TENSORFLOW_FRAMEWORK_TYPES_H_ 18 19 #include <map> 20 #include <set> 21 #include <string> 22 23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24 // Disable clang-format to prevent 'FixedPoint' header from being included 25 // before 'Tensor' header on which it depends. 26 // clang-format off 27 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" 28 // clang-format on 29 #include "tensorflow/core/framework/bfloat16.h" 30 #include "tensorflow/core/framework/numeric_types.h" 31 #include "tensorflow/core/framework/resource_handle.h" 32 #include "tensorflow/core/framework/types.pb.h" 33 #include "tensorflow/core/framework/variant.h" 34 #include "tensorflow/core/lib/core/stringpiece.h" 35 #include "tensorflow/core/lib/gtl/array_slice.h" 36 #include "tensorflow/core/lib/gtl/inlined_vector.h" 37 #include "tensorflow/core/platform/logging.h" 38 #include "tensorflow/core/platform/types.h" 39 40 namespace tensorflow { 41 42 // MemoryType is used to describe whether input or output Tensors of 43 // an OpKernel should reside in "Host memory" (e.g., CPU memory) or 44 // "Device" Memory (CPU memory for CPU devices, GPU memory for GPU 45 // devices). 46 enum MemoryType { 47 DEVICE_MEMORY = 0, 48 HOST_MEMORY = 1, 49 }; 50 51 // A DeviceType is just a string, but we wrap it up in a class to give 52 // some type checking as we're passing these around 53 class DeviceType { 54 public: 55 DeviceType(const char* type) // NOLINT(runtime/explicit) 56 : type_(type) {} 57 58 explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {} 59 60 const char* type() const { return type_.c_str(); } 61 const string& type_string() const { return type_; } 62 63 bool operator<(const DeviceType& other) const; 64 bool operator==(const DeviceType& other) const; 65 bool operator!=(const DeviceType& other) const { return !(*this == other); } 66 67 private: 68 string type_; 69 }; 70 std::ostream& operator<<(std::ostream& os, const DeviceType& d); 71 72 // Convenient constants that can be passed to a DeviceType constructor 73 TF_EXPORT extern const char* const DEVICE_CPU; // "CPU" 74 TF_EXPORT extern const char* const DEVICE_GPU; // "GPU" 75 TF_EXPORT extern const char* const DEVICE_SYCL; // "SYCL" 76 77 template <typename Device> 78 struct DeviceName {}; 79 80 template <> 81 struct DeviceName<Eigen::ThreadPoolDevice> { 82 static const std::string value; 83 }; 84 85 #if GOOGLE_CUDA 86 template <> 87 struct DeviceName<Eigen::GpuDevice> { 88 static const std::string value; 89 }; 90 #endif // GOOGLE_CUDA 91 92 #ifdef TENSORFLOW_USE_SYCL 93 template <> 94 struct DeviceName<Eigen::SyclDevice> { 95 static const std::string value; 96 }; 97 #endif // TENSORFLOW_USE_SYCL 98 99 typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector; 100 typedef gtl::ArraySlice<MemoryType> MemoryTypeSlice; 101 102 typedef gtl::InlinedVector<DataType, 4> DataTypeVector; 103 typedef gtl::ArraySlice<DataType> DataTypeSlice; 104 105 typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector; 106 107 // Convert the enums to strings for errors: 108 string DataTypeString(DataType dtype); 109 string DeviceTypeString(const DeviceType& device_type); 110 string DataTypeSliceString(const DataTypeSlice dtypes); 111 inline string DataTypeVectorString(const DataTypeVector& dtypes) { 112 return DataTypeSliceString(dtypes); 113 } 114 115 // DataTypeSet represents a set of DataType values as a simple and efficient 116 // bit mask. Note that DataTypeSet cannot represent all DataType values; it 117 // cannot represent any of the DT_*_REF values. 118 class DataTypeSet { 119 private: 120 const uint32 mask_; 121 122 static constexpr uint32 kNumBits = 32; 123 124 public: 125 constexpr DataTypeSet(const DataTypeSet& other) : mask_(other.mask_) {} 126 explicit constexpr DataTypeSet(uint32 mask) : mask_(mask) {} 127 128 constexpr bool Contains(DataType dt) const { 129 return (static_cast<uint32>(dt) < kNumBits) && 130 ((mask_ >> static_cast<uint32>(dt)) & 1u) != 0u; 131 } 132 133 class Iterator { 134 const DataTypeSet& set_; 135 uint32 pos_; 136 137 public: 138 Iterator(const DataTypeSet& set, uint32 pos) : set_(set), pos_(pos) { 139 DCHECK_LE(pos, kNumBits); 140 } 141 DataType operator*() const { return static_cast<DataType>(pos_); } 142 Iterator& operator++() { 143 ++pos_; 144 DCHECK_LE(pos_, kNumBits); 145 if (pos_ < kNumBits) { 146 uint32 remaining_mask = set_.mask_ >> pos_; 147 if (remaining_mask != 0u) { 148 pos_ += ctz_uint32(remaining_mask); 149 } 150 } 151 DCHECK_LE(pos_, kNumBits); 152 return *this; 153 } 154 bool operator==(const Iterator& other) const { return pos_ == other.pos_; } 155 bool operator!=(const Iterator& other) const { return !(*this == other); } 156 size_t operator-(const Iterator& other) const { 157 return this->pos_ - other.pos_; 158 } 159 }; 160 161 static uint32 ctz_uint32(uint32 x) { 162 DCHECK_NE(x, 0u); 163 #ifdef __GNUC__ 164 return __builtin_ctz(x); 165 #else 166 uint32 n = 0u; 167 while ((x & 1u) == 0u) { 168 x >>= 1; 169 ++n; 170 } 171 return n; 172 #endif 173 } 174 175 static uint32 clz_uint32(uint32 x) { 176 DCHECK_NE(x, 0u); 177 #ifdef __GNUC__ 178 return __builtin_clz(x); 179 #else 180 uint32 n = 0u; 181 while ((x >> (kNumBits - 1u)) == 0u) { 182 x <<= 1; 183 ++n; 184 } 185 return n; 186 #endif 187 } 188 189 Iterator begin() const { 190 // The begin position is the index of the first bit set to 1 in the entire 191 // bit mask. If there are no bits set to 1, then the index is 0. 192 if (mask_ != 0) { 193 return Iterator(*this, ctz_uint32(mask_)); 194 } 195 // The set is empty. 196 return Iterator(*this, 0); 197 } 198 199 Iterator end() const { 200 // The end position is the index of the highest bit that is set, plus 1. 201 // If there are no bits set to 1, then the index is 0. 202 if (mask_ != 0) { 203 return Iterator(*this, kNumBits - clz_uint32(mask_)); 204 } 205 // The set is empty. 206 return Iterator(*this, 0); 207 } 208 209 size_t size() const { 210 #if defined(__GNUC__) 211 return __builtin_popcount(mask_); 212 #else 213 size_t n = 0; 214 uint32 x = mask_; 215 while (x > 0) { 216 n += x & 1u; 217 x >>= 1; 218 } 219 return n; 220 #endif 221 } 222 223 constexpr DataTypeSet operator|(const DataTypeSet& other) const { 224 return DataTypeSet(mask_ | other.mask_); 225 } 226 }; 227 228 // If "sp" names a valid type, store it in "*dt" and return true. Otherwise, 229 // return false. 230 bool DataTypeFromString(StringPiece sp, DataType* dt); 231 232 constexpr inline DataTypeSet ToSet(DataType dt) { 233 return DataTypeSet(1u << static_cast<uint32>(dt)); 234 } 235 236 // DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc. 237 enum { kDataTypeRefOffset = 100 }; 238 inline bool IsRefType(DataType dtype) { 239 return dtype > static_cast<DataType>(kDataTypeRefOffset); 240 } 241 inline DataType MakeRefType(DataType dtype) { 242 DCHECK(!IsRefType(dtype)); 243 return static_cast<DataType>(dtype + kDataTypeRefOffset); 244 } 245 inline DataType RemoveRefType(DataType dtype) { 246 DCHECK(IsRefType(dtype)); 247 return static_cast<DataType>(dtype - kDataTypeRefOffset); 248 } 249 inline DataType BaseType(DataType dtype) { 250 return IsRefType(dtype) ? RemoveRefType(dtype) : dtype; 251 } 252 253 // Returns true if the actual type is the same as or ref of the expected type. 254 inline bool TypesCompatible(DataType expected, DataType actual) { 255 return expected == actual || expected == BaseType(actual); 256 } 257 258 // Does not include _ref types. 259 constexpr DataTypeSet kAllTypes = 260 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT8) | 261 ToSet(DT_INT16) | ToSet(DT_UINT16) | ToSet(DT_INT8) | ToSet(DT_STRING) | 262 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) | 263 ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | 264 ToSet(DT_QUINT16) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_RESOURCE) | 265 ToSet(DT_VARIANT) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | 266 ToSet(DT_BFLOAT16); 267 inline const DataTypeSet& AllTypes() { return kAllTypes; } 268 269 #if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) 270 271 // Types that support '<' and '>'. 272 constexpr DataTypeSet kRealNumberTypes = 273 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) | 274 ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_INT8) | ToSet(DT_UINT16) | 275 ToSet(DT_HALF) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | ToSet(DT_BFLOAT16); 276 inline const DataTypeSet RealNumberTypes() { return kRealNumberTypes; } 277 278 // Return the list of all numeric types. 279 // Includes complex and quantized types. 280 // NOTE: On Android, we only include the float and int32 types for now. 281 const DataTypeSet kNumberTypes = 282 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT64) | ToSet(DT_INT32) | 283 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | 284 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_QINT8) | 285 ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_UINT32) | 286 ToSet(DT_UINT64) | ToSet(DT_BFLOAT16); 287 inline const DataTypeSet& NumberTypes() { return kNumberTypes; } 288 289 constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 290 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | 291 ToSet(DT_QINT32); 292 inline const DataTypeSet& QuantizedTypes() { return kQuantizedTypes; } 293 294 // Types that support '<' and '>', including quantized types. 295 const DataTypeSet kRealAndQuantizedTypes = 296 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) | 297 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_UINT16) | ToSet(DT_INT8) | 298 ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | 299 ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_BFLOAT16); 300 inline const DataTypeSet& RealAndQuantizedTypes() { 301 return kRealAndQuantizedTypes; 302 } 303 304 #elif defined(__ANDROID_TYPES_FULL__) 305 306 constexpr DataTypeSet kRealNumberTypes = 307 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_HALF); 308 inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; } 309 310 constexpr DataTypeSet kNumberTypes = 311 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) | 312 ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF); 313 inline DataTypeSet NumberTypes() { return kNumberTypes; } 314 315 constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 316 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | 317 ToSet(DT_QINT32); 318 inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; } 319 320 constexpr DataTypeSet kRealAndQuantizedTypes = 321 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) | 322 ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) | 323 ToSet(DT_HALF); 324 inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; } 325 326 #else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__) 327 328 constexpr DataTypeSet kRealNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32); 329 inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; } 330 331 constexpr DataTypeSet kNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32) | 332 ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 333 ToSet(DT_QINT32); 334 inline DataTypeSet NumberTypes() { return kNumberTypes; } 335 336 constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 337 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | 338 ToSet(DT_QINT32); 339 inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; } 340 341 constexpr DataTypeSet kRealAndQuantizedTypes = 342 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 343 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32); 344 inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; } 345 346 #endif // defined(IS_MOBILE_PLATFORM) 347 348 // Validates type T for whether it is a supported DataType. 349 template <class T> 350 struct IsValidDataType; 351 352 // DataTypeToEnum<T>::v() and DataTypeToEnum<T>::value are the DataType 353 // constants for T, e.g. DataTypeToEnum<float>::v() is DT_FLOAT. 354 template <class T> 355 struct DataTypeToEnum { 356 static_assert(IsValidDataType<T>::value, "Specified Data Type not supported"); 357 }; // Specializations below 358 359 // EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g. 360 // EnumToDataType<DT_FLOAT>::Type is float. 361 template <DataType VALUE> 362 struct EnumToDataType {}; // Specializations below 363 364 // Template specialization for both DataTypeToEnum and EnumToDataType. 365 #define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \ 366 template <> \ 367 struct DataTypeToEnum<TYPE> { \ 368 static DataType v() { return ENUM; } \ 369 static DataType ref() { return MakeRefType(ENUM); } \ 370 static constexpr DataType value = ENUM; \ 371 }; \ 372 template <> \ 373 struct IsValidDataType<TYPE> { \ 374 static constexpr bool value = true; \ 375 }; \ 376 template <> \ 377 struct EnumToDataType<ENUM> { \ 378 typedef TYPE Type; \ 379 } 380 381 MATCH_TYPE_AND_ENUM(float, DT_FLOAT); 382 MATCH_TYPE_AND_ENUM(double, DT_DOUBLE); 383 MATCH_TYPE_AND_ENUM(int32, DT_INT32); 384 MATCH_TYPE_AND_ENUM(uint32, DT_UINT32); 385 MATCH_TYPE_AND_ENUM(uint16, DT_UINT16); 386 MATCH_TYPE_AND_ENUM(uint8, DT_UINT8); 387 MATCH_TYPE_AND_ENUM(int16, DT_INT16); 388 MATCH_TYPE_AND_ENUM(int8, DT_INT8); 389 MATCH_TYPE_AND_ENUM(string, DT_STRING); 390 MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64); 391 MATCH_TYPE_AND_ENUM(complex128, DT_COMPLEX128); 392 MATCH_TYPE_AND_ENUM(int64, DT_INT64); 393 MATCH_TYPE_AND_ENUM(uint64, DT_UINT64); 394 MATCH_TYPE_AND_ENUM(bool, DT_BOOL); 395 MATCH_TYPE_AND_ENUM(qint8, DT_QINT8); 396 MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8); 397 MATCH_TYPE_AND_ENUM(qint16, DT_QINT16); 398 MATCH_TYPE_AND_ENUM(quint16, DT_QUINT16); 399 MATCH_TYPE_AND_ENUM(qint32, DT_QINT32); 400 MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16); 401 MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF); 402 MATCH_TYPE_AND_ENUM(ResourceHandle, DT_RESOURCE); 403 MATCH_TYPE_AND_ENUM(Variant, DT_VARIANT); 404 405 #undef MATCH_TYPE_AND_ENUM 406 407 // All types not specialized are marked invalid. 408 template <class T> 409 struct IsValidDataType { 410 static constexpr bool value = false; 411 }; 412 413 // Extra validity checking; not part of public API. 414 static_assert(IsValidDataType<int64>::value, "Incorrect impl for int64"); 415 static_assert(IsValidDataType<int32>::value, "Incorrect impl for int32"); 416 417 // TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying 418 // is_simple<T> in tensor.cc (and possible choose a more general name?) 419 constexpr DataTypeSet kDataTypesCanUseMemcpy = 420 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT32) | 421 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | 422 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) | 423 ToSet(DT_UINT64) | ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 424 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) | 425 ToSet(DT_BFLOAT16) | ToSet(DT_HALF); 426 inline bool DataTypeCanUseMemcpy(DataType dt) { 427 return kDataTypesCanUseMemcpy.Contains(dt); 428 } 429 430 // Returns true iff 'dt' is a real, non-quantized floating point type. 431 constexpr DataTypeSet kDataTypeIsFloating = 432 ToSet(DT_HALF) | ToSet(DT_BFLOAT16) | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE); 433 inline bool DataTypeIsFloating(DataType dt) { 434 return kDataTypeIsFloating.Contains(dt); 435 } 436 437 // Returns true iff 'dt' is a complex type. 438 constexpr DataTypeSet kDataTypeIsComplex = 439 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128); 440 inline bool DataTypeIsComplex(DataType dt) { 441 return kDataTypeIsComplex.Contains(dt); 442 } 443 444 inline bool DataTypeIsQuantized(DataType dt) { 445 return kQuantizedTypes.Contains(dt); 446 } 447 448 // Is the dtype nonquantized integral? 449 constexpr DataTypeSet kDataTypeIsInteger = 450 ToSet(DT_INT8) | ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_UINT16) | 451 ToSet(DT_INT32) | ToSet(DT_UINT32) | ToSet(DT_INT64) | ToSet(DT_UINT64); 452 inline bool DataTypeIsInteger(DataType dt) { 453 return kDataTypeIsInteger.Contains(dt); 454 } 455 456 // Is the dtype a signed integral type? 457 constexpr DataTypeSet kDataTypeIsSigned = 458 ToSet(DT_INT8) | ToSet(DT_INT16) | ToSet(DT_INT32) | ToSet(DT_INT64); 459 inline bool DataTypeIsSigned(DataType dt) { 460 return kDataTypeIsSigned.Contains(dt); 461 } 462 463 // Is the dtype an unsigned integral type? 464 constexpr DataTypeSet kDataTypeIsUnsigned = 465 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_UINT32) | ToSet(DT_UINT64); 466 inline bool DataTypeIsUnsigned(DataType dt) { 467 return kDataTypeIsUnsigned.Contains(dt); 468 } 469 470 // Returns a 0 on failure 471 int DataTypeSize(DataType dt); 472 473 // Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE. 474 // For DT_RESOURCE, the handle always sits on host (even if the underlying 475 // object has device-allocated resources). 476 bool DataTypeAlwaysOnHost(DataType dt); 477 478 } // namespace tensorflow 479 480 #endif // TENSORFLOW_FRAMEWORK_TYPES_H_ 481