1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/kernels/cast_op.h" 21 22 #include "tensorflow/core/common_runtime/device.h" 23 #include "tensorflow/core/framework/op.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/types.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/macros.h" 28 #include "tensorflow/core/platform/types.h" 29 #include "tensorflow/core/util/work_sharder.h" 30 31 #include "tensorflow/core/kernels/cast_op_impl.h" 32 33 namespace tensorflow { 34 35 typedef Eigen::ThreadPoolDevice CPUDevice; 36 typedef Eigen::GpuDevice GPUDevice; 37 #ifdef TENSORFLOW_USE_SYCL 38 typedef Eigen::SyclDevice SYCLDevice; 39 #endif // TENSORFLOW_USE_SYCL 40 41 #define CURRY_TYPES2(FN, arg0) \ 42 FN(arg0, bool); \ 43 FN(arg0, uint8); \ 44 FN(arg0, int8); \ 45 FN(arg0, uint16); \ 46 FN(arg0, int16); \ 47 FN(arg0, int32); \ 48 FN(arg0, int64); \ 49 FN(arg0, Eigen::half); \ 50 FN(arg0, float); \ 51 FN(arg0, double); \ 52 FN(arg0, std::complex<float>); \ 53 FN(arg0, std::complex<double>) 54 55 CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) { 56 OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_)); 57 OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_)); 58 } 59 60 void CastOpBase::Compute(OpKernelContext* ctx) { 61 const Tensor& inp = ctx->input(0); 62 if (work_ == nullptr) { 63 ctx->set_output(0, inp); 64 } else { 65 Tensor* out = nullptr; 66 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); 67 work_(ctx, inp, out); 68 } 69 } 70 71 Status CastOpBase::Unimplemented() { 72 return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ", 73 DataTypeString(dst_dtype_), " is not supported"); 74 } 75 76 CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { 77 OP_REQUIRES_OK(ctx, Prepare()); 78 } 79 80 Status CpuCastOp::Prepare() { 81 if (src_dtype_ == dst_dtype_) { 82 work_ = nullptr; // Identity 83 return Status::OK(); 84 } 85 if (src_dtype_ == DT_BOOL) { 86 work_ = GetCpuCastFromBool(dst_dtype_); 87 } else if (src_dtype_ == DT_UINT8) { 88 work_ = GetCpuCastFromUint8(dst_dtype_); 89 } else if (src_dtype_ == DT_INT8) { 90 work_ = GetCpuCastFromInt8(dst_dtype_); 91 } else if (src_dtype_ == DT_UINT16) { 92 work_ = GetCpuCastFromUint16(dst_dtype_); 93 } else if (src_dtype_ == DT_INT16) { 94 work_ = GetCpuCastFromInt16(dst_dtype_); 95 } else if (src_dtype_ == DT_INT32) { 96 work_ = GetCpuCastFromInt32(dst_dtype_); 97 } else if (src_dtype_ == DT_INT64) { 98 work_ = GetCpuCastFromInt64(dst_dtype_); 99 } else if (src_dtype_ == DT_HALF) { 100 work_ = GetCpuCastFromHalf(dst_dtype_); 101 } else if (src_dtype_ == DT_FLOAT) { 102 work_ = GetCpuCastFromFloat(dst_dtype_); 103 } else if (src_dtype_ == DT_DOUBLE) { 104 work_ = GetCpuCastFromDouble(dst_dtype_); 105 } else if (src_dtype_ == DT_COMPLEX64) { 106 work_ = GetCpuCastFromComplex64(dst_dtype_); 107 } else if (src_dtype_ == DT_COMPLEX128) { 108 work_ = GetCpuCastFromComplex128(dst_dtype_); 109 } else if (src_dtype_ == DT_BFLOAT16) { 110 work_ = GetCpuCastFromBfloat(dst_dtype_); 111 } 112 113 // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a 114 // bottleneck, we could probably implement specialized support for 115 // vectorized versions (not the least based on F16C for Haswell 116 // or newer). 117 118 return work_ == nullptr ? Unimplemented() : Status::OK(); 119 } 120 121 #if GOOGLE_CUDA 122 class GpuCastOp : public CastOpBase { 123 public: 124 explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { 125 OP_REQUIRES_OK(ctx, Prepare()); 126 } 127 128 private: 129 Status Prepare() { 130 if (src_dtype_ == dst_dtype_) { 131 work_ = nullptr; // Identity 132 return Status::OK(); 133 } 134 if (src_dtype_ == DT_BOOL) { 135 work_ = GetGpuCastFromBool(dst_dtype_); 136 } else if (src_dtype_ == DT_UINT8) { 137 work_ = GetGpuCastFromUint8(dst_dtype_); 138 } else if (src_dtype_ == DT_INT8) { 139 work_ = GetGpuCastFromInt8(dst_dtype_); 140 } else if (src_dtype_ == DT_UINT16) { 141 work_ = GetGpuCastFromUint16(dst_dtype_); 142 } else if (src_dtype_ == DT_INT16) { 143 work_ = GetGpuCastFromInt16(dst_dtype_); 144 } else if (src_dtype_ == DT_INT32) { 145 work_ = GetGpuCastFromInt32(dst_dtype_); 146 } else if (src_dtype_ == DT_INT64) { 147 work_ = GetGpuCastFromInt64(dst_dtype_); 148 } else if (src_dtype_ == DT_HALF) { 149 work_ = GetGpuCastFromHalf(dst_dtype_); 150 } else if (src_dtype_ == DT_FLOAT) { 151 work_ = GetGpuCastFromFloat(dst_dtype_); 152 } else if (src_dtype_ == DT_DOUBLE) { 153 work_ = GetGpuCastFromDouble(dst_dtype_); 154 } else if (src_dtype_ == DT_COMPLEX64) { 155 work_ = GetGpuCastFromComplex64(dst_dtype_); 156 } else if (src_dtype_ == DT_COMPLEX128) { 157 work_ = GetGpuCastFromComplex128(dst_dtype_); 158 } else if (src_dtype_ == DT_BFLOAT16) { 159 work_ = GetGpuCastFromBfloat(dst_dtype_); 160 } 161 162 return work_ == nullptr ? Unimplemented() : Status::OK(); 163 } 164 }; 165 #endif // GOOGLE_CUDA 166 167 #undef CAST_CASE 168 169 REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp); 170 171 #if GOOGLE_CUDA 172 #define REGISTER_CAST_GPU(srctype, dsttype) \ 173 REGISTER_KERNEL_BUILDER(Name("Cast") \ 174 .TypeConstraint<srctype>("SrcT") \ 175 .TypeConstraint<dsttype>("DstT") \ 176 .Device(DEVICE_GPU), \ 177 GpuCastOp) 178 179 CURRY_TYPES2(REGISTER_CAST_GPU, bool); 180 CURRY_TYPES2(REGISTER_CAST_GPU, uint8); 181 CURRY_TYPES2(REGISTER_CAST_GPU, int8); 182 CURRY_TYPES2(REGISTER_CAST_GPU, uint16); 183 CURRY_TYPES2(REGISTER_CAST_GPU, int16); 184 CURRY_TYPES2(REGISTER_CAST_GPU, int32); 185 CURRY_TYPES2(REGISTER_CAST_GPU, int64); 186 CURRY_TYPES2(REGISTER_CAST_GPU, Eigen::half); 187 CURRY_TYPES2(REGISTER_CAST_GPU, float); 188 CURRY_TYPES2(REGISTER_CAST_GPU, double); 189 CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<float>); 190 CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<double>); 191 REGISTER_CAST_GPU(float, bfloat16); 192 REGISTER_CAST_GPU(bfloat16, float); 193 194 #undef REGISTER_CAST_GPU 195 #endif // GOOGLE_CUDA 196 197 #ifdef TENSORFLOW_USE_SYCL 198 class SyclCastOp : public CastOpBase { 199 public: 200 explicit SyclCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { 201 OP_REQUIRES_OK(ctx, Prepare()); 202 } 203 204 private: 205 Status Prepare() { 206 if (src_dtype_ == dst_dtype_) { 207 work_ = nullptr; // Identity 208 return Status::OK(); 209 } 210 if (src_dtype_ == DT_BOOL) { 211 work_ = GetSyclCastFromBool(dst_dtype_); 212 } else if (src_dtype_ == DT_INT32) { 213 work_ = GetSyclCastFromInt32(dst_dtype_); 214 } else if (src_dtype_ == DT_INT64) { 215 work_ = GetSyclCastFromInt64(dst_dtype_); 216 } else if (src_dtype_ == DT_FLOAT) { 217 work_ = GetSyclCastFromFloat(dst_dtype_); 218 } else if (src_dtype_ == DT_DOUBLE) { 219 work_ = GetSyclCastFromDouble(dst_dtype_); 220 } 221 222 return work_ == nullptr ? Unimplemented() : Status::OK(); 223 } 224 }; 225 226 #define REGISTER_CAST_SYCL(srctype, dsttype) \ 227 REGISTER_KERNEL_BUILDER(Name("Cast") \ 228 .TypeConstraint<srctype>("SrcT") \ 229 .TypeConstraint<dsttype>("DstT") \ 230 .Device(DEVICE_SYCL), \ 231 SyclCastOp) 232 CURRY_TYPES2(REGISTER_CAST_SYCL, bool); 233 CURRY_TYPES2(REGISTER_CAST_SYCL, int32); 234 CURRY_TYPES2(REGISTER_CAST_SYCL, int64); 235 CURRY_TYPES2(REGISTER_CAST_SYCL, float); 236 CURRY_TYPES2(REGISTER_CAST_SYCL, double); 237 238 #undef REGISTER_CAST_SYCL 239 240 #endif // TENSORFLOW_USE_SYCL 241 242 #undef CURRY_TYPES2 243 244 // HostCast differs from Cast in that its input and output are in host memory. 245 REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp); 246 REGISTER_KERNEL_BUILDER( 247 Name("_HostCast").Device(DEVICE_GPU).HostMemory("x").HostMemory("y"), 248 CpuCastOp); 249 #ifdef TENSORFLOW_USE_SYCL 250 REGISTER_KERNEL_BUILDER( 251 Name("_HostCast").Device(DEVICE_SYCL).HostMemory("x").HostMemory("y"), 252 CpuCastOp); 253 #endif // TENSORFLOW_USE_SYCL 254 } // end namespace tensorflow 255