Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 Licensed under the Apache License, Version 2.0 (the "License");
      3 you may not use this file except in compliance with the License.
      4 You may obtain a copy of the License at
      5     http://www.apache.org/licenses/LICENSE-2.0
      6 Unless required by applicable law or agreed to in writing, software
      7 distributed under the License is distributed on an "AS IS" BASIS,
      8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      9 See the License for the specific language governing permissions and
     10 limitations under the License.
     11 ==============================================================================*/
     12 
     13 #ifdef INTEL_MKL
     14 
     15 #include <limits>
     16 #include <vector>
     17 
     18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/framework/register_types.h"
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/framework/tensor_types.h"
     23 #include "tensorflow/core/framework/types.h"
     24 #include "tensorflow/core/kernels/bounds_check.h"
     25 #include "tensorflow/core/kernels/concat_lib.h"
     26 #include "tensorflow/core/lib/core/status.h"
     27 #include "tensorflow/core/platform/types.h"
     28 
     29 #include "mkl_dnn.h"
     30 #include "mkl_dnn_types.h"
     31 #include "tensorflow/core/util/mkl_util.h"
     32 
     33 #ifndef INTEL_MKL_ML
     34 #include "mkldnn.hpp"
     35 
     36 using mkldnn::concat;
     37 using mkldnn::stream;
     38 #endif
     39 
     40 namespace tensorflow {
     41 typedef Eigen::ThreadPoolDevice CPUDevice;
     42 
     43 // List of TensorShape objects. Used in Concat/Split layers.
     44 typedef std::vector<TensorShape> TensorShapeList;
     45 
     46 enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };
     47 
     48 // TODO(intelft) Check if we can reuse existing EigenConcatOp using Mutable
     49 // reference inputs.
     50 // --------------------------------------------------------------------------
     51 //                      Eigen Concat Op
     52 // --------------------------------------------------------------------------
     53 template <typename Device, typename T, AxisArgumentName AxisArgName>
     54 class EigenConcatBaseOp : public OpKernel {
     55  public:
     56   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
     57       ConstMatrixVector;
     58 
     59   explicit EigenConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}
     60 
     61   // Although, we modify Compute for this call to accept one extra param,
     62   // we need to have empty Compute because Compute is pure virtual function.
     63   void Compute(OpKernelContext* c) {}
     64 
     65 #ifdef INTEL_MKL_ML
     66 
     67   void Compute(OpKernelContext* c, const std::vector<Tensor>& values) {
     68     const Tensor* concat_dim_tensor;
     69     const char* axis_attribute_name =
     70         AxisArgName == NAME_IS_AXIS
     71             ? "axis"
     72             : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
     73     OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
     74     OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()),
     75                 errors::InvalidArgument(
     76                     axis_attribute_name,
     77                     " tensor should be a scalar integer, but got shape ",
     78                     concat_dim_tensor->shape().DebugString()));
     79     const int32 concat_dim =
     80         internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()());
     81     // Instead of accessing values from context, we use input to Compute.
     82     const int N = values.size();
     83     const int input_dims = values[0].dims();
     84     const TensorShape& input_shape = values[0].shape();
     85 
     86     int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
     87     OP_REQUIRES(c,
     88                 (0 <= axis && axis < input_dims) ||
     89                     (allow_legacy_scalars() && concat_dim == 0),
     90                 errors::InvalidArgument(
     91                     "ConcatOp : Expected concatenating dimensions in the range "
     92                     "[",
     93                     -input_dims, ", ", input_dims, "), but got ", concat_dim));
     94     // Note that we reduce the concat of n-dimensional tensors into a two
     95     // dimensional concat. Assuming the dimensions of any input/output
     96     // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along
     97     // the dimension indicated with size y0, we flatten it to {x, y}, where y =
     98     // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1).
     99     ConstMatrixVector inputs_flat;
    100     inputs_flat.reserve(N);
    101     int64 inputs_flat_dim0 = 1;
    102     for (int d = 0; d < axis; ++d) {
    103       inputs_flat_dim0 *= input_shape.dim_size(d);
    104     }
    105     int64 output_concat_dim = 0;
    106     const bool input_is_scalar = IsLegacyScalar(input_shape);
    107     for (int i = 0; i < N; ++i) {
    108       const auto in = values[i];
    109       const bool in_is_scalar = IsLegacyScalar(in.shape());
    110       OP_REQUIRES(
    111           c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
    112           errors::InvalidArgument(
    113               "ConcatOp : Ranks of all input tensors should match: shape[0] = ",
    114               input_shape.DebugString(), " vs. shape[", i,
    115               "] = ", in.shape().DebugString()));
    116       for (int j = 0; j < input_dims; ++j) {
    117         if (j == axis) {
    118           continue;
    119         }
    120         OP_REQUIRES(
    121             c, in.dim_size(j) == input_shape.dim_size(j),
    122             errors::InvalidArgument(
    123                 "ConcatOp : Dimensions of inputs should match: shape[0] = ",
    124                 input_shape.DebugString(), " vs. shape[", i,
    125                 "] = ", in.shape().DebugString()));
    126       }
    127       if (in.NumElements() > 0) {
    128         int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
    129         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
    130             in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
    131       }
    132       // TODO(irving): Remove check once !allow_legacy_scalars().
    133       output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
    134     }
    135 
    136     TensorShape output_shape(input_shape);
    137     // TODO(irving): Remove rank 0 case once !allow_legacy_scalars().
    138     if (output_shape.dims() == 0) {
    139       output_shape.AddDim(output_concat_dim);
    140     } else {
    141       output_shape.set_dim(axis, output_concat_dim);
    142     }
    143     Tensor* output = nullptr;
    144     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
    145     if (output->NumElements() > 0) {
    146       int64 output_dim1 = output->NumElements() / inputs_flat_dim0;
    147       auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
    148       ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
    149     }
    150   }
    151 
    152 #else  // MKL_DNN
    153 
    154   void Compute(OpKernelContext* c, const std::vector<Tensor>& values,
    155                const TensorShapeList& input_shapes) {
    156     const Tensor* concat_dim_tensor;
    157     const char* axis_attribute_name =
    158         AxisArgName == NAME_IS_AXIS
    159             ? "axis"
    160             : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
    161     OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
    162     OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()),
    163                 errors::InvalidArgument(
    164                     axis_attribute_name,
    165                     " tensor should be a scalar integer, but got shape ",
    166                     concat_dim_tensor->shape().DebugString()));
    167     const int32 concat_dim =
    168         internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()());
    169     // Instead of accessing values from context, we use input to Compute.
    170     const int N = values.size();
    171     const int input_dims = input_shapes[0].dims();
    172     const TensorShape& input_shape = input_shapes[0];
    173 
    174     int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
    175     OP_REQUIRES(c,
    176                 (0 <= axis && axis < input_dims) ||
    177                     (allow_legacy_scalars() && concat_dim == 0),
    178                 errors::InvalidArgument(
    179                     "ConcatOp : Expected concatenating dimensions in the range "
    180                     "[",
    181                     -input_dims, ", ", input_dims, "), but got ", concat_dim));
    182     // Note that we reduce the concat of n-dimensional tensors into a two
    183     // dimensional concat. Assuming the dimensions of any input/output
    184     // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along
    185     // the dimension indicated with size y0, we flatten it to {x, y}, where y =
    186     // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1).
    187     ConstMatrixVector inputs_flat;
    188     inputs_flat.reserve(N);
    189     int64 inputs_flat_dim0 = 1;
    190     for (int d = 0; d < axis; ++d) {
    191       inputs_flat_dim0 *= input_shape.dim_size(d);
    192     }
    193     int64 output_concat_dim = 0;
    194     const bool input_is_scalar = IsLegacyScalar(input_shape);
    195     for (int i = 0; i < N; ++i) {
    196       const auto in = values[i];
    197       const bool in_is_scalar = IsLegacyScalar(input_shapes[i]);
    198       OP_REQUIRES(
    199           c,
    200           (input_shapes[i].dims() == input_dims) ||
    201               (input_is_scalar && in_is_scalar),
    202           errors::InvalidArgument(
    203               "ConcatOp : Ranks of all input tensors should match: shape[0] = ",
    204               input_shape.DebugString(), " vs. shape[", i,
    205               "] = ", input_shapes[i].DebugString()));
    206       if (in.NumElements() > 0) {
    207         int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
    208         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
    209             in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
    210       }
    211       output_concat_dim +=
    212           input_shapes[i].dims() > 0 ? input_shapes[i].dim_size(axis) : 1;
    213     }
    214 
    215     TensorShape output_shape(input_shape);
    216     if (output_shape.dims() == 0) {
    217       output_shape.AddDim(output_concat_dim);
    218     } else {
    219       output_shape.set_dim(axis, output_concat_dim);
    220     }
    221     Tensor* output = nullptr;
    222     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
    223     if (output->NumElements() > 0) {
    224       int64 output_dim1 = output->NumElements() / inputs_flat_dim0;
    225       auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
    226       ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
    227     }
    228   }
    229 
    230 #endif
    231 };
    232 
    233 #ifdef INTEL_MKL_ML
    234 
    235 // --------------------------------------------------------------------------
    236 //                      Mkl Concat Op
    237 // --------------------------------------------------------------------------
    238 
    239 template <typename Device, typename T, AxisArgumentName AxisArgName>
    240 class MklConcatOp : public OpKernel {
    241  private:
    242   TensorFormat data_format_;
    243   EigenConcatBaseOp<Device, T, AxisArgName> eigen_concat_op_;
    244 
    245  public:
    246   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
    247       ConstMatrixVector;
    248 
    249   explicit MklConcatOp(OpKernelConstruction* c)
    250       : OpKernel(c), eigen_concat_op_(c) {}
    251 
    252   void Compute(OpKernelContext* context) override {
    253     MklConcatOpContext mkl_context;
    254 
    255     // Get input tensors.
    256     OpInputList input_tensors;
    257     GetMklInputList(context, "values", &input_tensors);
    258     const int N = input_tensors.size();
    259     // Get MKL shapes.
    260     MklShapeList input_shapes(N);
    261     GetMklShapeList(context, "values", &input_shapes);
    262 
    263     // If this is Concat, then concat_dim is 0th input.
    264     // If this is ConcatV2, then axis is Nth input.
    265     const Tensor& concat_dim_tensor = AxisArgName == NAME_IS_CONCAT_DIM
    266                                           ? MklGetInput(context, 0)
    267                                           : MklGetInput(context, N);
    268 
    269     // Sanity checks
    270     OP_REQUIRES(
    271         context, IsLegacyScalar(concat_dim_tensor.shape()),
    272         errors::InvalidArgument(
    273             "Concat dim tensor should be a scalar integer, but got shape ",
    274             concat_dim_tensor.shape().DebugString()));
    275     int32 concat_dim =
    276         internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()());
    277 
    278     MklShape& inpshape0 = input_shapes[0];
    279 
    280     // Check that all tensors are Mkl, if not we call Eigen version.
    281     bool invoke_eigen = false;
    282     bool is_concat_dim_channel = true;
    283     if (!AreAllMklTensors(input_shapes)) {
    284       invoke_eigen = true;
    285     }
    286 
    287     // Check that total number of dimensions is 4, if not call Eigen.
    288     if (!invoke_eigen) {
    289       for (auto& s : input_shapes) {
    290         if (s.GetDimension() != 4) {
    291           invoke_eigen = true;
    292           break;
    293         }
    294       }
    295     }
    296 
    297     // check that concat_dim is channel, if not call Eigen version.
    298     if (!invoke_eigen) {
    299       for (auto& s : input_shapes) {
    300         if (!s.IsMklChannelDim(concat_dim)) {
    301           invoke_eigen = true;
    302           is_concat_dim_channel = false;
    303           break;
    304         }
    305       }
    306     }
    307 
    308     if (invoke_eigen) {
    309       string msg = std::string("Invoking Eigen version of Concat. Reason:") +
    310                    (!is_concat_dim_channel
    311                         ? std::string("Concat dimension is not channel")
    312                         : std::string("Not all tensors are in Mkl layout"));
    313       VLOG(1) << "_MklConcatOp: " << msg;
    314       CallEigenVersion(context, input_tensors, input_shapes);
    315       return;
    316     }
    317 
    318     // For MKL format, the channel is dimension number 2.
    319     // So if we are concating over channel and _all_ inputs are in MKL
    320     // format, then we set concat_dim to 2.
    321     // Since we have reached till here, it means we are concating
    322     // over channel.
    323     concat_dim = MklDims::C;
    324 
    325     // One more sanity check: check that ranks of all tensors match
    326     // and that their shapes match except for concat_dim.
    327     int i = 0;
    328     for (auto& s : input_shapes) {
    329       size_t exp_dims = inpshape0.GetDimension();
    330       OP_REQUIRES(context, s.GetDimension() == exp_dims,
    331                   errors::InvalidArgument(
    332                       "_MklConcatOp : Ranks of all input tensors should match:"
    333                       " input dimensions = ",
    334                       s.GetDimension(), " vs. expected rank = ", exp_dims));
    335 
    336       for (int d = 0; d < exp_dims; ++d) {
    337         if (d == concat_dim) {
    338           continue;
    339         }
    340 
    341         size_t exp_size = inpshape0.GetSizes()[d];
    342         OP_REQUIRES(
    343             context, exp_size == s.GetSizes()[d],
    344             errors::InvalidArgument("_MklConcatOp : Dimensions of inputs"
    345                                     "should match: shape[0][",
    346                                     d, "]= ", exp_size, " vs. shape[", i, "][",
    347                                     d, "] = ", s.GetSizes()[d]));
    348       }
    349       ++i;
    350     }
    351 
    352     // Use input MKL layout instead of creating new layouts.
    353     int64 output_concat_dim_size = 0;
    354     for (auto& s : input_shapes) {
    355       output_concat_dim_size +=
    356           s.GetDimension() > 0 ? s.GetSizes()[concat_dim] : 1;
    357     }
    358     mkl_context.MklCreateInputLayouts(context, input_shapes);
    359     OP_REQUIRES_OK(context, context->status());
    360 
    361     CHECK_EQ(dnnConcatCreate_F32(&mkl_context.prim_concat, NULL, N,
    362                                  &mkl_context.lt_inputs[0]),
    363              E_SUCCESS);
    364 
    365     // Calculate output sizes and strides
    366     TensorFormat data_format;
    367     if (inpshape0.IsTensorInNHWCFormat()) {
    368       data_format = FORMAT_NHWC;
    369     } else {
    370       OP_REQUIRES(
    371           context, inpshape0.IsTensorInNCHWFormat(),
    372           errors::InvalidArgument(
    373               "_MklConcat only supports all inputs in NCHW or NHWC format "));
    374       data_format = FORMAT_NCHW;
    375     }
    376 
    377     // Since all tensors are in Mkl layout, we copy sizes from input tensor.
    378     mkl_context.out_sizes[MklDims::W] = inpshape0.GetSizes()[MklDims::W];
    379     mkl_context.out_sizes[MklDims::H] = inpshape0.GetSizes()[MklDims::H];
    380     mkl_context.out_sizes[MklDims::C] = output_concat_dim_size;
    381     mkl_context.out_sizes[MklDims::N] = inpshape0.GetSizes()[MklDims::N];
    382     GetStridesFromSizes(data_format, mkl_context.out_strides,
    383                         mkl_context.out_sizes);
    384 
    385     // Set output Mkl shape.
    386     int64 dim = 4;
    387     MklShape mkl_output_mkl_shape;
    388     mkl_output_mkl_shape.SetMklTensor(true);
    389     mkl_output_mkl_shape.SetMklLayout(mkl_context.prim_concat, dnnResourceDst);
    390     mkl_output_mkl_shape.SetTfLayout(dim, mkl_context.out_sizes,
    391                                      mkl_context.out_strides);
    392     mkl_output_mkl_shape.SetTfDimOrder(dim, inpshape0.GetTfToMklDimMap());
    393 
    394     TensorShape mkl_output_tf_shape;
    395     mkl_output_tf_shape.AddDim(1);
    396     mkl_output_tf_shape.AddDim(
    397         dnnLayoutGetMemorySize_F32(
    398             static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
    399         sizeof(T));
    400 
    401     Tensor* output = nullptr;
    402     AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape,
    403                               mkl_output_mkl_shape);
    404 
    405     // Set destination resource.
    406     mkl_context.concat_res[dnnResourceDst] =
    407         const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
    408 
    409     mkl_context.mkl_tmp_tensors.resize(N);
    410     mkl_context.MklPrepareConcatInputs(context, input_tensors);
    411     OP_REQUIRES_OK(context, context->status());
    412 
    413     // Execute primitive.
    414     CHECK_EQ(dnnExecute_F32(mkl_context.prim_concat, mkl_context.concat_res),
    415              E_SUCCESS);
    416 
    417     mkl_context.MklCleanup();
    418     OP_REQUIRES_OK(context, context->status());
    419   }
    420 
    421  private:
    422   typedef struct {
    423     TensorFormat data_format;
    424     size_t out_sizes[4];
    425     size_t out_strides[4];
    426     dnnPrimitive_t prim_concat;
    427     void* concat_res[dnnResourceNumber];
    428     std::vector<dnnLayout_t> lt_inputs;
    429     std::vector<Tensor> mkl_tmp_tensors;
    430 
    431     // Create MKL dnnLayout_t objects for tensors coming into the layer
    432     // We only support case where input tensors are all in Mkl layout.
    433     void MklCreateInputLayouts(OpKernelContext* context,
    434                                MklShapeList& input_shapes) {
    435       for (auto& is : input_shapes) {
    436         CHECK_EQ(is.IsMklTensor(), true);
    437         lt_inputs.push_back((dnnLayout_t)is.GetCurLayout());
    438       }
    439     }
    440 
    441     void MklPrepareConcatInputs(OpKernelContext* context,
    442                                 OpInputList& input_tensors) {
    443       CHECK_EQ(lt_inputs.size(), mkl_tmp_tensors.size());
    444 
    445       for (int i = 0; i < lt_inputs.size(); ++i) {
    446         dnnPrimitive_t mkl_prim_convert_input;
    447         dnnLayout_t mkl_lt_internal_input;
    448         void* mkl_buf_convert_input = nullptr;
    449 
    450         CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    451                      &mkl_lt_internal_input, prim_concat,
    452                      (dnnResourceType_t)(dnnResourceMultipleSrc + i)),
    453                  E_SUCCESS);
    454 
    455         if (!dnnLayoutCompare_F32(lt_inputs[i], mkl_lt_internal_input)) {
    456           CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input,
    457                                            lt_inputs[i], mkl_lt_internal_input),
    458                    E_SUCCESS);
    459 
    460           AllocTmpBuffer(context, &mkl_tmp_tensors[i], mkl_lt_internal_input,
    461                          &mkl_buf_convert_input);
    462 
    463           CHECK_EQ(dnnConversionExecute_F32(
    464                        mkl_prim_convert_input,
    465                        const_cast<void*>(static_cast<const void*>(
    466                            input_tensors[i].flat<T>().data())),
    467                        mkl_buf_convert_input),
    468                    E_SUCCESS);
    469 
    470           concat_res[dnnResourceMultipleSrc + i] = mkl_buf_convert_input;
    471           CHECK_EQ(dnnDelete_F32(mkl_prim_convert_input), E_SUCCESS);
    472         } else {
    473           concat_res[dnnResourceMultipleSrc + i] = const_cast<void*>(
    474               static_cast<const void*>(input_tensors[i].flat<T>().data()));
    475         }
    476 
    477         CHECK_EQ(dnnLayoutDelete_F32(mkl_lt_internal_input), E_SUCCESS);
    478       }
    479     }
    480 
    481     void MklCleanup() {
    482       for (auto& lt : lt_inputs) {
    483         lt = nullptr;
    484       }
    485       CHECK_EQ(dnnDelete_F32(prim_concat), E_SUCCESS);
    486     }
    487   } MklConcatOpContext;
    488 
    489   void CallEigenVersion(OpKernelContext* context, const OpInputList& values,
    490                         const MklShapeList& input_shapes) {
    491     // Before calling Eigen version, we need to convert Mkl tensors to TF.
    492     // First check that the number of input tensors and the number of Mkl
    493     // shapes match.
    494     CHECK_EQ(values.size(), input_shapes.size());
    495 
    496     std::vector<Tensor> converted_values;
    497     for (int i = 0; i < input_shapes.size(); i++) {
    498       if (input_shapes[i].IsMklTensor()) {
    499         // If input tensor is Mkl, then do the conversion.
    500         Tensor tmp_tensor =
    501             ConvertMklToTF<T>(context, values[i], input_shapes[i]);
    502         converted_values.push_back(tmp_tensor);
    503       } else {
    504         // If input tensor is TF already, then we do not need any conversion.
    505         converted_values.push_back(values[i]);
    506       }
    507     }
    508 
    509     // Call Eigen concat.
    510     eigen_concat_op_.Compute(context, converted_values);
    511 
    512     // Set dummy Mkl tensor as output Mkl tensor for this op.
    513     MklShape mkl_tensor_mkl_shape;
    514     mkl_tensor_mkl_shape.SetMklTensor(false);
    515     mkl_tensor_mkl_shape.SetDimensions(4);
    516     mkl_tensor_mkl_shape.SetTfDimOrder(4);  // Dimensions
    517     Tensor* mkl_tensor = nullptr;
    518     TensorShape mkl_tensor_tf_shape;
    519     mkl_tensor_tf_shape.AddDim(
    520         SIZE_OF_MKL_SERIAL_DATA(mkl_tensor_mkl_shape.GetDimension()));
    521     int tf_output_index = 0;
    522     context->allocate_output(
    523         GetTensorMetaDataIndex(tf_output_index, context->num_outputs()),
    524         mkl_tensor_tf_shape, &mkl_tensor);
    525     mkl_tensor_mkl_shape.SerializeMklShape(
    526         mkl_tensor->flat<uint8>().data(),
    527         mkl_tensor->flat<uint8>().size() * sizeof(uint8));
    528   }
    529 
    530   // overloading methods with input shapes as a list of TensorShape's
    531   void CallEigenVersion(OpKernelContext* context, const OpInputList& values,
    532                         const TensorShapeList& input_shapes) {
    533     CHECK_EQ(values.size(), input_shapes.size());
    534 
    535     std::vector<Tensor> converted_values;
    536     for (int i = 0; i < input_shapes.size(); i++) {
    537       converted_values.push_back(values[i]);
    538     }
    539 
    540     // Call Eigen concat.
    541     eigen_concat_op_.Compute(context, converted_values);
    542 
    543     // Set dummy Mkl tensor as output Mkl tensor for this op.
    544     MklShape mkl_tensor_mkl_shape;
    545     mkl_tensor_mkl_shape.SetMklTensor(false);
    546     mkl_tensor_mkl_shape.SetDimensions(4);
    547     Tensor* mkl_tensor = nullptr;
    548     TensorShape mkl_tensor_tf_shape;
    549     mkl_tensor_tf_shape.AddDim(
    550         SIZE_OF_MKL_SERIAL_DATA(mkl_tensor_mkl_shape.GetDimension()));
    551     int tf_output_index = 0;
    552     context->allocate_output(
    553         GetTensorMetaDataIndex(tf_output_index, context->num_outputs()),
    554         mkl_tensor_tf_shape, &mkl_tensor);
    555     mkl_tensor_mkl_shape.SerializeMklShape(
    556         mkl_tensor->flat<uint8>().data(),
    557         mkl_tensor->flat<uint8>().size() * sizeof(uint8));
    558   }
    559 };
    560 
    561 #else
    562 
    563 // --------------------------------------------------------------------------
    564 //                      Mkl Concat Op
    565 // --------------------------------------------------------------------------
    566 
    567 template <typename Device, typename T, AxisArgumentName AxisArgName>
    568 class MklConcatOp : public OpKernel {
    569  private:
    570   TensorFormat data_format_;
    571   EigenConcatBaseOp<Device, T, AxisArgName> eigen_concat_op_;
    572 
    573  public:
    574   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
    575       ConstMatrixVector;
    576 
    577   explicit MklConcatOp(OpKernelConstruction* c)
    578       : OpKernel(c), eigen_concat_op_(c) {}
    579 
    580   void Compute(OpKernelContext* context) override {
    581     try {
    582       auto cpu_engine = engine(engine::cpu, 0);
    583       OpInputList input_tensors;
    584       GetMklInputList(context, "values", &input_tensors);
    585       const int N = input_tensors.size();
    586 
    587       // Get Tensor shapes.
    588       std::vector<MklDnnShape> input_shapes(N);
    589       GetMklShapeList(context, "values", &input_shapes);
    590 
    591       const Tensor& concat_dim_tensor = (AxisArgName == NAME_IS_CONCAT_DIM)
    592                                             ? MklGetInput(context, 0)
    593                                             : MklGetInput(context, N);
    594       // Sanity checks
    595       OP_REQUIRES(
    596           context, IsLegacyScalar(concat_dim_tensor.shape()),
    597           errors::InvalidArgument(
    598               "Concat dim tensor should be a scalar integer, but got shape ",
    599               concat_dim_tensor.shape().DebugString()));
    600       int32 concat_dim =
    601           internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()());
    602 
    603       // check that ranks of all tensors match
    604       // and that their shapes match except for concat_dim.
    605       int i = 0;
    606       bool invoke_eigen = false;
    607       bool are_all_mkl_inputs = true, are_all_tf_inputs = true;
    608       const TensorShape expected_shape = input_shapes[0].IsMklTensor()
    609                                              ? input_shapes[0].GetTfShape()
    610                                              : input_tensors[0].shape();
    611       size_t expected_dims = expected_shape.dims();
    612 
    613       if (concat_dim < 0) concat_dim = expected_dims + concat_dim;
    614 
    615       for (auto& s : input_shapes) {
    616         if (s == expected_shape) {
    617           ++i;
    618           continue;
    619         }
    620 
    621         TensorShape s_shape =
    622             s.IsMklTensor() ? s.GetTfShape() : input_tensors[i].shape();
    623         size_t s_dims = s_shape.dims();
    624 
    625         OP_REQUIRES(
    626             context, s_dims == expected_dims,
    627             errors::InvalidArgument(
    628                 "_MklConcatOp : Ranks of all input tensors should match:"
    629                 " input dimensions = ",
    630                 s_dims, " vs. expected rank = ", expected_dims));
    631 
    632         for (int d = 0; d < expected_dims; ++d) {
    633           if (d == concat_dim) continue;
    634 
    635           size_t expected_size = expected_shape.dim_size(d);
    636           size_t s_size = s_shape.dim_size(d);
    637           OP_REQUIRES(
    638               context, expected_size == s_size,
    639               errors::InvalidArgument("_MklConcatOp : Dimensions of inputs "
    640                                       "should match: shape[0][",
    641                                       d, "]= ", expected_size, " vs. shape[", i,
    642                                       "][", d, "] = ", s_size));
    643         }
    644 
    645         if (s.IsMklTensor())
    646           are_all_tf_inputs = false;
    647         else
    648           are_all_mkl_inputs = false;
    649 
    650         if (s_dims != 4) invoke_eigen = true;
    651         ++i;
    652       }
    653 
    654       // All inputs are not in one format (TF or MKL). This is mixed input case.
    655       // We can potentially optimize this case by converting all TF inputs
    656       // to Mkl format. But currently, we fall to Eigen for this case.
    657       // It may be possible to convert inputs that in TF format to Mkl
    658       // format and avoid calling eigen version.
    659       if (!are_all_tf_inputs && !are_all_mkl_inputs) invoke_eigen = true;
    660 
    661       // Call Eigen library
    662       if (invoke_eigen) {
    663         TensorShapeList tf_input_shapes;
    664         i = 0;
    665         for (auto& s : input_shapes) {
    666           TensorShape s_shape =
    667               s.IsMklTensor() ? s.GetTfShape() : input_tensors[i].shape();
    668           tf_input_shapes.push_back(s_shape);
    669           ++i;
    670         }
    671         CallEigenVersion(context, input_tensors, tf_input_shapes);
    672         return;
    673       }
    674 
    675       memory::dims dst_dims;
    676       if (are_all_mkl_inputs)
    677         dst_dims = TFShapeToMklDnnDims(input_shapes[0].GetTfShape());
    678       else
    679         // When all the inputs are in Tensorflow format, we don't know
    680         // what is the input data format. In that case, we just use
    681         // output format that is same as input formats.
    682         dst_dims = TFShapeToMklDnnDims(input_tensors[0].shape());
    683 
    684       std::vector<memory::primitive_desc> srcs_pd;
    685       std::vector<MklDnnData<T>> srcs(N, MklDnnData<T>(&cpu_engine));
    686       int64 dst_concat_dim_size = 0;
    687       for (int k = 0; k < N; k++) {
    688         bool is_mkl_tensor = input_shapes[k].IsMklTensor();
    689         memory::dims src_dims;
    690 
    691         // Same comment as dst_dims for src_dims.
    692         src_dims = (is_mkl_tensor)
    693                        ? TFShapeToMklDnnDims(input_shapes[k].GetTfShape())
    694                        : TFShapeToMklDnnDims(input_tensors[k].shape());
    695 
    696         dst_concat_dim_size += src_dims[concat_dim];
    697         auto src_md =
    698             is_mkl_tensor ? input_shapes[k].GetMklLayout() :
    699                           // It does not matter what data format we use here
    700                           // (NHWC or NCHW). We just need to ensure that output
    701                           // of Concat uses same data format as input.
    702                 memory::desc(src_dims, MklDnnType<T>(), memory::format::nchw);
    703 
    704         srcs[k].SetUsrMem(src_md, &input_tensors[k]);
    705         auto src_mpd = srcs[k].GetUsrMemPrimDesc();
    706         srcs_pd.push_back(src_mpd);
    707       }
    708       dst_dims[concat_dim] = dst_concat_dim_size;
    709 
    710       MklDnnData<T> dst(&cpu_engine);
    711       memory::desc dst_md({}, memory::data_undef, memory::format_undef);
    712       memory::dims dst_dims_in_nchw;
    713       if (are_all_mkl_inputs) {
    714         // Since we are passing a specific format for destination,
    715         // we need to have dst_dims in MklDnn order (NCHW).
    716         auto orig_tf_format = input_shapes[0].GetTfDataFormat();
    717         dst_dims_in_nchw = MklDnnDimsInNCHW(
    718             dst_dims, MklDnnDataFormatToTFDataFormat(orig_tf_format));
    719         // We will set the output in the same format as input to avoid layout
    720         // conversions.
    721         // Currently we are setting dst format same as input format.
    722         // See if we can make this choice in a better way.
    723         dst_md = memory::desc(
    724             dst_dims_in_nchw, MklDnnType<T>(),
    725             (memory::format)input_shapes[0].GetMklLayout().data.format);
    726       } else {
    727         // Again, format does not matter here. We just need to make it same as
    728         // input format.
    729         dst_md = memory::desc(dst_dims, MklDnnType<T>(), memory::format::nchw);
    730       }
    731 
    732       std::vector<primitive::at> inputs;
    733       for (int k = 0; k < input_tensors.size(); k++)
    734         inputs.push_back(srcs[k].GetOpMem());
    735 
    736       // If all inputs are in MKL format, then meaning of concat_dim needs to
    737       // change. Value of concat_dim is tied to input Tensorflow data format
    738       // (NHWC or NCHW). MklDnn dimensions are in NCHW order. So if Tensorflow
    739       // tensors are in NCHW order, then concat_dim semantics is preserved.
    740       // But ifinput tensors are in NHWC order, then semantics need to change.
    741       // E.g., if we are concatinating over Channel (dimension 3 for NHWC),
    742       // then since MklDnn order is NCHW, concat_dim needs to be 1.
    743       if (are_all_mkl_inputs) concat_dim = input_shapes[0].TfDimIdx(concat_dim);
    744 
    745       auto concat_pd = concat::primitive_desc(dst_md, concat_dim, srcs_pd);
    746 
    747       MklDnnShape dnn_shape_dst;
    748       TensorShape tf_shape_dst;
    749       Tensor* dst_tensor = nullptr;
    750       if (are_all_mkl_inputs) {
    751         dnn_shape_dst.SetMklTensor(true);
    752         auto dst_pd = concat_pd.dst_primitive_desc();
    753         dnn_shape_dst.SetMklLayout(&dst_pd);
    754         dnn_shape_dst.SetElemType(MklDnnType<T>());
    755         dnn_shape_dst.SetTfLayout(dst_dims.size(), dst_dims_in_nchw,
    756                                   input_shapes[0].GetTfDataFormat());
    757         tf_shape_dst.AddDim((dst_pd.get_size() / sizeof(T)));
    758       } else {
    759         dnn_shape_dst.SetMklTensor(false);
    760         tf_shape_dst = MklDnnDimsToTFShape(dst_dims);
    761       }
    762       AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst,
    763                                 dnn_shape_dst);
    764       CHECK_NOTNULL(dst_tensor);
    765 
    766       dst_md =
    767           dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout() : dst_md;
    768       dst.SetUsrMem(dst_md, dst_tensor);
    769 
    770       auto concat_op = concat(concat_pd, inputs, dst.GetOpMem());
    771       std::vector<primitive> net;
    772       net.push_back(concat_op);
    773       stream(stream::kind::eager).submit(net).wait();
    774     } catch (mkldnn::error& e) {
    775       string error_msg = "Status: " + std::to_string(e.status) +
    776                          ", message: " + string(e.message) + ", in file " +
    777                          string(__FILE__) + ":" + std::to_string(__LINE__);
    778       OP_REQUIRES_OK(
    779           context,
    780           errors::Aborted("Operation received an exception:", error_msg));
    781     }
    782   }
    783 
    784   void CallEigenVersion(OpKernelContext* context, const OpInputList& values,
    785                         const TensorShapeList& input_shapes) {
    786     CHECK_EQ(values.size(), input_shapes.size());
    787 
    788     std::vector<Tensor> converted_values;
    789     for (int i = 0; i < input_shapes.size(); i++)
    790       converted_values.push_back(values[i]);
    791 
    792     // Call Eigen concat.
    793     eigen_concat_op_.Compute(context, converted_values, input_shapes);
    794 
    795     // Set output Mkl tensor for this op.
    796     MklDnnShape dnn_shape_output;
    797     dnn_shape_output.SetMklTensor(false);
    798     dnn_shape_output.SetDimensions(4);
    799     Tensor* output_tensor = nullptr;
    800     TensorShape tf_shape_output;
    801     tf_shape_output.AddDim(dnn_shape_output.GetSerializeBufferSize());
    802     context->allocate_output(GetTensorMetaDataIndex(0, context->num_outputs()),
    803                              tf_shape_output, &output_tensor);
    804     dnn_shape_output.SerializeMklDnnShape(
    805         output_tensor->flat<uint8>().data(),
    806         output_tensor->flat<uint8>().size() * sizeof(uint8));
    807   }
    808 };
    809 
    810 #endif
    811 
    812 /* Use optimized concat for float type only */
    813 #define REGISTER_MKL_CPU(type)                                              \
    814   REGISTER_KERNEL_BUILDER(Name("_MklConcat")                                \
    815                               .Device(DEVICE_CPU)                           \
    816                               .TypeConstraint<type>("T")                    \
    817                               .HostMemory("concat_dim")                     \
    818                               .Label(mkl_op_registry::kMklOpLabel),         \
    819                           MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>) \
    820   REGISTER_KERNEL_BUILDER(Name("_MklConcatV2")                              \
    821                               .Device(DEVICE_CPU)                           \
    822                               .TypeConstraint<type>("T")                    \
    823                               .TypeConstraint<int32>("Tidx")                \
    824                               .HostMemory("axis")                           \
    825                               .Label(mkl_op_registry::kMklOpLabel),         \
    826                           MklConcatOp<CPUDevice, type, NAME_IS_AXIS>)
    827 
    828 TF_CALL_float(REGISTER_MKL_CPU);
    829 
    830 #undef REGISTER_CONCAT_MKL
    831 }  // namespace tensorflow
    832 
    833 #endif  // INTEL_MKL
    834