Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_CAST_OP_H_
     17 #define TENSORFLOW_KERNELS_CAST_OP_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "tensorflow/core/framework/bfloat16.h"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/tensor_types.h"
     23 #include "tensorflow/core/framework/types.h"
     24 #include "tensorflow/core/platform/cpu_info.h"
     25 #include "tensorflow/core/platform/types.h"
     26 
     27 namespace tensorflow {
     28 
     29 // Common base class of Cast kernels
     30 class CastOpBase : public OpKernel {
     31  public:
     32   explicit CastOpBase(OpKernelConstruction* ctx);
     33 
     34   void Compute(OpKernelContext* ctx) override;
     35 
     36  protected:
     37   DataType src_dtype_;
     38   DataType dst_dtype_;
     39   std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr;
     40 
     41   Status Unimplemented();
     42 
     43   TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase);
     44 };
     45 
     46 // CPU implementation of Cast
     47 class CpuCastOp : public CastOpBase {
     48  public:
     49   explicit CpuCastOp(OpKernelConstruction* ctx);
     50 
     51  private:
     52   Status Prepare();
     53 };
     54 
     55 namespace functor {
     56 
     57 template <typename Device, typename Tout, typename Tin>
     58 void Cast(const Device& d, typename TTypes<Tout>::Flat o,
     59           typename TTypes<Tin>::ConstFlat i) {
     60   o.device(d) = i.template cast<Tout>();
     61 }
     62 
     63 template <typename Device, typename Tout, typename Tin>
     64 struct CastFunctor {
     65   void operator()(const Device& d, typename TTypes<Tout>::Flat o,
     66                   typename TTypes<Tin>::ConstFlat i);
     67 };
     68 
     69 }  // end namespace functor
     70 }  // end namespace tensorflow
     71 
     72 namespace Eigen {
     73 namespace internal {
     74 
     75 // Eigen can't convert to/from complex numbers, because it is limited to cases
     76 // that can be static_casted. But numpy is able to cast to/from complex, which
     77 // we want to replicate. So we add specializations for complex here.
     78 template <typename From, typename To>
     79 struct scalar_cast_op<std::complex<From>, To> {
     80   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE To
     81   operator()(const std::complex<From>& a) const {
     82     // Replicate numpy behavior of returning just the real part
     83     return static_cast<To>(a.real());
     84   }
     85 };
     86 
     87 template <typename From, typename To>
     88 struct scalar_cast_op<From, std::complex<To>> {
     89   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
     90       const From& a) const {
     91     // Replicate numpy behavior of setting the imaginary part to 0
     92     return std::complex<To>(static_cast<To>(a), To(0));
     93   }
     94 };
     95 
     96 template <typename From, typename To>
     97 struct scalar_cast_op<std::complex<From>, std::complex<To>> {
     98   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
     99       const std::complex<From>& a) const {
    100     return std::complex<To>(static_cast<To>(a.real()),
    101                             static_cast<To>(a.imag()));
    102   }
    103 };
    104 
    105 template <typename From, typename To>
    106 struct functor_traits_complex_impl {
    107   enum { Cost = NumTraits<To>::AddCost, PacketAccess = false };
    108 };
    109 
    110 template <typename From, typename To>
    111 struct functor_traits<scalar_cast_op<std::complex<From>, To>>
    112     : functor_traits_complex_impl<std::complex<From>, To> {};
    113 template <typename From, typename To>
    114 struct functor_traits<scalar_cast_op<From, std::complex<To>>>
    115     : functor_traits_complex_impl<From, std::complex<To>> {};
    116 // Needed to avoid ambiguous partial specialization
    117 template <typename From, typename To>
    118 struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
    119     : functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};
    120 
    121 // Specialized cast op impls for bfloat16.
    122 template <>
    123 struct scalar_cast_op<::tensorflow::bfloat16, float> {
    124   EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
    125   typedef float result_type;
    126   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(
    127       const ::tensorflow::bfloat16& a) const {
    128     float ret;
    129     uint16_t* p = reinterpret_cast<uint16_t*>(&ret);
    130 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
    131     p[0] = a.value;
    132     p[1] = 0;
    133 #else
    134     static_assert(::tensorflow::port::kLittleEndian,
    135                   "Not a little endian system!");
    136     p[0] = 0;
    137     p[1] = a.value;
    138 #endif
    139     return ret;
    140   }
    141 };
    142 
    143 template <>
    144 struct functor_traits<scalar_cast_op<::tensorflow::bfloat16, float>> {
    145   enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
    146 };
    147 
    148 template <>
    149 struct scalar_cast_op<float, ::tensorflow::bfloat16> {
    150   EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
    151   typedef ::tensorflow::bfloat16 result_type;
    152   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const ::tensorflow::bfloat16 operator()(
    153       const float a) const {
    154     return ::tensorflow::bfloat16(a);
    155   }
    156 };
    157 
    158 template <>
    159 struct functor_traits<scalar_cast_op<float, ::tensorflow::bfloat16>> {
    160   enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
    161 };
    162 
    163 }  // namespace internal
    164 }  // namespace Eigen
    165 
    166 #endif  // TENSORFLOW_KERNELS_CAST_OP_H_
    167