Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/cast_op.h"
     21 
     22 #include "tensorflow/core/common_runtime/device.h"
     23 #include "tensorflow/core/framework/op.h"
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/types.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 #include "tensorflow/core/platform/macros.h"
     28 #include "tensorflow/core/platform/types.h"
     29 #include "tensorflow/core/util/work_sharder.h"
     30 
     31 #include "tensorflow/core/kernels/cast_op_impl.h"
     32 
     33 namespace tensorflow {
     34 
     35 typedef Eigen::ThreadPoolDevice CPUDevice;
     36 typedef Eigen::GpuDevice GPUDevice;
     37 #ifdef TENSORFLOW_USE_SYCL
     38 typedef Eigen::SyclDevice SYCLDevice;
     39 #endif  // TENSORFLOW_USE_SYCL
     40 
     41 #define CURRY_TYPES2(FN, arg0)   \
     42   FN(arg0, bool);                \
     43   FN(arg0, uint8);               \
     44   FN(arg0, int8);                \
     45   FN(arg0, uint16);              \
     46   FN(arg0, int16);               \
     47   FN(arg0, int32);               \
     48   FN(arg0, int64);               \
     49   FN(arg0, Eigen::half);         \
     50   FN(arg0, float);               \
     51   FN(arg0, double);              \
     52   FN(arg0, std::complex<float>); \
     53   FN(arg0, std::complex<double>)
     54 
     55 CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
     56   OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_));
     57   OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_));
     58 }
     59 
     60 void CastOpBase::Compute(OpKernelContext* ctx) {
     61   const Tensor& inp = ctx->input(0);
     62   if (work_ == nullptr) {
     63     ctx->set_output(0, inp);
     64   } else {
     65     Tensor* out = nullptr;
     66     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
     67     work_(ctx, inp, out);
     68   }
     69 }
     70 
     71 Status CastOpBase::Unimplemented() {
     72   return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ",
     73                                DataTypeString(dst_dtype_), " is not supported");
     74 }
     75 
     76 CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
     77   OP_REQUIRES_OK(ctx, Prepare());
     78 }
     79 
     80 Status CpuCastOp::Prepare() {
     81   if (src_dtype_ == dst_dtype_) {
     82     work_ = nullptr;  // Identity
     83     return Status::OK();
     84   }
     85   if (src_dtype_ == DT_BOOL) {
     86     work_ = GetCpuCastFromBool(dst_dtype_);
     87   } else if (src_dtype_ == DT_UINT8) {
     88     work_ = GetCpuCastFromUint8(dst_dtype_);
     89   } else if (src_dtype_ == DT_INT8) {
     90     work_ = GetCpuCastFromInt8(dst_dtype_);
     91   } else if (src_dtype_ == DT_UINT16) {
     92     work_ = GetCpuCastFromUint16(dst_dtype_);
     93   } else if (src_dtype_ == DT_INT16) {
     94     work_ = GetCpuCastFromInt16(dst_dtype_);
     95   } else if (src_dtype_ == DT_INT32) {
     96     work_ = GetCpuCastFromInt32(dst_dtype_);
     97   } else if (src_dtype_ == DT_INT64) {
     98     work_ = GetCpuCastFromInt64(dst_dtype_);
     99   } else if (src_dtype_ == DT_HALF) {
    100     work_ = GetCpuCastFromHalf(dst_dtype_);
    101   } else if (src_dtype_ == DT_FLOAT) {
    102     work_ = GetCpuCastFromFloat(dst_dtype_);
    103   } else if (src_dtype_ == DT_DOUBLE) {
    104     work_ = GetCpuCastFromDouble(dst_dtype_);
    105   } else if (src_dtype_ == DT_COMPLEX64) {
    106     work_ = GetCpuCastFromComplex64(dst_dtype_);
    107   } else if (src_dtype_ == DT_COMPLEX128) {
    108     work_ = GetCpuCastFromComplex128(dst_dtype_);
    109   } else if (src_dtype_ == DT_BFLOAT16) {
    110     work_ = GetCpuCastFromBfloat(dst_dtype_);
    111   }
    112 
    113   // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a
    114   // bottleneck, we could probably implement specialized support for
    115   // vectorized versions (not the least based on F16C for Haswell
    116   // or newer).
    117 
    118   return work_ == nullptr ? Unimplemented() : Status::OK();
    119 }
    120 
    121 #if GOOGLE_CUDA
    122 class GpuCastOp : public CastOpBase {
    123  public:
    124   explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
    125     OP_REQUIRES_OK(ctx, Prepare());
    126   }
    127 
    128  private:
    129   Status Prepare() {
    130     if (src_dtype_ == dst_dtype_) {
    131       work_ = nullptr;  // Identity
    132       return Status::OK();
    133     }
    134     if (src_dtype_ == DT_BOOL) {
    135       work_ = GetGpuCastFromBool(dst_dtype_);
    136     } else if (src_dtype_ == DT_UINT8) {
    137       work_ = GetGpuCastFromUint8(dst_dtype_);
    138     } else if (src_dtype_ == DT_INT8) {
    139       work_ = GetGpuCastFromInt8(dst_dtype_);
    140     } else if (src_dtype_ == DT_UINT16) {
    141       work_ = GetGpuCastFromUint16(dst_dtype_);
    142     } else if (src_dtype_ == DT_INT16) {
    143       work_ = GetGpuCastFromInt16(dst_dtype_);
    144     } else if (src_dtype_ == DT_INT32) {
    145       work_ = GetGpuCastFromInt32(dst_dtype_);
    146     } else if (src_dtype_ == DT_INT64) {
    147       work_ = GetGpuCastFromInt64(dst_dtype_);
    148     } else if (src_dtype_ == DT_HALF) {
    149       work_ = GetGpuCastFromHalf(dst_dtype_);
    150     } else if (src_dtype_ == DT_FLOAT) {
    151       work_ = GetGpuCastFromFloat(dst_dtype_);
    152     } else if (src_dtype_ == DT_DOUBLE) {
    153       work_ = GetGpuCastFromDouble(dst_dtype_);
    154     } else if (src_dtype_ == DT_COMPLEX64) {
    155       work_ = GetGpuCastFromComplex64(dst_dtype_);
    156     } else if (src_dtype_ == DT_COMPLEX128) {
    157       work_ = GetGpuCastFromComplex128(dst_dtype_);
    158     } else if (src_dtype_ == DT_BFLOAT16) {
    159       work_ = GetGpuCastFromBfloat(dst_dtype_);
    160     }
    161 
    162     return work_ == nullptr ? Unimplemented() : Status::OK();
    163   }
    164 };
    165 #endif  // GOOGLE_CUDA
    166 
    167 #undef CAST_CASE
    168 
    169 REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
    170 
    171 #if GOOGLE_CUDA
    172 #define REGISTER_CAST_GPU(srctype, dsttype)                    \
    173   REGISTER_KERNEL_BUILDER(Name("Cast")                         \
    174                               .TypeConstraint<srctype>("SrcT") \
    175                               .TypeConstraint<dsttype>("DstT") \
    176                               .Device(DEVICE_GPU),             \
    177                           GpuCastOp)
    178 
    179 CURRY_TYPES2(REGISTER_CAST_GPU, bool);
    180 CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
    181 CURRY_TYPES2(REGISTER_CAST_GPU, int8);
    182 CURRY_TYPES2(REGISTER_CAST_GPU, uint16);
    183 CURRY_TYPES2(REGISTER_CAST_GPU, int16);
    184 CURRY_TYPES2(REGISTER_CAST_GPU, int32);
    185 CURRY_TYPES2(REGISTER_CAST_GPU, int64);
    186 CURRY_TYPES2(REGISTER_CAST_GPU, Eigen::half);
    187 CURRY_TYPES2(REGISTER_CAST_GPU, float);
    188 CURRY_TYPES2(REGISTER_CAST_GPU, double);
    189 CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<float>);
    190 CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<double>);
    191 REGISTER_CAST_GPU(float, bfloat16);
    192 REGISTER_CAST_GPU(bfloat16, float);
    193 
    194 #undef REGISTER_CAST_GPU
    195 #endif  // GOOGLE_CUDA
    196 
    197 #ifdef TENSORFLOW_USE_SYCL
    198 class SyclCastOp : public CastOpBase {
    199  public:
    200   explicit SyclCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
    201     OP_REQUIRES_OK(ctx, Prepare());
    202   }
    203 
    204  private:
    205   Status Prepare() {
    206     if (src_dtype_ == dst_dtype_) {
    207       work_ = nullptr;  // Identity
    208       return Status::OK();
    209     }
    210     if (src_dtype_ == DT_BOOL) {
    211       work_ = GetSyclCastFromBool(dst_dtype_);
    212     } else if (src_dtype_ == DT_INT32) {
    213       work_ = GetSyclCastFromInt32(dst_dtype_);
    214     } else if (src_dtype_ == DT_INT64) {
    215       work_ = GetSyclCastFromInt64(dst_dtype_);
    216     } else if (src_dtype_ == DT_FLOAT) {
    217       work_ = GetSyclCastFromFloat(dst_dtype_);
    218     } else if (src_dtype_ == DT_DOUBLE) {
    219       work_ = GetSyclCastFromDouble(dst_dtype_);
    220     }
    221 
    222     return work_ == nullptr ? Unimplemented() : Status::OK();
    223   }
    224 };
    225 
    226 #define REGISTER_CAST_SYCL(srctype, dsttype)                   \
    227   REGISTER_KERNEL_BUILDER(Name("Cast")                         \
    228                               .TypeConstraint<srctype>("SrcT") \
    229                               .TypeConstraint<dsttype>("DstT") \
    230                               .Device(DEVICE_SYCL),            \
    231                           SyclCastOp)
    232 CURRY_TYPES2(REGISTER_CAST_SYCL, bool);
    233 CURRY_TYPES2(REGISTER_CAST_SYCL, int32);
    234 CURRY_TYPES2(REGISTER_CAST_SYCL, int64);
    235 CURRY_TYPES2(REGISTER_CAST_SYCL, float);
    236 CURRY_TYPES2(REGISTER_CAST_SYCL, double);
    237 
    238 #undef REGISTER_CAST_SYCL
    239 
    240 #endif  // TENSORFLOW_USE_SYCL
    241 
    242 #undef CURRY_TYPES2
    243 
    244 // HostCast differs from Cast in that its input and output are in host memory.
    245 REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp);
    246 REGISTER_KERNEL_BUILDER(
    247     Name("_HostCast").Device(DEVICE_GPU).HostMemory("x").HostMemory("y"),
    248     CpuCastOp);
    249 #ifdef TENSORFLOW_USE_SYCL
    250 REGISTER_KERNEL_BUILDER(
    251     Name("_HostCast").Device(DEVICE_SYCL).HostMemory("x").HostMemory("y"),
    252     CpuCastOp);
    253 #endif  // TENSORFLOW_USE_SYCL
    254 }  // end namespace tensorflow
    255