Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifdef INTEL_MKL
     17 
     18 #include <memory>
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/framework/register_types.h"
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/framework/tensor_shape.h"
     23 #include "tensorflow/core/framework/types.h"
     24 #include "tensorflow/core/lib/core/status.h"
     25 #include "tensorflow/core/platform/logging.h"
     26 
     27 #include "mkl_dnn.h"
     28 #include "mkl_dnn_types.h"
     29 #include "tensorflow/core/util/mkl_util.h"
     30 
     31 #ifndef INTEL_MKL_ML
     32 #include "mkldnn.hpp"
     33 using mkldnn::stream;
     34 #endif
     35 
     36 namespace tensorflow {
     37 using CPUDevice = Eigen::ThreadPoolDevice;
     38 template <typename Device, typename T>
     39 class MklReshapeOp : public OpKernel {
     40  public:
     41   explicit MklReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}
     42 
     43 #ifdef INTEL_MKL_ML
     44   void Compute(OpKernelContext* context) override {
     45     const Tensor& input = MklGetInput(context, 0);
     46     const Tensor& sizes = MklGetInput(context, 1);
     47 
     48     // Preliminary validation of sizes.
     49     OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
     50                 errors::InvalidArgument("sizes input must be 1-D, not shape ",
     51                                         sizes.shape().DebugString()));
     52 
     53     // Compute the output shape.  Determine product of specified
     54     // dimensions, and find the index of the unspecified one.
     55     TensorShape shape;
     56     int64 product = 1;
     57     int unknown_index = -1;
     58     switch (sizes.dtype()) {
     59       case DT_INT32:
     60         OP_REQUIRES_OK(context, ValidateSizes<int32>(sizes, &product,
     61                                                      &unknown_index, &shape));
     62         break;
     63       case DT_INT64:
     64         OP_REQUIRES_OK(context, ValidateSizes<int64>(sizes, &product,
     65                                                      &unknown_index, &shape));
     66         break;
     67       default:
     68         context->CtxFailure(errors::InvalidArgument(
     69             "desired shape must be a DT_INT32 or DT_INT64 vector, not a ",
     70             DataTypeString(sizes.dtype())));
     71         return;
     72     }
     73     if (unknown_index != -1) {
     74       OP_REQUIRES(
     75           context, product > 0,
     76           errors::InvalidArgument("Reshape cannot infer the missing input size "
     77                                   "for an empty tensor unless all specified "
     78                                   "input sizes are non-zero"));
     79       const int64 missing = input.NumElements() / product;
     80       OP_REQUIRES(
     81           context, product * missing == input.NumElements(),
     82           errors::InvalidArgument(
     83               "Input to reshape is a tensor with ", input.NumElements(),
     84               " values, but the requested shape requires a multiple of ",
     85               product));
     86       shape.set_dim(unknown_index, missing);
     87     }
     88     OP_REQUIRES(context, shape.num_elements() == input.NumElements(),
     89                 errors::InvalidArgument("Input to reshape is a tensor with ",
     90                                         input.NumElements(),
     91                                         " values, but the requested shape has ",
     92                                         shape.num_elements()));
     93 
     94     MklShape mkl_shape_input;
     95     GetMklShape(context, 0, &mkl_shape_input);
     96     bool input_in_mkl_format = mkl_shape_input.IsMklTensor();
     97     if (input_in_mkl_format) {
     98       TensorShape& shape_to = shape;
     99       TensorShape shape_from;
    100       for (size_t i = 0; i < mkl_shape_input.GetDimension(); i++) {
    101         // Outermost to innermost dimension
    102         shape_from.AddDim(
    103             mkl_shape_input.GetSizes()[mkl_shape_input.tf_dim_idx(i)]);
    104       }
    105 
    106       if (shape_from == shape_to) {
    107         CopyMklTensorInToOut(context, 0, 0);
    108         return;
    109       } else {
    110         // Allocate output tensor.
    111         Tensor* output_tensor = NULL;
    112         MklShape mkl_shape_output;
    113         mkl_shape_output.SetMklTensor(false);
    114         AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
    115                                   mkl_shape_output);
    116 
    117         // Get output layout pointer.
    118         dnnLayout_t output_layout =
    119             static_cast<dnnLayout_t>(mkl_shape_input.GetTfLayout());
    120 
    121         // Execute DNNConversion.
    122         // Note: we  assume an MKL tensor always have float as its data type.
    123         void* input_buffer =
    124             static_cast<void*>(const_cast<float*>(input.flat<float>().data()));
    125         void* output_buffer = static_cast<void*>(
    126             const_cast<float*>(output_tensor->flat<float>().data()));
    127         mkl_shape_input.GetConvertedFlatData(output_layout, input_buffer,
    128                                              output_buffer);
    129 
    130         VLOG(1) << "MKLToTFConversion complete successfully.";
    131         return;
    132       }
    133     } else {
    134       CopyTfTensorInToOutWithShape(context, 0, 0, shape);
    135     }
    136   }
    137 
    138 #else
    139 
    140  private:
    141   // When the input tensor is in MKL layout and we are reshaping the tensor to a
    142   // different shape than its actual shape, then we use MKLDNN reorder primitive
    143   // to put tensor back in Tensorflow layout. But we can skip this reordering
    144   // some times. This function checks for all such cases.
    145   bool SkipReorder(const MklDnnShape& mkl_shape_input,
    146                    const TensorShape& reshape_to) {
    147     CHECK_EQ(mkl_shape_input.IsMklTensor(), true);
    148     bool ret = false;
    149 
    150     // If Tensorflow's data format and the underlying format maintained by
    151     // MKLDNN are equivalent (both are NHWC or both are NCHW), then we can
    152     // safely return true.
    153     auto input_mkl_md = mkl_shape_input.GetMklLayout();
    154     if (mkl_shape_input.GetTfDataFormat() == input_mkl_md.data.format) {
    155       ret = true;
    156     }
    157 
    158     return ret;
    159   }
    160 
    161  public:
    162   void Compute(OpKernelContext* context) override {
    163     const Tensor& input_tensor = MklGetInput(context, 0);
    164     const Tensor& sizes = MklGetInput(context, 1);
    165 
    166     MklDnnShape mkl_shape_input;
    167     GetMklShape(context, kInputSlotIdx, &mkl_shape_input);
    168     bool input_in_mkl_format = mkl_shape_input.IsMklTensor();
    169     const int64 nelems = input_in_mkl_format
    170                              ? mkl_shape_input.GetTfShape().num_elements()
    171                              : input_tensor.NumElements();
    172 
    173     // Preliminary validation of sizes.
    174     OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
    175                 errors::InvalidArgument("sizes input must be 1-D, not shape ",
    176                                         sizes.shape().DebugString()));
    177 
    178     // Compute the output shape.  Determine product of specified
    179     // dimensions, and find the index of the unspecified one.
    180     TensorShape shape;
    181     int64 product = 1;
    182     int unknown_index = -1;
    183     switch (sizes.dtype()) {
    184       case DT_INT32:
    185         OP_REQUIRES_OK(context, ValidateSizes<int32>(sizes, &product,
    186                                                      &unknown_index, &shape));
    187         break;
    188       case DT_INT64:
    189         OP_REQUIRES_OK(context, ValidateSizes<int64>(sizes, &product,
    190                                                      &unknown_index, &shape));
    191         break;
    192       default:
    193         context->CtxFailure(errors::InvalidArgument(
    194             "desired shape must be a DT_INT32 or DT_INT64 vector, not a ",
    195             DataTypeString(sizes.dtype())));
    196         return;
    197     }
    198     if (unknown_index != -1) {
    199       OP_REQUIRES(
    200           context, product > 0,
    201           errors::InvalidArgument("Reshape cannot infer the missing input size "
    202                                   "for an empty tensor unless all specified "
    203                                   "input sizes are non-zero"));
    204       const int64 missing = nelems / product;
    205       OP_REQUIRES(
    206           context, product * missing == nelems,
    207           errors::InvalidArgument(
    208               "Input to reshape is a tensor with ", nelems,
    209               " values, but the requested shape requires a multiple of ",
    210               product));
    211       shape.set_dim(unknown_index, missing);
    212     }
    213     OP_REQUIRES(
    214         context, shape.num_elements() == nelems,
    215         errors::InvalidArgument("Input to reshape is a tensor with ", nelems,
    216                                 " values, but the requested shape has ",
    217                                 shape.num_elements()));
    218 
    219     if (input_in_mkl_format) {
    220       TensorShape& shape_to = shape;
    221       TensorShape shape_from = mkl_shape_input.GetTfShape();
    222       if (shape_from == shape_to) {
    223         CopyMklTensorInToOut(context, kInputSlotIdx, kOutputSlotIdx);
    224         return;
    225       } else {
    226         try {
    227           auto cpu_engine = engine(engine::cpu, 0);
    228           MklDnnData<T> dnn_data_input(&cpu_engine);
    229           // Reshape is just a logical view change operation for a tensor.
    230           // It does not change underlying layout. But MKLDNN may maintain
    231           // tensor data in different layout than that specified by Tensorflow.
    232           // If MKLDNN maintains input tensor in different layout than that
    233           // specified by Tensorflow, we will need to reorder tensor and then
    234           // put it in the shape expected by Tensorflow. But if MKLDNN has
    235           // maintained input tensor in the same layout as it is expected by
    236           // Tensorflow, we don't need to reorder tensor contents, we just
    237           // need to update MklDnnShape object associated with the input
    238           // tensor to reflect the shape change expected by reshape.
    239           if (!SkipReorder(mkl_shape_input, shape_to)) {
    240             // If dimensions that are being expanded or collapsed are not
    241             // maintained contiguously by MKLDNN, then we use reorder.
    242 
    243             // Get Mkl layout of input tensor.
    244             auto input_mkl_md = mkl_shape_input.GetMklLayout();
    245             // Set input Mkl layout as the user layout.
    246             dnn_data_input.SetUsrMem(input_mkl_md, &input_tensor);
    247             // Get expected Tensorflow layout of input tensor.
    248             auto output_tf_md = mkl_shape_input.GetTfLayout();
    249             auto output_tf_pd =
    250                 memory::primitive_desc(output_tf_md, cpu_engine);
    251 
    252             Tensor* output_tensor = nullptr;
    253             MklShape mkl_shape_output;
    254             mkl_shape_output.SetMklTensor(false);
    255             // We allocate output tensor in the shape expected by Reshape.
    256             AllocateOutputSetMklShape(context, kOutputSlotIdx, &output_tensor,
    257                                       shape_to, mkl_shape_output);
    258 
    259             // Insert reorder between Mkl layout and TensorFlow layout if
    260             // needed. If reorder is not needed but reshape is needed (since
    261             // shape_from != shape_to), then we just copy input tensor to
    262             // output tensor with target shape (we cannot forward Mkl layout
    263             // in such case because shape has changed.)
    264             std::vector<primitive> net;
    265             if (dnn_data_input.CheckReorderToOpMem(output_tf_pd, output_tensor,
    266                                                    &net)) {
    267               stream(stream::kind::eager).submit(net).wait();
    268             } else {
    269               output_tensor->CopyFrom(input_tensor, shape_to);
    270             }
    271             return;
    272           } else {
    273             // If dimensions that are being expanded or collapsed are
    274             // maintained contiguously by MKLDNN, then we skip reorder, just
    275             // update MklDnnShape object for the tensorflow tensor, and forward
    276             // Tensorflow tensor as it is to the output.
    277             auto output_dims = TFShapeToMklDnnDims(shape_to);
    278             auto output_strides = CalculateTFStrides(output_dims);
    279             auto output_tf_md = MklDnnData<T>::CreateBlockedMemDesc(
    280                 output_dims, output_strides);
    281             auto output_tf_pd =
    282                 memory::primitive_desc(output_tf_md, cpu_engine);
    283 
    284             // Set MklDnnShape
    285             MklDnnShape mkl_shape_output;
    286             mkl_shape_output.SetMklTensor(true);
    287             mkl_shape_output.SetMklLayout(&output_tf_pd);
    288             mkl_shape_output.SetElemType(MklDnnType<T>());
    289             mkl_shape_output.SetTfLayout(output_dims.size(), output_dims,
    290                                          memory::format::blocked);
    291 
    292             // We now simply forward input Mkl tensor to output and change its
    293             // output MklDnnShape object.
    294             ForwardMklTensorInToOutWithMklShape(
    295                 context, kInputSlotIdx, kOutputSlotIdx, mkl_shape_output);
    296             return;
    297           }
    298         } catch (mkldnn::error& e) {
    299           string error_msg = "Status: " + std::to_string(e.status) +
    300                              ", message: " + string(e.message) + ", in file " +
    301                              string(__FILE__) + ":" + std::to_string(__LINE__);
    302           OP_REQUIRES_OK(
    303               context,
    304               errors::Aborted("Operation received an exception:", error_msg));
    305         }
    306       }
    307     } else {
    308       // If input tensor is not in Mkl format, then just copy Tensorflow tensor
    309       // to output with specified shape.
    310       CopyTfTensorInToOutWithShape(context, kInputSlotIdx, kOutputSlotIdx,
    311                                    shape);
    312     }
    313   }
    314 
    315 #endif  // INTEL_MKL_ML
    316 
    317  private:
    318   const int kInputSlotIdx = 0;
    319   const int kOutputSlotIdx = 0;
    320 
    321   template <typename Tshape>
    322   Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index,
    323                        TensorShape* shape) {
    324     *product = 1;
    325     *unknown_index = -1;
    326     const int64 num_dims = sizes.NumElements();
    327     auto Svec = sizes.flat<Tshape>();
    328     for (int d = 0; d < num_dims; ++d) {
    329       const Tshape size = Svec(d);
    330       if (size == -1) {
    331         if (*unknown_index != -1) {
    332           return errors::InvalidArgument(
    333               "Only one input size may be -1, not both ", *unknown_index,
    334               " and ", d);
    335         }
    336         *unknown_index = d;
    337         shape->AddDim(1);
    338       } else if (size < 0) {
    339         return errors::InvalidArgument("Size ", d,
    340                                        " must be non-negative, not ", size);
    341       } else {
    342         shape->AddDim(size);
    343         (*product) *= size;
    344       }
    345     }
    346     return Status::OK();
    347   }
    348 };
    349 
    350 #define REGISTER_MKL_CPU(T)                                         \
    351   REGISTER_KERNEL_BUILDER(Name("_MklReshape")                       \
    352                               .Device(DEVICE_CPU)                   \
    353                               .HostMemory("shape")                  \
    354                               .TypeConstraint<T>("T")               \
    355                               .TypeConstraint<int32>("Tshape")      \
    356                               .Label(mkl_op_registry::kMklOpLabel), \
    357                           MklReshapeOp<CPUDevice, T>);              \
    358   REGISTER_KERNEL_BUILDER(Name("_MklReshape")                       \
    359                               .Device(DEVICE_CPU)                   \
    360                               .HostMemory("shape")                  \
    361                               .TypeConstraint<T>("T")               \
    362                               .TypeConstraint<int64>("Tshape")      \
    363                               .Label(mkl_op_registry::kMklOpLabel), \
    364                           MklReshapeOp<CPUDevice, T>);
    365 TF_CALL_float(REGISTER_MKL_CPU);
    366 #undef REGISTER_MKL_CPU
    367 }  // namespace tensorflow
    368 
    369 #endif  // INTEL_MKL
    370