Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_
     17 #define TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_
     18 // This file is used by cuda code and must remain compilable by nvcc.
     19 
     20 #include "tensorflow/core/framework/numeric_types.h"
     21 #include "tensorflow/core/framework/resource_handle.h"
     22 #include "tensorflow/core/framework/variant.h"
     23 #include "tensorflow/core/platform/types.h"
     24 
     25 // Two sets of macros:
     26 // - TF_CALL_float, TF_CALL_double, etc. which call the given macro with
     27 //   the type name as the only parameter - except on platforms for which
     28 //   the type should not be included.
     29 // - Macros to apply another macro to lists of supported types. These also call
     30 //   into TF_CALL_float, TF_CALL_double, etc. so they filter by target platform
     31 //   as well.
     32 // If you change the lists of types, please also update the list in types.cc.
     33 //
     34 // See example uses of these macros in core/ops.
     35 //
     36 //
     37 // Each of these TF_CALL_XXX_TYPES(m) macros invokes the macro "m" multiple
     38 // times by passing each invocation a data type supported by TensorFlow.
     39 //
     40 // The different variations pass different subsets of the types.
     41 // TF_CALL_ALL_TYPES(m) applied "m" to all types supported by TensorFlow.
     42 // The set of types depends on the compilation platform.
     43 //.
     44 // This can be used to register a different template instantiation of
     45 // an OpKernel for different signatures, e.g.:
     46 /*
     47    #define REGISTER_PARTITION(type)                                      \
     48      REGISTER_KERNEL_BUILDER(                                            \
     49          Name("Partition").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
     50          PartitionOp<type>);
     51    TF_CALL_ALL_TYPES(REGISTER_PARTITION)
     52    #undef REGISTER_PARTITION
     53 */
     54 
     55 #if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) || \
     56     defined(ANDROID_TEGRA)
     57 
     58 // All types are supported, so all macros are invoked.
     59 //
     60 // Note: macros are defined in same order as types in types.proto, for
     61 // readability.
     62 #define TF_CALL_float(m) m(float)
     63 #define TF_CALL_double(m) m(double)
     64 #define TF_CALL_int32(m) m(::tensorflow::int32)
     65 #define TF_CALL_uint32(m) m(::tensorflow::uint32)
     66 #define TF_CALL_uint8(m) m(::tensorflow::uint8)
     67 #define TF_CALL_int16(m) m(::tensorflow::int16)
     68 
     69 #define TF_CALL_int8(m) m(::tensorflow::int8)
     70 #define TF_CALL_string(m) m(string)
     71 #define TF_CALL_resource(m) m(::tensorflow::ResourceHandle)
     72 #define TF_CALL_variant(m) m(::tensorflow::Variant)
     73 #define TF_CALL_complex64(m) m(::tensorflow::complex64)
     74 #define TF_CALL_int64(m) m(::tensorflow::int64)
     75 #define TF_CALL_uint64(m) m(::tensorflow::uint64)
     76 #define TF_CALL_bool(m) m(bool)
     77 
     78 #define TF_CALL_qint8(m) m(::tensorflow::qint8)
     79 #define TF_CALL_quint8(m) m(::tensorflow::quint8)
     80 #define TF_CALL_qint32(m) m(::tensorflow::qint32)
     81 #define TF_CALL_bfloat16(m) m(::tensorflow::bfloat16)
     82 #define TF_CALL_qint16(m) m(::tensorflow::qint16)
     83 
     84 #define TF_CALL_quint16(m) m(::tensorflow::quint16)
     85 #define TF_CALL_uint16(m) m(::tensorflow::uint16)
     86 #define TF_CALL_complex128(m) m(::tensorflow::complex128)
     87 #define TF_CALL_half(m) m(Eigen::half)
     88 
     89 #elif defined(__ANDROID_TYPES_FULL__)
     90 
     91 // Only string, half, float, int32, int64, bool, and quantized types
     92 // supported.
     93 #define TF_CALL_float(m) m(float)
     94 #define TF_CALL_double(m)
     95 #define TF_CALL_int32(m) m(::tensorflow::int32)
     96 #define TF_CALL_uint32(m)
     97 #define TF_CALL_uint8(m)
     98 #define TF_CALL_int16(m)
     99 
    100 #define TF_CALL_int8(m)
    101 #define TF_CALL_string(m) m(string)
    102 #define TF_CALL_resource(m)
    103 #define TF_CALL_variant(m)
    104 #define TF_CALL_complex64(m)
    105 #define TF_CALL_int64(m) m(::tensorflow::int64)
    106 #define TF_CALL_uint64(m)
    107 #define TF_CALL_bool(m) m(bool)
    108 
    109 #define TF_CALL_qint8(m) m(::tensorflow::qint8)
    110 #define TF_CALL_quint8(m) m(::tensorflow::quint8)
    111 #define TF_CALL_qint32(m) m(::tensorflow::qint32)
    112 #define TF_CALL_bfloat16(m)
    113 #define TF_CALL_qint16(m) m(::tensorflow::qint16)
    114 
    115 #define TF_CALL_quint16(m) m(::tensorflow::quint16)
    116 #define TF_CALL_uint16(m)
    117 #define TF_CALL_complex128(m)
    118 #define TF_CALL_half(m) m(Eigen::half)
    119 
    120 #else  // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__)
    121 
    122 // Only float, int32, and bool are supported.
    123 #define TF_CALL_float(m) m(float)
    124 #define TF_CALL_double(m)
    125 #define TF_CALL_int32(m) m(::tensorflow::int32)
    126 #define TF_CALL_uint32(m)
    127 #define TF_CALL_uint8(m)
    128 #define TF_CALL_int16(m)
    129 
    130 #define TF_CALL_int8(m)
    131 #define TF_CALL_string(m)
    132 #define TF_CALL_resource(m)
    133 #define TF_CALL_variant(m)
    134 #define TF_CALL_complex64(m)
    135 #define TF_CALL_int64(m)
    136 #define TF_CALL_uint64(m)
    137 #define TF_CALL_bool(m) m(bool)
    138 
    139 #define TF_CALL_qint8(m)
    140 #define TF_CALL_quint8(m)
    141 #define TF_CALL_qint32(m)
    142 #define TF_CALL_bfloat16(m)
    143 #define TF_CALL_qint16(m)
    144 
    145 #define TF_CALL_quint16(m)
    146 #define TF_CALL_uint16(m)
    147 #define TF_CALL_complex128(m)
    148 #define TF_CALL_half(m)
    149 
    150 #endif  // defined(IS_MOBILE_PLATFORM)  - end of TF_CALL_type defines
    151 
    152 // Defines for sets of types.
    153 
    154 // TODO(b/111604096): Add uint32 and uint64 to TF_CALL_INTEGRAL_TYPES.
    155 //
    156 // The uint32 and uint64 types were introduced in 10/2017 to be used via XLA and
    157 // thus were not included in TF_CALL_INTEGRAL_TYPES. Including them in
    158 // TF_CALL_INTEGRAL_TYPES should only happen after evaluating the effect on the
    159 // TF binary size and performance.
    160 #define TF_CALL_INTEGRAL_TYPES(m)                                      \
    161   TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \
    162       TF_CALL_uint8(m) TF_CALL_int8(m)
    163 
    164 #define TF_CALL_FLOAT_TYPES(m) \
    165   TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
    166 
    167 #define TF_CALL_REAL_NUMBER_TYPES(m) \
    168   TF_CALL_INTEGRAL_TYPES(m)          \
    169   TF_CALL_FLOAT_TYPES(m)
    170 
    171 #define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \
    172   TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
    173 
    174 #define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m)                              \
    175   TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)   \
    176       TF_CALL_int64(m) TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) \
    177           TF_CALL_int8(m)
    178 
    179 // Call "m" for all number types, including complex64 and complex128.
    180 #define TF_CALL_NUMBER_TYPES(m) \
    181   TF_CALL_REAL_NUMBER_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
    182 
    183 #define TF_CALL_NUMBER_TYPES_NO_INT32(m) \
    184   TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m)  \
    185   TF_CALL_complex64(m) TF_CALL_complex128(m)
    186 
    187 #define TF_CALL_POD_TYPES(m) TF_CALL_NUMBER_TYPES(m) TF_CALL_bool(m)
    188 
    189 // Call "m" on all types.
    190 #define TF_CALL_ALL_TYPES(m) \
    191   TF_CALL_POD_TYPES(m) TF_CALL_string(m) TF_CALL_resource(m) TF_CALL_variant(m)
    192 
    193 // Call "m" on POD and string types.
    194 #define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_string(m)
    195 
    196 // Call "m" on all number types supported on GPU.
    197 #define TF_CALL_GPU_NUMBER_TYPES(m) \
    198   TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
    199 
    200 // Call "m" on all types supported on GPU.
    201 #define TF_CALL_GPU_ALL_TYPES(m) \
    202   TF_CALL_GPU_NUMBER_TYPES(m)    \
    203   TF_CALL_bool(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
    204 
    205 #define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) TF_CALL_float(m) TF_CALL_double(m)
    206 
    207 // Call "m" on all quantized types.
    208 // TODO(cwhipkey): include TF_CALL_qint16(m) TF_CALL_quint16(m)
    209 #define TF_CALL_QUANTIZED_TYPES(m) \
    210   TF_CALL_qint8(m) TF_CALL_quint8(m) TF_CALL_qint32(m)
    211 
    212 // Types used for save and restore ops.
    213 #define TF_CALL_SAVE_RESTORE_TYPES(m)                                     \
    214   TF_CALL_INTEGRAL_TYPES(m)                                               \
    215   TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_complex64(m) \
    216       TF_CALL_complex128(m) TF_CALL_bool(m) TF_CALL_string(m)             \
    217           TF_CALL_QUANTIZED_TYPES(m)
    218 
    219 #ifdef TENSORFLOW_SYCL_NO_DOUBLE
    220 #define TF_CALL_SYCL_double(m)
    221 #else  // TENSORFLOW_SYCL_NO_DOUBLE
    222 #define TF_CALL_SYCL_double(m) TF_CALL_double(m)
    223 #endif  // TENSORFLOW_SYCL_NO_DOUBLE
    224 
    225 #ifdef __ANDROID_TYPES_SLIM__
    226 #define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m)
    227 #else  // __ANDROID_TYPES_SLIM__
    228 #define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m) TF_CALL_SYCL_double(m)
    229 #endif  // __ANDROID_TYPES_SLIM__
    230 
    231 #endif  // TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_
    232