Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Implements a quantized eight-bit version of the matmul operation.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
     21 #include "public/gemmlowp.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/kernels/meta_support.h"
     25 #include "tensorflow/core/kernels/quantization_utils.h"
     26 #include "tensorflow/core/kernels/reference_gemm.h"
     27 #include "tensorflow/core/lib/core/errors.h"
     28 
     29 namespace tensorflow {
     30 
     31 // We have to break this out as a separate function because there are multiple
     32 // combinations of transpose attributes we need to support, and they have to be
     33 // compile-time constants to work with the templates used internally.
     34 template <bool TransposeA, bool TransposeB, bool TransposeC>
     35 void GemmlowpMultiply(OpKernelContext* op_context, const quint8* a_data,
     36                       const quint8* b_data, qint32* c_data, int m, int n, int k,
     37                       int offset_a, int offset_b, int lda, int ldb, int ldc) {
     38   const uint8* a_data_as_uint8 = &(a_data->value);
     39   const uint8* b_data_as_uint8 = &(b_data->value);
     40   int32* c_data_as_int32 = &(c_data->value);
     41   static const gemmlowp::MapOrder ResultOrder =
     42       !TransposeC ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
     43   static const gemmlowp::MapOrder LhsOrder =
     44       !TransposeA ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
     45   static const gemmlowp::MapOrder RhsOrder =
     46       !TransposeB ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
     47   gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(a_data_as_uint8, m, k,
     48                                                         lda);
     49   gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(b_data_as_uint8, k, n,
     50                                                         ldb);
     51   gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(c_data_as_int32, m, n,
     52                                                         ldc);
     53   const std::tuple<> empty_pipeline = {};
     54   auto& worker_threads =
     55       *(op_context->device()->tensorflow_cpu_worker_threads());
     56   TensorflowGemmContext context(worker_threads.num_threads,
     57                                 worker_threads.workers);
     58   gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
     59                                    gemmlowp::DefaultL8R8BitDepthParams>(
     60       &context, lhs, rhs, &result, -offset_a, -offset_b, empty_pipeline);
     61   // Since gemmlowp uses assembly to write to the output, msan won't detect
     62   // the output buffer as written to, so we mark it manually.
     63   TF_ANNOTATE_MEMORY_IS_INITIALIZED(c_data_as_int32, m * n * sizeof(int32));
     64 }
     65 
     66 template <class T1, class T2, class Toutput>
     67 class QuantizedMatMulOp : public OpKernel {
     68  public:
     69   explicit QuantizedMatMulOp(OpKernelConstruction* context)
     70       : OpKernel(context) {
     71     OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &transpose_a_));
     72     OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &transpose_b_));
     73   }
     74 
     75   void Compute(OpKernelContext* context) override {
     76     const Tensor& a = context->input(0);
     77     const Tensor& b = context->input(1);
     78     const float min_a = context->input(2).flat<float>()(0);
     79     const float max_a = context->input(3).flat<float>()(0);
     80     const float min_b = context->input(4).flat<float>()(0);
     81     const float max_b = context->input(5).flat<float>()(0);
     82 
     83     // Make sure that we have valid quantization ranges for the input buffers.
     84     // If the difference between the min and max is negative or zero, it makes
     85     // it hard to do meaningful intermediate operations on the values.
     86     OP_REQUIRES(context, (max_a > min_a),
     87                 errors::InvalidArgument("max_a must be larger than min_a."));
     88     OP_REQUIRES(context, (max_b > min_b),
     89                 errors::InvalidArgument("max_b must be larger than min_b."));
     90     const int32 offset_a = FloatToQuantizedUnclamped<T1>(0.0f, min_a, max_a);
     91     const int32 offset_b = FloatToQuantizedUnclamped<T2>(0.0f, min_b, max_b);
     92     const int32 offset_c = 0;
     93     const int32 mult_c = 1;
     94     const int32 shift_c = 0;
     95 
     96     // Check that the dimensions of the two matrices are valid.
     97     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(a.shape()),
     98                 errors::InvalidArgument("In[0] is not a matrix"));
     99     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(b.shape()),
    100                 errors::InvalidArgument("In[1] is not a matrix"));
    101     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
    102     dim_pair[0].first = transpose_a_ ? 0 : 1;
    103     dim_pair[0].second = transpose_b_ ? 1 : 0;
    104 
    105     OP_REQUIRES(context,
    106                 a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
    107                 errors::InvalidArgument(
    108                     "Matrix size-compatible: In[0]: ", a.shape().DebugString(),
    109                     ", In[1]: ", b.shape().DebugString()));
    110 
    111     OP_REQUIRES(context, ((shift_c >= 0) && (shift_c <= 31)),
    112                 errors::InvalidArgument("shift_c must be between 0 and 31, "
    113                                         "inclusive."));
    114 
    115     int a_dim_remaining = 1 - dim_pair[0].first;
    116     int b_dim_remaining = 1 - dim_pair[0].second;
    117     TensorShape out_shape(
    118         {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
    119     Tensor* c = nullptr;
    120     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c));
    121     CHECK(c);
    122 
    123     const T1* a_data = a.flat<T1>().data();
    124     const T2* b_data = b.flat<T2>().data();
    125     Toutput* c_data = c->flat<Toutput>().data();
    126 
    127     const bool transpose_c = false;
    128     const size_t m = a.dim_size(a_dim_remaining);
    129     const size_t n = b.dim_size(b_dim_remaining);
    130     const size_t k = a.dim_size(dim_pair[0].first);
    131     const size_t lda = a.dim_size(1);
    132     const size_t ldb = b.dim_size(1);
    133     const size_t ldc = n;
    134 
    135     if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() &&
    136         std::is_same<T2, quint8>() && std::is_same<Toutput, qint32>() &&
    137         (offset_c == 0) && (mult_c == 1) && (shift_c == 0) &&
    138         (transpose_c == false) && (k <= 2048)) {
    139       // Gemmlowp/meta code path works on 32 & 64 bit Arm with NEON Simd and
    140       // allows optimized quantized 8bit to 32bit gemm.
    141       meta::QuantizedGemm(context, transpose_a_, transpose_b_, a_data, b_data,
    142                           c_data, m, n, k, -offset_a, -offset_b, lda, ldb, ldc);
    143     } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
    144                std::is_same<Toutput, qint32>() && (offset_c == 0) &&
    145                (mult_c == 1) && (shift_c == 0) && (transpose_c == false)) {
    146       // The gemmlowp optimized library only works for a particular set of data
    147       // types, so check if we meet those requirements and fall back to a slower
    148       // reference implementation if not.
    149       if (transpose_a_) {
    150         if (transpose_b_) {
    151           GemmlowpMultiply<true, true, false>(context, a_data, b_data, c_data,
    152                                               m, n, k, offset_a, offset_b, lda,
    153                                               ldb, ldc);
    154         } else {
    155           GemmlowpMultiply<true, false, false>(context, a_data, b_data, c_data,
    156                                                m, n, k, offset_a, offset_b, lda,
    157                                                ldb, ldc);
    158         }
    159       } else {
    160         if (transpose_b_) {
    161           GemmlowpMultiply<false, true, false>(context, a_data, b_data, c_data,
    162                                                m, n, k, offset_a, offset_b, lda,
    163                                                ldb, ldc);
    164         } else {
    165           GemmlowpMultiply<false, false, false>(context, a_data, b_data, c_data,
    166                                                 m, n, k, offset_a, offset_b,
    167                                                 lda, ldb, ldc);
    168         }
    169       }
    170     } else {
    171       ReferenceGemm<T1, T2, Toutput>(
    172           transpose_a_, transpose_b_, transpose_c, m, n, k, a_data, offset_a,
    173           lda, b_data, offset_b, ldb, c_data, shift_c, offset_c, mult_c, ldc);
    174     }
    175 
    176     float min_c_value;
    177     float max_c_value;
    178     QuantizationRangeForMultiplication<T1, T2, Toutput>(
    179         min_a, max_a, min_b, max_b, &min_c_value, &max_c_value);
    180     Tensor* c_min = nullptr;
    181     OP_REQUIRES_OK(context, context->allocate_output(1, {}, &c_min));
    182     c_min->flat<float>()(0) = min_c_value;
    183 
    184     Tensor* c_max = nullptr;
    185     OP_REQUIRES_OK(context, context->allocate_output(2, {}, &c_max));
    186     c_max->flat<float>()(0) = max_c_value;
    187   }
    188 
    189  private:
    190   bool transpose_a_;
    191   bool transpose_b_;
    192 };
    193 
    194 REGISTER_KERNEL_BUILDER(Name("QuantizedMatMul")
    195                             .Device(DEVICE_CPU)
    196                             .TypeConstraint<quint8>("T1")
    197                             .TypeConstraint<quint8>("T2")
    198                             .TypeConstraint<qint32>("Toutput"),
    199                         QuantizedMatMulOp<quint8, quint8, qint32>);
    200 
    201 }  // namespace tensorflow
    202