Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
     17 #define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
     18 // This file is used by cuda code and must remain compilable by nvcc.
     19 
     20 #include "tensorflow/core/framework/numeric_types.h"
     21 #include "tensorflow/core/framework/resource_handle.h"
     22 #include "tensorflow/core/framework/variant.h"
     23 #include "tensorflow/core/platform/types.h"
     24 
     25 // Two sets of macros:
     26 // - TF_CALL_float, TF_CALL_double, etc. which call the given macro with
     27 //   the type name as the only parameter - except on platforms for which
     28 //   the type should not be included.
     29 // - Macros to apply another macro to lists of supported types. These also call
     30 //   into TF_CALL_float, TF_CALL_double, etc. so they filter by target platform
     31 //   as well.
     32 // If you change the lists of types, please also update the list in types.cc.
     33 //
     34 // See example uses of these macros in core/ops.
     35 //
     36 //
     37 // Each of these TF_CALL_XXX_TYPES(m) macros invokes the macro "m" multiple
     38 // times by passing each invocation a data type supported by TensorFlow.
     39 //
     40 // The different variations pass different subsets of the types.
     41 // TF_CALL_ALL_TYPES(m) applied "m" to all types supported by TensorFlow.
     42 // The set of types depends on the compilation platform.
     43 //.
     44 // This can be used to register a different template instantiation of
     45 // an OpKernel for different signatures, e.g.:
     46 /*
     47    #define REGISTER_PARTITION(type)                                      \
     48      REGISTER_KERNEL_BUILDER(                                            \
     49          Name("Partition").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
     50          PartitionOp<type>);
     51    TF_CALL_ALL_TYPES(REGISTER_PARTITION)
     52    #undef REGISTER_PARTITION
     53 */
     54 
     55 #if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) || \
     56     defined(ANDROID_TEGRA)
     57 
     58 // All types are supported, so all macros are invoked.
     59 //
     60 // Note: macros are defined in same order as types in types.proto, for
     61 // readability.
     62 #define TF_CALL_float(m) m(float)
     63 #define TF_CALL_double(m) m(double)
     64 #define TF_CALL_int32(m) m(::tensorflow::int32)
     65 #define TF_CALL_uint32(m) m(::tensorflow::uint32)
     66 #define TF_CALL_uint8(m) m(::tensorflow::uint8)
     67 #define TF_CALL_int16(m) m(::tensorflow::int16)
     68 
     69 #define TF_CALL_int8(m) m(::tensorflow::int8)
     70 #define TF_CALL_string(m) m(string)
     71 #define TF_CALL_resource(m) m(::tensorflow::ResourceHandle)
     72 #define TF_CALL_variant(m) m(::tensorflow::Variant)
     73 #define TF_CALL_complex64(m) m(::tensorflow::complex64)
     74 #define TF_CALL_int64(m) m(::tensorflow::int64)
     75 #define TF_CALL_uint64(m) m(::tensorflow::uint64)
     76 #define TF_CALL_bool(m) m(bool)
     77 
     78 #define TF_CALL_qint8(m) m(::tensorflow::qint8)
     79 #define TF_CALL_quint8(m) m(::tensorflow::quint8)
     80 #define TF_CALL_qint32(m) m(::tensorflow::qint32)
     81 #define TF_CALL_bfloat16(m) m(::tensorflow::bfloat16)
     82 #define TF_CALL_qint16(m) m(::tensorflow::qint16)
     83 
     84 #define TF_CALL_quint16(m) m(::tensorflow::quint16)
     85 #define TF_CALL_uint16(m) m(::tensorflow::uint16)
     86 #define TF_CALL_complex128(m) m(::tensorflow::complex128)
     87 #define TF_CALL_half(m) m(Eigen::half)
     88 
     89 #elif defined(__ANDROID_TYPES_FULL__)
     90 
     91 // Only string, half, float, int32, int64, bool, and quantized types
     92 // supported.
     93 #define TF_CALL_float(m) m(float)
     94 #define TF_CALL_double(m)
     95 #define TF_CALL_int32(m) m(::tensorflow::int32)
     96 #define TF_CALL_uint32(m)
     97 #define TF_CALL_uint8(m)
     98 #define TF_CALL_int16(m)
     99 
    100 #define TF_CALL_int8(m)
    101 #define TF_CALL_string(m) m(string)
    102 #define TF_CALL_resource(m)
    103 #define TF_CALL_variant(m)
    104 #define TF_CALL_complex64(m)
    105 #define TF_CALL_int64(m) m(::tensorflow::int64)
    106 #define TF_CALL_uint64(m)
    107 #define TF_CALL_bool(m) m(bool)
    108 
    109 #define TF_CALL_qint8(m) m(::tensorflow::qint8)
    110 #define TF_CALL_quint8(m) m(::tensorflow::quint8)
    111 #define TF_CALL_qint32(m) m(::tensorflow::qint32)
    112 #define TF_CALL_bfloat16(m)
    113 #define TF_CALL_qint16(m) m(::tensorflow::qint16)
    114 
    115 #define TF_CALL_quint16(m) m(::tensorflow::quint16)
    116 #define TF_CALL_uint16(m)
    117 #define TF_CALL_complex128(m)
    118 #define TF_CALL_half(m) m(Eigen::half)
    119 
    120 #else  // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__)
    121 
    122 // Only float, int32, and bool are supported.
    123 #define TF_CALL_float(m) m(float)
    124 #define TF_CALL_double(m)
    125 #define TF_CALL_int32(m) m(::tensorflow::int32)
    126 #define TF_CALL_uint32(m)
    127 #define TF_CALL_uint8(m)
    128 #define TF_CALL_int16(m)
    129 
    130 #define TF_CALL_int8(m)
    131 #define TF_CALL_string(m)
    132 #define TF_CALL_resource(m)
    133 #define TF_CALL_variant(m)
    134 #define TF_CALL_complex64(m)
    135 #define TF_CALL_int64(m)
    136 #define TF_CALL_uint64(m)
    137 #define TF_CALL_bool(m) m(bool)
    138 
    139 #define TF_CALL_qint8(m)
    140 #define TF_CALL_quint8(m)
    141 #define TF_CALL_qint32(m)
    142 #define TF_CALL_bfloat16(m)
    143 #define TF_CALL_qint16(m)
    144 
    145 #define TF_CALL_quint16(m)
    146 #define TF_CALL_uint16(m)
    147 #define TF_CALL_complex128(m)
    148 #define TF_CALL_half(m)
    149 
    150 #endif  // defined(IS_MOBILE_PLATFORM)  - end of TF_CALL_type defines
    151 
    152 // Defines for sets of types.
    153 
    154 #define TF_CALL_INTEGRAL_TYPES(m)                                      \
    155   TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \
    156       TF_CALL_uint8(m) TF_CALL_int8(m)
    157 
    158 #define TF_CALL_REAL_NUMBER_TYPES(m) \
    159   TF_CALL_INTEGRAL_TYPES(m)          \
    160   TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
    161 
    162 #define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \
    163   TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
    164 
    165 #define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m)                              \
    166   TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)   \
    167       TF_CALL_int64(m) TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) \
    168           TF_CALL_int8(m)
    169 
    170 // Call "m" for all number types, including complex64 and complex128.
    171 #define TF_CALL_NUMBER_TYPES(m) \
    172   TF_CALL_REAL_NUMBER_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
    173 
    174 #define TF_CALL_NUMBER_TYPES_NO_INT32(m) \
    175   TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m)  \
    176   TF_CALL_complex64(m) TF_CALL_complex128(m)
    177 
    178 #define TF_CALL_POD_TYPES(m) TF_CALL_NUMBER_TYPES(m) TF_CALL_bool(m)
    179 
    180 // Call "m" on all types.
    181 #define TF_CALL_ALL_TYPES(m) \
    182   TF_CALL_POD_TYPES(m) TF_CALL_string(m) TF_CALL_resource(m) TF_CALL_variant(m)
    183 
    184 // Call "m" on POD and string types.
    185 #define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_string(m)
    186 
    187 // Call "m" on all number types supported on GPU.
    188 #define TF_CALL_GPU_NUMBER_TYPES(m) \
    189   TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
    190 
    191 // Call "m" on all types supported on GPU.
    192 #define TF_CALL_GPU_ALL_TYPES(m) \
    193   TF_CALL_GPU_NUMBER_TYPES(m)    \
    194   TF_CALL_bool(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
    195 
    196 #define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) TF_CALL_float(m) TF_CALL_double(m)
    197 
    198 // Call "m" on all quantized types.
    199 // TODO(cwhipkey): include TF_CALL_qint16(m) TF_CALL_quint16(m)
    200 #define TF_CALL_QUANTIZED_TYPES(m) \
    201   TF_CALL_qint8(m) TF_CALL_quint8(m) TF_CALL_qint32(m)
    202 
    203 // Types used for save and restore ops.
    204 #define TF_CALL_SAVE_RESTORE_TYPES(m)                                     \
    205   TF_CALL_INTEGRAL_TYPES(m)                                               \
    206   TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_complex64(m) \
    207       TF_CALL_complex128(m) TF_CALL_bool(m) TF_CALL_string(m)             \
    208           TF_CALL_QUANTIZED_TYPES(m)
    209 
    210 #ifdef TENSORFLOW_SYCL_NO_DOUBLE
    211 #define TF_CALL_SYCL_double(m)
    212 #else  // TENSORFLOW_SYCL_NO_DOUBLE
    213 #define TF_CALL_SYCL_double(m) TF_CALL_double(m)
    214 #endif  // TENSORFLOW_SYCL_NO_DOUBLE
    215 
    216 #ifdef __ANDROID_TYPES_SLIM__
    217 #define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m)
    218 #else  // __ANDROID_TYPES_SLIM__
    219 #define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m) TF_CALL_SYCL_double(m)
    220 #endif  // __ANDROID_TYPES_SLIM__
    221 
    222 #endif  // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
    223