Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifdef INTEL_MKL
     17 
     18 #include <algorithm>
     19 #include <vector>
     20 #include "tensorflow/core/framework/numeric_op.h"
     21 #include "tensorflow/core/framework/op.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/kernels/ops_util.h"
     27 #include "tensorflow/core/platform/cpu_info.h"
     28 #include "tensorflow/core/platform/macros.h"
     29 #include "tensorflow/core/util/tensor_format.h"
     30 
     31 #include "tensorflow/core/kernels/mkl_tfconv_op.h"
     32 #include "tensorflow/core/util/mkl_util.h"
     33 
     34 #ifndef INTEL_MKL_ML
     35 #include "mkldnn.hpp"
     36 
     37 using mkldnn::stream;
     38 #endif
     39 
     40 namespace tensorflow {
     41 typedef Eigen::ThreadPoolDevice CPUDevice;
     42 
     43 ///////////////////////////////////////////////////////////
     44 //               Op kernel
     45 // Checks and ensures that the 2 inputs are compatible for mkl binary ops.
     46 // Here's the basic logic:
     47 //
     48 // if both inputs are in TF format:
     49 //   pass the inputs through to the output
     50 // else if both inputs are in mkl format:
     51 //   if both have the same shape:
     52 //     pass the inputs through to the output
     53 //   else:
     54 //     convert both to TF
     55 // else if one is TF and one is MKL:
     56 //   if broadcast is needed:
     57 //     convert the MKL format input to TF format
     58 //   else:
     59 //     convert the TF format input to MKL format
     60 ///////////////////////////////////////////////////////////
     61 
     62 #ifdef INTEL_MKL_ML
     63 template <typename Device, typename T>
     64 class MklInputConversionOp : public OpKernel {
     65  public:
     66   explicit MklInputConversionOp(OpKernelConstruction* context)
     67       : OpKernel(context) {
     68     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
     69     OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type));
     70     has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F);
     71   }
     72 
     73  private:
     74   void Compute(OpKernelContext* context) override {
     75     // Check if input tensors are in MKL format.
     76     const Tensor& input_tensor_0 = MklGetInput(context, 0);
     77     MklShape input_shape_0;
     78     GetMklShape(context, 0, &input_shape_0);
     79 
     80     const Tensor& input_tensor_1 = MklGetInput(context, 1);
     81     MklShape input_shape_1;
     82     GetMklShape(context, 1, &input_shape_1);
     83 
     84     bool tf_shapes_are_same = MklCompareShapes(&context->input(0).shape(),
     85                                                &context->input(1).shape());
     86 
     87     VLOG(1) << "MklInputConversionOp: Input shapes are "
     88             << (tf_shapes_are_same ? "*same*" : "*different*") << ": "
     89             << context->input(0).shape().DebugString() << " and "
     90             << context->input(1).shape().DebugString();
     91 
     92     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
     93     // if both inputs are in TF format, just copy input tensors to output.
     94     if (!input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
     95       VLOG(1) << "MklInputConversionOp: No conversion needed, "
     96               << "copying TF inputs to output";
     97 
     98       ForwardTfTensorInToOut(context, 0, 0);
     99       ForwardTfTensorInToOut(context, 1, 1);
    100       return;
    101     }
    102 
    103     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    104     // If both inputs are in MKL format
    105     if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
    106       // If both have the same shape, pass them through
    107       if (tf_shapes_are_same) {
    108         VLOG(1) << "MklInputConversionOp: No conversion needed, "
    109                 << "copying MKL inputs with identical shapes to output";
    110 
    111         ForwardMklTensorInToOut(context, 0, 0);
    112         ForwardMklTensorInToOut(context, 1, 1);
    113         return;
    114       }
    115 
    116       // Sanity check
    117       bool mkl_shapes_are_same =
    118           MklCompareShapes(&input_shape_0, &input_shape_1);
    119       if (mkl_shapes_are_same) {
    120         CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are "
    121                         "different but MKL shapes are same";
    122       }
    123 
    124       // Both have different shapes, so broadcast will be necessary.
    125       // Convert to TF and pass both tensors through (we can't do broadcast
    126       // with MKL tensors)
    127       VLOG(1) << "MklInputConversionOp: Broadcast needed, "
    128               << "converted MKL inputs to TF format";
    129 
    130       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
    131                                            op_data_type, has_avx512f_, 0);
    132       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
    133                                            op_data_type, has_avx512f_, 1);
    134       SetDummyMklShapeOutput(context, 0);
    135       SetDummyMklShapeOutput(context, 1);
    136       return;
    137     }
    138 
    139     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    140     // One input is MKL and one is TF. If no broadcast is needed, convert
    141     // the TF tensor to MKL, otherwise convert the MKL tensor to TF format
    142     VLOG(1) << "MklInputConversionOp: Inputs in different formats (MKL/TF)";
    143 
    144     const Tensor* mkl_tensor;
    145     const MklShape* mkl_shape;
    146     const Tensor* tf_tensor;
    147     MklShape* tf_mkl_shape;
    148     uint32 mkl_tensor_index;
    149     uint32 tf_tensor_index;
    150     if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
    151       mkl_tensor = &input_tensor_0;
    152       mkl_shape = &input_shape_0;
    153       mkl_tensor_index = 0;
    154       tf_tensor = &input_tensor_1;
    155       tf_mkl_shape = &input_shape_1;
    156       tf_tensor_index = 1;
    157     } else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
    158       mkl_tensor = &input_tensor_1;
    159       mkl_shape = &input_shape_1;
    160       mkl_tensor_index = 1;
    161       tf_tensor = &input_tensor_0;
    162       tf_mkl_shape = &input_shape_0;
    163       tf_tensor_index = 0;
    164     } else {
    165       CHECK(false) << "MklInputConversionOp: Unexpected combination of input "
    166                       "shapes for MKL "
    167                    << "element-wise op";
    168     }
    169 
    170     // Broadcast is needed if the shapes are not the same
    171     bool broadcast_needed;
    172 
    173     size_t in0_size = 1;
    174     for (size_t i = 0; i < mkl_shape->GetDimension(); ++i)
    175       in0_size *= mkl_shape->tf_dim_size(i);
    176 
    177     size_t in1_size = 1;
    178     for (size_t i = 0; i < tf_tensor->shape().dims(); ++i)
    179       in1_size *= tf_tensor->shape().dim_size(i);
    180 
    181     broadcast_needed = (in0_size != in1_size);
    182 
    183     if (!broadcast_needed) {
    184       // Both shapes are same, convert the TF input to MKL
    185       VLOG(1) << "MklInputConversionOp: No broadcast needed.";
    186       VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index
    187               << " to MKL format";
    188 
    189       // Create MklShape
    190       Tensor* tensor_out;
    191       MklShape mkl_output_mkl_shape;
    192       mkl_output_mkl_shape.SetMklTensor(true);
    193       mkl_output_mkl_shape.SetTfLayout(mkl_shape->GetDimension(),
    194                                        mkl_shape->GetSizes(),
    195                                        mkl_shape->GetStrides());
    196       mkl_output_mkl_shape.SetTfDimOrder(mkl_shape->GetDimension());
    197 
    198       // ** Temporarily borrow the layout from the MKL input **
    199       mkl_output_mkl_shape.SetMklLayout(mkl_shape->GetCurLayout());
    200 
    201       // Create output tensor
    202       AllocateOutputSetMklShape(context, tf_tensor_index, &tensor_out,
    203                                 mkl_tensor->shape(), mkl_output_mkl_shape);
    204 
    205       // Since the shapes are the same, use information from the other tensor
    206       tf_mkl_shape->SetTfLayout(mkl_shape->GetDimension(),
    207                                 mkl_shape->GetSizes(), mkl_shape->GetStrides());
    208       // Convert the data format
    209       tf_mkl_shape->GetConvertedFlatData(
    210           mkl_shape->GetCurLayout(),
    211           const_cast<T*>(tf_tensor->flat<T>().data()),
    212           const_cast<T*>(tensor_out->flat<T>().data()));
    213 
    214       // ** Release the borrowed layout to avoid double deletion
    215       //    in the destructor call **
    216       mkl_output_mkl_shape.SetMklLayout(nullptr);
    217 
    218       // -- The tensor in MKL format passes through --
    219       ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index);
    220     } else {
    221       // Broadcast is needed, so convert the MKL input to TF
    222       VLOG(1) << "MklInputConversionOp: Broadcast needed.";
    223       VLOG(1) << "MklInputConversionOp: Converting input " << mkl_tensor_index
    224               << " to TF format";
    225       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
    226                                            op_data_type, has_avx512f_,
    227                                            mkl_tensor_index);
    228       SetDummyMklShapeOutput(context, mkl_tensor_index);
    229 
    230       // The tensor in TF format passes through
    231       ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index);
    232     }
    233 
    234     VLOG(1) << "MklInputConversionOp: Shapes (output): "
    235             << context->mutable_output(0)->shape().DebugString() << " and "
    236             << context->mutable_output(1)->shape().DebugString();
    237 
    238     VLOG(1) << "MklInputConversion completed successfully.";
    239   }
    240 
    241  private:
    242   /// Data format of the operation
    243   string data_format_str;
    244 
    245   /// Data type of the operation
    246   DataType op_data_type;
    247 
    248   /// CPUIDInfo
    249   bool has_avx512f_ = false;
    250 };
    251 
    252 #else
    253 
    254 template <typename Device, typename T>
    255 class MklInputConversionOp : public OpKernel {
    256  public:
    257   explicit MklInputConversionOp(OpKernelConstruction* context)
    258       : OpKernel(context) {
    259     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
    260     OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type));
    261     has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F);
    262   }
    263 
    264  private:
    265   void Compute(OpKernelContext* context) override {
    266     const Tensor& input_tensor_0 = MklGetInput(context, 0);
    267     MklDnnShape input_shape_0;
    268     GetMklShape(context, 0, &input_shape_0);
    269 
    270     const Tensor& input_tensor_1 = MklGetInput(context, 1);
    271     MklDnnShape input_shape_1;
    272     GetMklShape(context, 1, &input_shape_1);
    273 
    274     bool tf_shapes_are_same =
    275         context->input(0).shape() == context->input(1).shape();
    276 
    277     VLOG(1) << "MklInputConversionOp: Input shapes are "
    278             << (tf_shapes_are_same ? "*same*" : "*different*") << ": "
    279             << context->input(0).shape().DebugString() << " and "
    280             << context->input(1).shape().DebugString();
    281 
    282     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    283     // if both inputs are in TF format, just copy input tensors to output.
    284     if (!input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
    285       VLOG(1) << "MklInputConversionOp: No conversion needed, "
    286               << "copying TF inputs to output";
    287 
    288       ForwardTfTensorInToOut(context, 0, 0);
    289       ForwardTfTensorInToOut(context, 1, 1);
    290       return;
    291     }
    292 
    293     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    294     // If both inputs are in MKL format
    295     if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
    296       if (tf_shapes_are_same) {
    297         auto input0_md = input_shape_0.GetMklLayout();
    298         auto input1_md = input_shape_1.GetMklLayout();
    299 
    300         // If both have the same shape and same format, pass them through
    301         if (input0_md.data.format == input1_md.data.format) {
    302           VLOG(1) << "MklInputConversionOp: No conversion needed, "
    303                   << "copying MKL inputs with identical shapes to output";
    304 
    305           ForwardMklTensorInToOut(context, 0, 0);
    306           ForwardMklTensorInToOut(context, 1, 1);
    307           return;
    308         } else {
    309           VLOG(1) << "MklInputConversionOp: Shape is same, but format is "
    310                      "different, "
    311                   << "need to convert to same format";
    312 
    313           // Convert input0, and keep input1 unchanged
    314           // Create MklDnnShape for output mkl tensor based on input0
    315           Tensor* tensor_out;
    316           MklDnnShape mkl_output_mkl_shape;
    317           mkl_output_mkl_shape.SetMklTensor(true);
    318           mkl_output_mkl_shape.SetElemType(MklDnnType<T>());
    319           mkl_output_mkl_shape.SetTfLayout(input_shape_0.GetDimension(),
    320                                            input_shape_0.GetSizesAsMklDnnDims(),
    321                                            input_shape_0.GetTfDataFormat());
    322 
    323           // Get MKL layout from input1 as destination layout
    324           mkl_output_mkl_shape.SetMklLayout(&input1_md);
    325 
    326           // Create output Mkl tensor for index 0
    327           AllocateOutputSetMklShape(context, 0, &tensor_out,
    328                                     input_tensor_0.shape(),
    329                                     mkl_output_mkl_shape);
    330 
    331           // Create MklDnnData object for input0 tesnsor
    332           auto cpu_engine = engine(engine::cpu, 0);
    333           MklDnnData<T> input(&cpu_engine);
    334           input.SetUsrMem(input0_md, &input_tensor_0);
    335 
    336           // Create reorder from input0's layout to input1's layout
    337           std::vector<primitive> net;
    338           CHECK_EQ(input.CheckReorderToOpMem(
    339                        memory::primitive_desc(input1_md, cpu_engine),
    340                        tensor_out, &net),
    341                    true);
    342           stream(stream::kind::eager).submit(net).wait();
    343 
    344           // Input1 will be passed through
    345           ForwardMklTensorInToOut(context, 1, 1);
    346           return;
    347         }
    348       }
    349 
    350       // Sanity check
    351       bool mkl_shapes_are_same = input_shape_0 == input_shape_1;
    352       if (mkl_shapes_are_same) {
    353         CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are "
    354                         "different but MKL shapes are same";
    355       }
    356 
    357       // Both have different shapes, so broadcast will be necessary.
    358       // Convert to TF and pass both tensors through (we can't do broadcast
    359       // with MKL tensors)
    360       VLOG(1) << "MklInputConversionOp: Broadcast needed, "
    361               << "converted MKL inputs to TF format";
    362 
    363       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
    364                                            op_data_type, has_avx512f_, 0);
    365       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
    366                                            op_data_type, has_avx512f_, 1);
    367       SetDummyMklShapeOutput(context, 0);
    368       SetDummyMklShapeOutput(context, 1);
    369       return;
    370     }
    371 
    372     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    373     // One input is MKL and one is TF. If no broadcast is needed, convert
    374     // the TF tensor to MKL, otherwise convert the MKL tensor to TF format
    375     VLOG(1) << "MklInputConversionOp: Inputs in different formats (MKL/TF)";
    376 
    377     const Tensor* mkl_tensor;
    378     const MklDnnShape* mkl_shape;
    379     const Tensor* tf_tensor;
    380     MklDnnShape* tf_mkl_shape;
    381     uint mkl_tensor_index;
    382     uint tf_tensor_index;
    383     if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
    384       mkl_tensor = &input_tensor_0;
    385       mkl_shape = &input_shape_0;
    386       mkl_tensor_index = 0;
    387       tf_tensor = &input_tensor_1;
    388       tf_mkl_shape = &input_shape_1;
    389       tf_tensor_index = 1;
    390     } else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
    391       mkl_tensor = &input_tensor_1;
    392       mkl_shape = &input_shape_1;
    393       mkl_tensor_index = 1;
    394       tf_tensor = &input_tensor_0;
    395       tf_mkl_shape = &input_shape_0;
    396       tf_tensor_index = 0;
    397     } else {
    398       CHECK(false) << "MklInputConversionOp: Unexpected combination of input "
    399                       "shapes for MKL "
    400                    << "element-wise op";
    401     }
    402 
    403     // Broadcast is needed if the shapes are not the same
    404     bool broadcast_needed;
    405 
    406     size_t in0_size = 1;
    407     for (size_t i = 0; i < mkl_shape->GetDimension(); ++i)
    408       in0_size *= mkl_shape->TfDimSize(i);
    409 
    410     size_t in1_size = 1;
    411     for (size_t i = 0; i < tf_tensor->shape().dims(); ++i)
    412       in1_size *= tf_tensor->shape().dim_size(i);
    413 
    414     broadcast_needed = (in0_size != in1_size);
    415 
    416     if (!broadcast_needed) {
    417       // Both shapes are same, convert the TF input to MKL
    418       VLOG(1) << "MklInputConversionOp: No broadcast needed.";
    419       VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index
    420               << " to MKL format";
    421 
    422       // Create MklDnnShape for output Mkl tensor.
    423       Tensor* tensor_out;
    424       MklDnnShape mkl_output_mkl_shape;
    425       mkl_output_mkl_shape.SetMklTensor(true);
    426       mkl_output_mkl_shape.SetElemType(MklDnnType<T>());
    427       mkl_output_mkl_shape.SetTfLayout(mkl_shape->GetDimension(),
    428                                        mkl_shape->GetSizesAsMklDnnDims(),
    429                                        mkl_shape->GetTfDataFormat());
    430       // ** Temporarily borrow the layout from the MKL input **
    431       auto output_mkl_md = mkl_shape->GetMklLayout();
    432       mkl_output_mkl_shape.SetMklLayout(&output_mkl_md);
    433 
    434       // Create output Mkl tensor
    435       AllocateOutputSetMklShape(context, tf_tensor_index, &tensor_out,
    436                                 mkl_tensor->shape(), mkl_output_mkl_shape);
    437 
    438       // Create MklDnnData object for input tensor. Input tensor is in
    439       // Tensorflow layout.
    440       auto cpu_engine = engine(engine::cpu, 0);
    441       MklDnnData<T> tf_input(&cpu_engine);
    442       auto input_tf_md = mkl_output_mkl_shape.GetTfLayout();
    443       tf_input.SetUsrMem(input_tf_md, tf_tensor);
    444 
    445       // Create reorder between tensorflow layout and Mkl layout.
    446       std::vector<primitive> net;
    447       CHECK_EQ(tf_input.CheckReorderToOpMem(
    448                    memory::primitive_desc(output_mkl_md, cpu_engine),
    449                    tensor_out, &net),
    450                true);
    451       stream(stream::kind::eager).submit(net).wait();
    452 
    453       // -- The tensor in MKL format passes through --
    454       ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index);
    455     } else {
    456       // Broadcast is needed, so convert the MKL input to TF
    457       VLOG(1) << "MklInputConversionOp: Broadcast needed.";
    458       VLOG(1) << "MklInputConversionOp: Converting input " << mkl_tensor_index
    459               << " to TF format";
    460       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
    461                                            op_data_type, has_avx512f_,
    462                                            mkl_tensor_index);
    463       SetDummyMklShapeOutput(context, mkl_tensor_index);
    464 
    465       // The tensor in TF format passes through
    466       ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index);
    467     }
    468 
    469     VLOG(1) << "MklInputConversionOp: Shapes (output): "
    470             << context->mutable_output(0)->shape().DebugString() << " and "
    471             << context->mutable_output(1)->shape().DebugString();
    472 
    473     VLOG(1) << "MklInputConversion completed successfully.";
    474   }
    475 
    476  private:
    477   /// Data format of the operation
    478   string data_format_str;
    479 
    480   /// Data type of the operation
    481   DataType op_data_type;
    482 
    483   /// CPUIDInfo
    484   bool has_avx512f_ = false;
    485 };
    486 
    487 #endif
    488 
    489 ///////////////////////////////////////////////////////////
    490 //               Register kernel
    491 ///////////////////////////////////////////////////////////
    492 
    493 #define REGISTER_CPU(T)                                             \
    494   REGISTER_KERNEL_BUILDER(Name("_MklInputConversion")               \
    495                               .Device(DEVICE_CPU)                   \
    496                               .TypeConstraint<T>("T")               \
    497                               .Label(mkl_op_registry::kMklOpLabel), \
    498                           MklInputConversionOp<CPUDevice, T>);
    499 
    500 // TODO(nhasabni): We cannot support all number types since MklDnn does
    501 // not support types.
    502 // TF_CALL_NUMBER_TYPES(REGISTER_CPU);
    503 TF_CALL_float(REGISTER_CPU);
    504 #undef REGISTER_CPU
    505 }  // namespace tensorflow
    506 #endif  // INTEL_MKL
    507