1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_CAST_OP_H_ 17 #define TENSORFLOW_KERNELS_CAST_OP_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/core/framework/bfloat16.h" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/tensor_types.h" 23 #include "tensorflow/core/framework/types.h" 24 #include "tensorflow/core/platform/cpu_info.h" 25 #include "tensorflow/core/platform/types.h" 26 27 namespace tensorflow { 28 29 // Common base class of Cast kernels 30 class CastOpBase : public OpKernel { 31 public: 32 explicit CastOpBase(OpKernelConstruction* ctx); 33 34 void Compute(OpKernelContext* ctx) override; 35 36 protected: 37 DataType src_dtype_; 38 DataType dst_dtype_; 39 std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr; 40 41 Status Unimplemented(); 42 43 TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase); 44 }; 45 46 // CPU implementation of Cast 47 class CpuCastOp : public CastOpBase { 48 public: 49 explicit CpuCastOp(OpKernelConstruction* ctx); 50 51 private: 52 Status Prepare(); 53 }; 54 55 namespace functor { 56 57 template <typename Device, typename Tout, typename Tin> 58 void Cast(const Device& d, typename TTypes<Tout>::Flat o, 59 typename TTypes<Tin>::ConstFlat i) { 60 o.device(d) = i.template cast<Tout>(); 61 } 62 63 template <typename Device, typename Tout, typename Tin> 64 struct CastFunctor { 65 void operator()(const Device& d, typename TTypes<Tout>::Flat o, 66 typename TTypes<Tin>::ConstFlat i); 67 }; 68 69 } // end namespace functor 70 } // end namespace tensorflow 71 72 namespace Eigen { 73 namespace internal { 74 75 // Eigen can't convert to/from complex numbers, because it is limited to cases 76 // that can be static_casted. But numpy is able to cast to/from complex, which 77 // we want to replicate. So we add specializations for complex here. 78 template <typename From, typename To> 79 struct scalar_cast_op<std::complex<From>, To> { 80 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE To 81 operator()(const std::complex<From>& a) const { 82 // Replicate numpy behavior of returning just the real part 83 return static_cast<To>(a.real()); 84 } 85 }; 86 87 template <typename From, typename To> 88 struct scalar_cast_op<From, std::complex<To>> { 89 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()( 90 const From& a) const { 91 // Replicate numpy behavior of setting the imaginary part to 0 92 return std::complex<To>(static_cast<To>(a), To(0)); 93 } 94 }; 95 96 template <typename From, typename To> 97 struct scalar_cast_op<std::complex<From>, std::complex<To>> { 98 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()( 99 const std::complex<From>& a) const { 100 return std::complex<To>(static_cast<To>(a.real()), 101 static_cast<To>(a.imag())); 102 } 103 }; 104 105 template <typename From, typename To> 106 struct functor_traits_complex_impl { 107 enum { Cost = NumTraits<To>::AddCost, PacketAccess = false }; 108 }; 109 110 template <typename From, typename To> 111 struct functor_traits<scalar_cast_op<std::complex<From>, To>> 112 : functor_traits_complex_impl<std::complex<From>, To> {}; 113 template <typename From, typename To> 114 struct functor_traits<scalar_cast_op<From, std::complex<To>>> 115 : functor_traits_complex_impl<From, std::complex<To>> {}; 116 // Needed to avoid ambiguous partial specialization 117 template <typename From, typename To> 118 struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>> 119 : functor_traits_complex_impl<std::complex<From>, std::complex<To>> {}; 120 121 // Specialized cast op impls for bfloat16. 122 template <> 123 struct scalar_cast_op<::tensorflow::bfloat16, float> { 124 EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) 125 typedef float result_type; 126 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()( 127 const ::tensorflow::bfloat16& a) const { 128 float ret; 129 uint16_t* p = reinterpret_cast<uint16_t*>(&ret); 130 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ 131 p[0] = a.value; 132 p[1] = 0; 133 #else 134 static_assert(::tensorflow::port::kLittleEndian, 135 "Not a little endian system!"); 136 p[0] = 0; 137 p[1] = a.value; 138 #endif 139 return ret; 140 } 141 }; 142 143 template <> 144 struct functor_traits<scalar_cast_op<::tensorflow::bfloat16, float>> { 145 enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; 146 }; 147 148 template <> 149 struct scalar_cast_op<float, ::tensorflow::bfloat16> { 150 EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) 151 typedef ::tensorflow::bfloat16 result_type; 152 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const ::tensorflow::bfloat16 operator()( 153 const float a) const { 154 return ::tensorflow::bfloat16(a); 155 } 156 }; 157 158 template <> 159 struct functor_traits<scalar_cast_op<float, ::tensorflow::bfloat16>> { 160 enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; 161 }; 162 163 } // namespace internal 164 } // namespace Eigen 165 166 #endif // TENSORFLOW_KERNELS_CAST_OP_H_ 167