Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_
     17 #define TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_
     18 
     19 #ifdef INTEL_MKL
     20 
     21 #include <algorithm>
     22 #include <vector>
     23 #include "tensorflow/core/framework/numeric_op.h"
     24 #include "tensorflow/core/framework/op.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor_shape.h"
     29 #include "tensorflow/core/kernels/ops_util.h"
     30 #include "tensorflow/core/platform/cpu_info.h"
     31 #include "tensorflow/core/platform/macros.h"
     32 #include "tensorflow/core/util/tensor_format.h"
     33 
     34 #include "mkl_dnn.h"
     35 #include "mkl_dnn_types.h"
     36 #include "tensorflow/core/util/mkl_util.h"
     37 
     38 #ifndef INTEL_MKL_ML
     39 using mkldnn::stream;
     40 #endif
     41 
     42 namespace tensorflow {
     43 typedef Eigen::ThreadPoolDevice CPUDevice;
     44 
     45 ///////////////////////////////////////////////////////////
     46 //               Op kernel
     47 ///////////////////////////////////////////////////////////
     48 
     49 template <typename Device, typename T>
     50 class MklToTfOp : public OpKernel {
     51  public:
     52   explicit MklToTfOp(OpKernelConstruction* context) : OpKernel(context) {
     53     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
     54     OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type));
     55     has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F);
     56   }
     57 
     58   void Compute(OpKernelContext* context) override {
     59     ConvertMklToTf(this, context, data_format_str, op_data_type, has_avx512f_,
     60                    0);
     61     VLOG(1) << "MKLToTFConversion complete successfully.";
     62   }
     63 
     64 #ifndef INTEL_MKL_ML
     65   static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context,
     66                              string data_format_str, DataType op_data_type,
     67                              bool has_avx512f, uint input_number) {
     68     try {
     69       // Check that input tensor is in MKL format.
     70       const Tensor& input_tensor = MklGetInput(context, input_number);
     71       MklDnnShape input_shape;
     72       GetMklShape(context, input_number, &input_shape);
     73 
     74       // if input is already in Tf format, then copy input tensor to output.
     75       if (!input_shape.IsMklTensor()) {
     76         context->set_output(input_number, input_tensor);
     77         VLOG(1) << "MKLToTFConversion: No conversion needed, "
     78                 << "copying input to output";
     79         return;
     80       }
     81 
     82       // Check that input data type is same as operator data type and that it
     83       // is same as output data type.
     84       DataType input_data_type = op_kernel->input_type(input_number);
     85       DataType output_data_type = op_kernel->output_type(input_number);
     86       CHECK_EQ(op_data_type, input_data_type);
     87       CHECK_EQ(op_data_type, output_data_type);
     88 
     89       auto cpu_engine = engine(engine::cpu, 0);
     90       MklDnnData<T> input(&cpu_engine);
     91 
     92       // Get Mkl layout of input tensor.
     93       auto input_mkl_md = input_shape.GetMklLayout();
     94       // Get TensorFlow layout of input tensor. Expected output of conversion
     95       // has same layout as Tensorflow layout of input tensor.
     96       auto output_tf_md = input_shape.GetTfLayout();
     97       auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine);
     98       // Set input Mkl layout as the user layout.
     99       input.SetUsrMem(input_mkl_md, &input_tensor);
    100 
    101       // Allocate output tensor.
    102       TensorShape output_shape = input_shape.GetTfShape();
    103       Tensor* output_tensor = NULL;
    104       OP_REQUIRES_OK(context, context->allocate_output(
    105                                   input_number, output_shape, &output_tensor));
    106       CHECK_NOTNULL(output_tensor);
    107 
    108       // Do we need to reorder Mkl layout into TensorFlow layout?
    109       if (input.IsReorderNeeded(output_tf_pd)) {
    110         // Insert reorder between Mkl layout and TensorFlow layout.
    111         std::vector<primitive> net;
    112         CHECK_EQ(input.CheckReorderToOpMem(output_tf_pd, output_tensor, &net),
    113                  true);
    114         stream(stream::kind::eager).submit(net).wait();
    115       } else {
    116         // If not, just forward input tensor to output tensor.
    117         CHECK(output_tensor->CopyFrom(input_tensor, output_shape));
    118       }
    119     } catch (mkldnn::error& e) {
    120       string error_msg = "Status: " + std::to_string(e.status) +
    121                          ", message: " + std::string(e.message) + ", in file " +
    122                          std::string(__FILE__) + ":" + std::to_string(__LINE__);
    123       OP_REQUIRES_OK(
    124           context,
    125           errors::Aborted("Operation received an exception:", error_msg));
    126     }
    127   }
    128 #else
    129   static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context,
    130                              string data_format_str, DataType op_data_type,
    131                              bool has_avx512f, uint32 input_number) {
    132     // Check that input tensor is in MKL format.
    133     const Tensor& input_tensor = MklGetInput(context, input_number);
    134     MklShape input_shape;
    135     GetMklShape(context, input_number, &input_shape);
    136 
    137     // if input is already in Tf format, then just copy input tensor to output.
    138     if (!input_shape.IsMklTensor()) {
    139       context->set_output(input_number, input_tensor);
    140       VLOG(1) << "MKLToTFConversion: No conversion needed, "
    141               << "copying input to output";
    142       return;
    143     }
    144 
    145     // Check that input data type is same as operator data type and that it is
    146     // same as output data type.
    147     DataType input_data_type = op_kernel->input_type(input_number);
    148     DataType output_data_type = op_kernel->output_type(input_number);
    149     CHECK_EQ(op_data_type, input_data_type);
    150     CHECK_EQ(op_data_type, output_data_type);
    151 
    152     TensorShape output_shape;
    153     size_t ndims = input_shape.GetDimension();
    154     size_t* in_sizes = new size_t[ndims];
    155     for (size_t i = 0; i < ndims; i++) {
    156       // Outermost to innermost dimension
    157       output_shape.AddDim(input_shape.GetSizes()[input_shape.tf_dim_idx(i)]);
    158       in_sizes[i] = input_shape.GetSizes()[i];
    159     }
    160 
    161     // Allocate output tensor.
    162     Tensor* output_tensor = NULL;
    163     OP_REQUIRES_OK(context, context->allocate_output(input_number, output_shape,
    164                                                      &output_tensor));
    165 
    166     dnnLayout_t output_layout =
    167         static_cast<dnnLayout_t>(input_shape.GetTfLayout());
    168     // Execute DNNConversion.
    169     void* input_buffer =
    170         static_cast<void*>(const_cast<T*>(input_tensor.flat<T>().data()));
    171     delete[] in_sizes;
    172     void* output_buffer =
    173         static_cast<void*>(const_cast<T*>(output_tensor->flat<T>().data()));
    174     input_shape.GetConvertedFlatData(output_layout, input_buffer,
    175                                      output_buffer);
    176     VLOG(1) << "MKLToTFConversion complete successfully.";
    177   }
    178 #endif
    179 
    180  private:
    181   /// Data format of the operation
    182   string data_format_str;
    183 
    184   /// Data type of the operation
    185   DataType op_data_type;
    186 
    187   /// CPUIDInfo
    188   bool has_avx512f_ = false;
    189 };
    190 
    191 ///////////////////////////////////////////////////////////
    192 //               Register kernel
    193 ///////////////////////////////////////////////////////////
    194 
    195 #define REGISTER_CPU(T)                                             \
    196   REGISTER_KERNEL_BUILDER(Name("_MklToTf")                          \
    197                               .Device(DEVICE_CPU)                   \
    198                               .TypeConstraint<T>("T")               \
    199                               .Label(mkl_op_registry::kMklOpLabel), \
    200                           MklToTfOp<CPUDevice, T>);
    201 
    202 TF_CALL_NUMBER_TYPES(REGISTER_CPU);
    203 #undef REGISTER_CPU
    204 }  // namespace tensorflow
    205 #endif  // INTEL_MKL
    206 #endif  // TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_
    207