Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
     17 #define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
     18 // This file is used by cuda code and must remain compilable by nvcc.
     19 
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 typedef Eigen::ThreadPoolDevice CPUDevice;
     22 typedef Eigen::GpuDevice GPUDevice;
     23 
     24 #ifdef TENSORFLOW_USE_SYCL
     25 typedef Eigen::SyclDevice SYCLDevice;
     26 #endif  // TENSORFLOW_USE_SYCL
     27 
     28 #include "tensorflow/core/framework/numeric_types.h"
     29 #include "tensorflow/core/platform/types.h"
     30 
     31 namespace tensorflow {
     32 
     33 // Remap POD types by size to equivalent proxy types. This works
     34 // since all we are doing is copying data around.
     35 struct UnusableProxyType;
     36 template <typename Device, int size>
     37 struct proxy_type_pod {
     38   typedef UnusableProxyType type;
     39 };
     40 template <>
     41 struct proxy_type_pod<CPUDevice, 16> {
     42   typedef ::tensorflow::complex128 type;
     43 };
     44 template <>
     45 struct proxy_type_pod<CPUDevice, 8> {
     46   typedef ::tensorflow::int64 type;
     47 };
     48 template <>
     49 struct proxy_type_pod<CPUDevice, 4> {
     50   typedef ::tensorflow::int32 type;
     51 };
     52 template <>
     53 struct proxy_type_pod<CPUDevice, 2> {
     54   typedef ::tensorflow::int16 type;
     55 };
     56 template <>
     57 struct proxy_type_pod<CPUDevice, 1> {
     58   typedef ::tensorflow::int8 type;
     59 };
     60 template <>
     61 struct proxy_type_pod<GPUDevice, 8> {
     62   typedef double type;
     63 };
     64 template <>
     65 struct proxy_type_pod<GPUDevice, 4> {
     66   typedef float type;
     67 };
     68 template <>
     69 struct proxy_type_pod<GPUDevice, 2> {
     70   typedef Eigen::half type;
     71 };
     72 
     73 #ifdef TENSORFLOW_USE_SYCL
     74 template <>
     75 struct proxy_type_pod<SYCLDevice, 8> {
     76   typedef double type;
     77 };
     78 template <>
     79 struct proxy_type_pod<SYCLDevice, 4> {
     80   typedef float type;
     81 };
     82 #endif  // TENSORFLOW_USE_SYCL
     83 
     84 /// If POD we use proxy_type_pod, otherwise this maps to identiy.
     85 template <typename Device, typename T>
     86 struct proxy_type {
     87   typedef typename std::conditional<
     88       std::is_arithmetic<T>::value,
     89       typename proxy_type_pod<Device, sizeof(T)>::type, T>::type type;
     90   static_assert(sizeof(type) == sizeof(T), "proxy_type_pod is not valid");
     91 };
     92 
     93 /// The active proxy types
     94 #define TF_CALL_CPU_PROXY_TYPES(m)                                     \
     95   TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \
     96       TF_CALL_int8(m) TF_CALL_complex128(m)
     97 #define TF_CALL_GPU_PROXY_TYPES(m) \
     98   TF_CALL_double(m) TF_CALL_float(m) TF_CALL_half(m) TF_CALL_int32(m)
     99 #ifdef TENSORFLOW_USE_SYCL
    100 #define TF_CALL_SYCL_PROXY_TYPES(m) \
    101   TF_CALL_double(m) TF_CALL_float(m) TF_CALL_int32(m)
    102 #endif  // TENSORFLOW_USE_SYCL
    103 }  // namespace tensorflow
    104 
    105 #endif  // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
    106