1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ 17 #define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ 18 // This file is used by cuda code and must remain compilable by nvcc. 19 20 #include "tensorflow/core/framework/numeric_types.h" 21 #include "tensorflow/core/framework/resource_handle.h" 22 #include "tensorflow/core/framework/variant.h" 23 #include "tensorflow/core/platform/types.h" 24 25 // Two sets of macros: 26 // - TF_CALL_float, TF_CALL_double, etc. which call the given macro with 27 // the type name as the only parameter - except on platforms for which 28 // the type should not be included. 29 // - Macros to apply another macro to lists of supported types. These also call 30 // into TF_CALL_float, TF_CALL_double, etc. so they filter by target platform 31 // as well. 32 // If you change the lists of types, please also update the list in types.cc. 33 // 34 // See example uses of these macros in core/ops. 35 // 36 // 37 // Each of these TF_CALL_XXX_TYPES(m) macros invokes the macro "m" multiple 38 // times by passing each invocation a data type supported by TensorFlow. 39 // 40 // The different variations pass different subsets of the types. 41 // TF_CALL_ALL_TYPES(m) applied "m" to all types supported by TensorFlow. 42 // The set of types depends on the compilation platform. 43 //. 44 // This can be used to register a different template instantiation of 45 // an OpKernel for different signatures, e.g.: 46 /* 47 #define REGISTER_PARTITION(type) \ 48 REGISTER_KERNEL_BUILDER( \ 49 Name("Partition").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 50 PartitionOp<type>); 51 TF_CALL_ALL_TYPES(REGISTER_PARTITION) 52 #undef REGISTER_PARTITION 53 */ 54 55 #if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) || \ 56 defined(ANDROID_TEGRA) 57 58 // All types are supported, so all macros are invoked. 59 // 60 // Note: macros are defined in same order as types in types.proto, for 61 // readability. 62 #define TF_CALL_float(m) m(float) 63 #define TF_CALL_double(m) m(double) 64 #define TF_CALL_int32(m) m(::tensorflow::int32) 65 #define TF_CALL_uint32(m) m(::tensorflow::uint32) 66 #define TF_CALL_uint8(m) m(::tensorflow::uint8) 67 #define TF_CALL_int16(m) m(::tensorflow::int16) 68 69 #define TF_CALL_int8(m) m(::tensorflow::int8) 70 #define TF_CALL_string(m) m(string) 71 #define TF_CALL_resource(m) m(::tensorflow::ResourceHandle) 72 #define TF_CALL_variant(m) m(::tensorflow::Variant) 73 #define TF_CALL_complex64(m) m(::tensorflow::complex64) 74 #define TF_CALL_int64(m) m(::tensorflow::int64) 75 #define TF_CALL_uint64(m) m(::tensorflow::uint64) 76 #define TF_CALL_bool(m) m(bool) 77 78 #define TF_CALL_qint8(m) m(::tensorflow::qint8) 79 #define TF_CALL_quint8(m) m(::tensorflow::quint8) 80 #define TF_CALL_qint32(m) m(::tensorflow::qint32) 81 #define TF_CALL_bfloat16(m) m(::tensorflow::bfloat16) 82 #define TF_CALL_qint16(m) m(::tensorflow::qint16) 83 84 #define TF_CALL_quint16(m) m(::tensorflow::quint16) 85 #define TF_CALL_uint16(m) m(::tensorflow::uint16) 86 #define TF_CALL_complex128(m) m(::tensorflow::complex128) 87 #define TF_CALL_half(m) m(Eigen::half) 88 89 #elif defined(__ANDROID_TYPES_FULL__) 90 91 // Only string, half, float, int32, int64, bool, and quantized types 92 // supported. 93 #define TF_CALL_float(m) m(float) 94 #define TF_CALL_double(m) 95 #define TF_CALL_int32(m) m(::tensorflow::int32) 96 #define TF_CALL_uint32(m) 97 #define TF_CALL_uint8(m) 98 #define TF_CALL_int16(m) 99 100 #define TF_CALL_int8(m) 101 #define TF_CALL_string(m) m(string) 102 #define TF_CALL_resource(m) 103 #define TF_CALL_variant(m) 104 #define TF_CALL_complex64(m) 105 #define TF_CALL_int64(m) m(::tensorflow::int64) 106 #define TF_CALL_uint64(m) 107 #define TF_CALL_bool(m) m(bool) 108 109 #define TF_CALL_qint8(m) m(::tensorflow::qint8) 110 #define TF_CALL_quint8(m) m(::tensorflow::quint8) 111 #define TF_CALL_qint32(m) m(::tensorflow::qint32) 112 #define TF_CALL_bfloat16(m) 113 #define TF_CALL_qint16(m) m(::tensorflow::qint16) 114 115 #define TF_CALL_quint16(m) m(::tensorflow::quint16) 116 #define TF_CALL_uint16(m) 117 #define TF_CALL_complex128(m) 118 #define TF_CALL_half(m) m(Eigen::half) 119 120 #else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__) 121 122 // Only float, int32, and bool are supported. 123 #define TF_CALL_float(m) m(float) 124 #define TF_CALL_double(m) 125 #define TF_CALL_int32(m) m(::tensorflow::int32) 126 #define TF_CALL_uint32(m) 127 #define TF_CALL_uint8(m) 128 #define TF_CALL_int16(m) 129 130 #define TF_CALL_int8(m) 131 #define TF_CALL_string(m) 132 #define TF_CALL_resource(m) 133 #define TF_CALL_variant(m) 134 #define TF_CALL_complex64(m) 135 #define TF_CALL_int64(m) 136 #define TF_CALL_uint64(m) 137 #define TF_CALL_bool(m) m(bool) 138 139 #define TF_CALL_qint8(m) 140 #define TF_CALL_quint8(m) 141 #define TF_CALL_qint32(m) 142 #define TF_CALL_bfloat16(m) 143 #define TF_CALL_qint16(m) 144 145 #define TF_CALL_quint16(m) 146 #define TF_CALL_uint16(m) 147 #define TF_CALL_complex128(m) 148 #define TF_CALL_half(m) 149 150 #endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines 151 152 // Defines for sets of types. 153 154 #define TF_CALL_INTEGRAL_TYPES(m) \ 155 TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \ 156 TF_CALL_uint8(m) TF_CALL_int8(m) 157 158 #define TF_CALL_REAL_NUMBER_TYPES(m) \ 159 TF_CALL_INTEGRAL_TYPES(m) \ 160 TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) 161 162 #define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \ 163 TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) 164 165 #define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ 166 TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \ 167 TF_CALL_int64(m) TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) \ 168 TF_CALL_int8(m) 169 170 // Call "m" for all number types, including complex64 and complex128. 171 #define TF_CALL_NUMBER_TYPES(m) \ 172 TF_CALL_REAL_NUMBER_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m) 173 174 #define TF_CALL_NUMBER_TYPES_NO_INT32(m) \ 175 TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ 176 TF_CALL_complex64(m) TF_CALL_complex128(m) 177 178 #define TF_CALL_POD_TYPES(m) TF_CALL_NUMBER_TYPES(m) TF_CALL_bool(m) 179 180 // Call "m" on all types. 181 #define TF_CALL_ALL_TYPES(m) \ 182 TF_CALL_POD_TYPES(m) TF_CALL_string(m) TF_CALL_resource(m) TF_CALL_variant(m) 183 184 // Call "m" on POD and string types. 185 #define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_string(m) 186 187 // Call "m" on all number types supported on GPU. 188 #define TF_CALL_GPU_NUMBER_TYPES(m) \ 189 TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) 190 191 // Call "m" on all types supported on GPU. 192 #define TF_CALL_GPU_ALL_TYPES(m) \ 193 TF_CALL_GPU_NUMBER_TYPES(m) \ 194 TF_CALL_bool(m) TF_CALL_complex64(m) TF_CALL_complex128(m) 195 196 #define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) TF_CALL_float(m) TF_CALL_double(m) 197 198 // Call "m" on all quantized types. 199 // TODO(cwhipkey): include TF_CALL_qint16(m) TF_CALL_quint16(m) 200 #define TF_CALL_QUANTIZED_TYPES(m) \ 201 TF_CALL_qint8(m) TF_CALL_quint8(m) TF_CALL_qint32(m) 202 203 // Types used for save and restore ops. 204 #define TF_CALL_SAVE_RESTORE_TYPES(m) \ 205 TF_CALL_INTEGRAL_TYPES(m) \ 206 TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_complex64(m) \ 207 TF_CALL_complex128(m) TF_CALL_bool(m) TF_CALL_string(m) \ 208 TF_CALL_QUANTIZED_TYPES(m) 209 210 #ifdef TENSORFLOW_SYCL_NO_DOUBLE 211 #define TF_CALL_SYCL_double(m) 212 #else // TENSORFLOW_SYCL_NO_DOUBLE 213 #define TF_CALL_SYCL_double(m) TF_CALL_double(m) 214 #endif // TENSORFLOW_SYCL_NO_DOUBLE 215 216 #ifdef __ANDROID_TYPES_SLIM__ 217 #define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m) 218 #else // __ANDROID_TYPES_SLIM__ 219 #define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m) TF_CALL_SYCL_double(m) 220 #endif // __ANDROID_TYPES_SLIM__ 221 222 #endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ 223