1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_ 18 // This file is used by cuda code and must remain compilable by nvcc. 19 20 #include "tensorflow/core/framework/numeric_types.h" 21 #include "tensorflow/core/framework/resource_handle.h" 22 #include "tensorflow/core/framework/variant.h" 23 #include "tensorflow/core/platform/types.h" 24 25 // Two sets of macros: 26 // - TF_CALL_float, TF_CALL_double, etc. which call the given macro with 27 // the type name as the only parameter - except on platforms for which 28 // the type should not be included. 29 // - Macros to apply another macro to lists of supported types. These also call 30 // into TF_CALL_float, TF_CALL_double, etc. so they filter by target platform 31 // as well. 32 // If you change the lists of types, please also update the list in types.cc. 33 // 34 // See example uses of these macros in core/ops. 35 // 36 // 37 // Each of these TF_CALL_XXX_TYPES(m) macros invokes the macro "m" multiple 38 // times by passing each invocation a data type supported by TensorFlow. 39 // 40 // The different variations pass different subsets of the types. 41 // TF_CALL_ALL_TYPES(m) applied "m" to all types supported by TensorFlow. 42 // The set of types depends on the compilation platform. 43 //. 44 // This can be used to register a different template instantiation of 45 // an OpKernel for different signatures, e.g.: 46 /* 47 #define REGISTER_PARTITION(type) \ 48 REGISTER_KERNEL_BUILDER( \ 49 Name("Partition").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 50 PartitionOp<type>); 51 TF_CALL_ALL_TYPES(REGISTER_PARTITION) 52 #undef REGISTER_PARTITION 53 */ 54 55 #if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) || \ 56 defined(ANDROID_TEGRA) 57 58 // All types are supported, so all macros are invoked. 59 // 60 // Note: macros are defined in same order as types in types.proto, for 61 // readability. 62 #define TF_CALL_float(m) m(float) 63 #define TF_CALL_double(m) m(double) 64 #define TF_CALL_int32(m) m(::tensorflow::int32) 65 #define TF_CALL_uint32(m) m(::tensorflow::uint32) 66 #define TF_CALL_uint8(m) m(::tensorflow::uint8) 67 #define TF_CALL_int16(m) m(::tensorflow::int16) 68 69 #define TF_CALL_int8(m) m(::tensorflow::int8) 70 #define TF_CALL_string(m) m(string) 71 #define TF_CALL_resource(m) m(::tensorflow::ResourceHandle) 72 #define TF_CALL_variant(m) m(::tensorflow::Variant) 73 #define TF_CALL_complex64(m) m(::tensorflow::complex64) 74 #define TF_CALL_int64(m) m(::tensorflow::int64) 75 #define TF_CALL_uint64(m) m(::tensorflow::uint64) 76 #define TF_CALL_bool(m) m(bool) 77 78 #define TF_CALL_qint8(m) m(::tensorflow::qint8) 79 #define TF_CALL_quint8(m) m(::tensorflow::quint8) 80 #define TF_CALL_qint32(m) m(::tensorflow::qint32) 81 #define TF_CALL_bfloat16(m) m(::tensorflow::bfloat16) 82 #define TF_CALL_qint16(m) m(::tensorflow::qint16) 83 84 #define TF_CALL_quint16(m) m(::tensorflow::quint16) 85 #define TF_CALL_uint16(m) m(::tensorflow::uint16) 86 #define TF_CALL_complex128(m) m(::tensorflow::complex128) 87 #define TF_CALL_half(m) m(Eigen::half) 88 89 #elif defined(__ANDROID_TYPES_FULL__) 90 91 // Only string, half, float, int32, int64, bool, and quantized types 92 // supported. 93 #define TF_CALL_float(m) m(float) 94 #define TF_CALL_double(m) 95 #define TF_CALL_int32(m) m(::tensorflow::int32) 96 #define TF_CALL_uint32(m) 97 #define TF_CALL_uint8(m) 98 #define TF_CALL_int16(m) 99 100 #define TF_CALL_int8(m) 101 #define TF_CALL_string(m) m(string) 102 #define TF_CALL_resource(m) 103 #define TF_CALL_variant(m) 104 #define TF_CALL_complex64(m) 105 #define TF_CALL_int64(m) m(::tensorflow::int64) 106 #define TF_CALL_uint64(m) 107 #define TF_CALL_bool(m) m(bool) 108 109 #define TF_CALL_qint8(m) m(::tensorflow::qint8) 110 #define TF_CALL_quint8(m) m(::tensorflow::quint8) 111 #define TF_CALL_qint32(m) m(::tensorflow::qint32) 112 #define TF_CALL_bfloat16(m) 113 #define TF_CALL_qint16(m) m(::tensorflow::qint16) 114 115 #define TF_CALL_quint16(m) m(::tensorflow::quint16) 116 #define TF_CALL_uint16(m) 117 #define TF_CALL_complex128(m) 118 #define TF_CALL_half(m) m(Eigen::half) 119 120 #else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__) 121 122 // Only float, int32, and bool are supported. 123 #define TF_CALL_float(m) m(float) 124 #define TF_CALL_double(m) 125 #define TF_CALL_int32(m) m(::tensorflow::int32) 126 #define TF_CALL_uint32(m) 127 #define TF_CALL_uint8(m) 128 #define TF_CALL_int16(m) 129 130 #define TF_CALL_int8(m) 131 #define TF_CALL_string(m) 132 #define TF_CALL_resource(m) 133 #define TF_CALL_variant(m) 134 #define TF_CALL_complex64(m) 135 #define TF_CALL_int64(m) 136 #define TF_CALL_uint64(m) 137 #define TF_CALL_bool(m) m(bool) 138 139 #define TF_CALL_qint8(m) 140 #define TF_CALL_quint8(m) 141 #define TF_CALL_qint32(m) 142 #define TF_CALL_bfloat16(m) 143 #define TF_CALL_qint16(m) 144 145 #define TF_CALL_quint16(m) 146 #define TF_CALL_uint16(m) 147 #define TF_CALL_complex128(m) 148 #define TF_CALL_half(m) 149 150 #endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines 151 152 // Defines for sets of types. 153 154 // TODO(b/111604096): Add uint32 and uint64 to TF_CALL_INTEGRAL_TYPES. 155 // 156 // The uint32 and uint64 types were introduced in 10/2017 to be used via XLA and 157 // thus were not included in TF_CALL_INTEGRAL_TYPES. Including them in 158 // TF_CALL_INTEGRAL_TYPES should only happen after evaluating the effect on the 159 // TF binary size and performance. 160 #define TF_CALL_INTEGRAL_TYPES(m) \ 161 TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \ 162 TF_CALL_uint8(m) TF_CALL_int8(m) 163 164 #define TF_CALL_FLOAT_TYPES(m) \ 165 TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) 166 167 #define TF_CALL_REAL_NUMBER_TYPES(m) \ 168 TF_CALL_INTEGRAL_TYPES(m) \ 169 TF_CALL_FLOAT_TYPES(m) 170 171 #define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \ 172 TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) 173 174 #define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ 175 TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \ 176 TF_CALL_int64(m) TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) \ 177 TF_CALL_int8(m) 178 179 // Call "m" for all number types, including complex64 and complex128. 180 #define TF_CALL_NUMBER_TYPES(m) \ 181 TF_CALL_REAL_NUMBER_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m) 182 183 #define TF_CALL_NUMBER_TYPES_NO_INT32(m) \ 184 TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ 185 TF_CALL_complex64(m) TF_CALL_complex128(m) 186 187 #define TF_CALL_POD_TYPES(m) TF_CALL_NUMBER_TYPES(m) TF_CALL_bool(m) 188 189 // Call "m" on all types. 190 #define TF_CALL_ALL_TYPES(m) \ 191 TF_CALL_POD_TYPES(m) TF_CALL_string(m) TF_CALL_resource(m) TF_CALL_variant(m) 192 193 // Call "m" on POD and string types. 194 #define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_string(m) 195 196 // Call "m" on all number types supported on GPU. 197 #define TF_CALL_GPU_NUMBER_TYPES(m) \ 198 TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) 199 200 // Call "m" on all types supported on GPU. 201 #define TF_CALL_GPU_ALL_TYPES(m) \ 202 TF_CALL_GPU_NUMBER_TYPES(m) \ 203 TF_CALL_bool(m) TF_CALL_complex64(m) TF_CALL_complex128(m) 204 205 #define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) TF_CALL_float(m) TF_CALL_double(m) 206 207 // Call "m" on all quantized types. 208 // TODO(cwhipkey): include TF_CALL_qint16(m) TF_CALL_quint16(m) 209 #define TF_CALL_QUANTIZED_TYPES(m) \ 210 TF_CALL_qint8(m) TF_CALL_quint8(m) TF_CALL_qint32(m) 211 212 // Types used for save and restore ops. 213 #define TF_CALL_SAVE_RESTORE_TYPES(m) \ 214 TF_CALL_INTEGRAL_TYPES(m) \ 215 TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_complex64(m) \ 216 TF_CALL_complex128(m) TF_CALL_bool(m) TF_CALL_string(m) \ 217 TF_CALL_QUANTIZED_TYPES(m) 218 219 #ifdef TENSORFLOW_SYCL_NO_DOUBLE 220 #define TF_CALL_SYCL_double(m) 221 #else // TENSORFLOW_SYCL_NO_DOUBLE 222 #define TF_CALL_SYCL_double(m) TF_CALL_double(m) 223 #endif // TENSORFLOW_SYCL_NO_DOUBLE 224 225 #ifdef __ANDROID_TYPES_SLIM__ 226 #define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m) 227 #else // __ANDROID_TYPES_SLIM__ 228 #define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m) TF_CALL_SYCL_double(m) 229 #endif // __ANDROID_TYPES_SLIM__ 230 231 #endif // TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_ 232