Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 #ifdef INTEL_MKL
     18 #ifndef INTEL_MKL_ML
     19 
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 #include "tensorflow/core/framework/numeric_op.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/lib/core/errors.h"
     26 #include "tensorflow/core/util/tensor_format.h"
     27 
     28 #include "mkldnn.h"
     29 #include "mkldnn_types.h"
     30 #include "tensorflow/core/platform/default/logging.h"
     31 #include "tensorflow/core/util/mkl_util.h"
     32 
     33 #include "mkldnn.hpp"
     34 using mkldnn::prop_kind;
     35 using mkldnn::softmax_forward;
     36 using mkldnn::stream;
     37 
     38 namespace tensorflow {
     39 
     40 typedef Eigen::ThreadPoolDevice CPUDevice;
     41 
     42 template <typename Device, typename T>
     43 class MklSoftmaxOp : public OpKernel {
     44  public:
     45   ~MklSoftmaxOp() {}
     46 
     47   explicit MklSoftmaxOp(OpKernelConstruction* context) : OpKernel(context) {}
     48 
     49   void Compute(OpKernelContext* context) override {
     50     try {
     51       auto cpu_engine = engine(engine::cpu, 0);
     52 
     53       // src_tensor now points to the 0-th input of global data struct "context"
     54       size_t src_idx = 0;
     55       const Tensor& src_tensor = MklGetInput(context, src_idx);
     56 
     57       // Add: get MklShape
     58       MklDnnShape src_mkl_shape;
     59       GetMklShape(context, src_idx, &src_mkl_shape);
     60 
     61       // src_dims is the dimenstion of src_tensor
     62       // dim of the dst will also be same as src_dims
     63       auto src_tf_shape = src_mkl_shape.IsMklTensor()
     64                               ? src_mkl_shape.GetTfShape()
     65                               : src_tensor.shape();
     66       auto src_dims = TFShapeToMklDnnDims(src_tf_shape);
     67       auto output_dims = src_dims;
     68 
     69       // Create softmax memory for src, dst: both are defined in mkl_util.h,
     70       // they are wrapper
     71       MklDnnData<T> src(&cpu_engine);
     72       MklDnnData<T> dst(&cpu_engine);
     73 
     74       // If input is in MKL layout, then simply grab input layout; otherwise,
     75       // construct input Tf layout. For TF layout, although input shape
     76       // (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
     77       // layout
     78       auto src_md =
     79           src_mkl_shape.IsMklTensor()
     80               ? src_mkl_shape.GetMklLayout()
     81               : memory::desc(src_dims, MklDnnType<T>(), memory::format::nc);
     82 
     83       // src: setting memory descriptor and op memory descriptor
     84       // Basically following two functions maps the TF "src_tensor" to mkl
     85       // tensor object "src"
     86       // following functions are in mkl_util.h
     87       // data format is "nc" for src and dst; since the src and dst buffer is
     88       // always in 2D shape
     89       src.SetUsrMem(src_md, &src_tensor);
     90       src.SetOpMemDesc(src_dims, memory::format::nc);
     91 
     92       // creating a memory descriptor
     93       int axis = 1;  // axis to which softmax will be applied
     94       auto softmax_fwd_desc = softmax_forward::desc(prop_kind::forward_scoring,
     95                                                     src.GetOpMemDesc(), axis);
     96       auto softmax_fwd_pd =
     97           softmax_forward::primitive_desc(softmax_fwd_desc, cpu_engine);
     98 
     99       // add: output
    100       Tensor* output_tensor = nullptr;
    101       MklDnnShape output_mkl_shape;
    102       TensorShape output_tf_shape;  // shape of output TF tensor.
    103       // Softmax MklDnn output layout is same as input layout.
    104       auto dst_pd = src.GetUsrMemPrimDesc();
    105 
    106       // if input is MKL shape, ouput is also MKL shape.
    107       // if input is TF shape, output is also TF shape
    108       if (src_mkl_shape.IsMklTensor()) {
    109         output_mkl_shape.SetMklTensor(true);
    110         output_mkl_shape.SetMklLayout(&dst_pd);
    111         output_mkl_shape.SetElemType(MklDnnType<T>());
    112         output_mkl_shape.SetTfLayout(output_dims.size(), output_dims,
    113                                      memory::format::nc);
    114         output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T)));
    115       } else {  // then output is also TF shape
    116         output_mkl_shape.SetMklTensor(false);
    117         output_tf_shape = MklDnnDimsToTFShape(output_dims);
    118       }
    119       // Allocate output shape (MKL or TF based on the above)
    120       AllocateOutputSetMklShape(context, 0, &output_tensor, output_tf_shape,
    121                                 output_mkl_shape);
    122 
    123       // Output_dims and input_dims are same
    124       dst.SetUsrMem(src_md, output_tensor);
    125 
    126       // finally creating the "softmax op" using the primitive descriptor, src
    127       // and dst
    128       auto softmax_fwd =
    129           softmax_forward(softmax_fwd_pd, src.GetOpMem(), dst.GetOpMem());
    130 
    131       // execute net (pushing to the stream)
    132       // following 3 are common for all mkl dnn ops
    133       std::vector<primitive> net;
    134       net.push_back(softmax_fwd);
    135       stream(stream::kind::eager).submit(net).wait();
    136     } catch (mkldnn::error& e) {
    137       string error_msg = "Status: " + std::to_string(e.status) +
    138                          ", message: " + string(e.message) + ", in file " +
    139                          string(__FILE__) + ":" + std::to_string(__LINE__);
    140       OP_REQUIRES_OK(
    141           context,
    142           errors::Aborted("Operation received an exception:", error_msg));
    143     }
    144   }
    145 };
    146 
    147 /* Register DNN kernels for supported operations and supported types - right now
    148  * it is only Softmax and f32 */
    149 #define REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES(type)          \
    150   REGISTER_KERNEL_BUILDER(Name("_MklSoftmax")                       \
    151                               .Device(DEVICE_CPU)                   \
    152                               .TypeConstraint<type>("T")            \
    153                               .Label(mkl_op_registry::kMklOpLabel), \
    154                           MklSoftmaxOp<CPUDevice, type>);
    155 TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);
    156 
    157 }  // namespace tensorflow
    158 
    159 #endif  // INTEL_MKL_ML
    160 #endif  // INTEL_MKL
    161