1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Implements a quantized eight-bit version of the matmul operation. 17 18 #define EIGEN_USE_THREADS 19 20 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK 21 #include "public/gemmlowp.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/kernels/meta_support.h" 25 #include "tensorflow/core/kernels/quantization_utils.h" 26 #include "tensorflow/core/kernels/reference_gemm.h" 27 #include "tensorflow/core/lib/core/errors.h" 28 29 namespace tensorflow { 30 31 // We have to break this out as a separate function because there are multiple 32 // combinations of transpose attributes we need to support, and they have to be 33 // compile-time constants to work with the templates used internally. 34 template <bool TransposeA, bool TransposeB, bool TransposeC> 35 void GemmlowpMultiply(OpKernelContext* op_context, const quint8* a_data, 36 const quint8* b_data, qint32* c_data, int m, int n, int k, 37 int offset_a, int offset_b, int lda, int ldb, int ldc) { 38 const uint8* a_data_as_uint8 = &(a_data->value); 39 const uint8* b_data_as_uint8 = &(b_data->value); 40 int32* c_data_as_int32 = &(c_data->value); 41 static const gemmlowp::MapOrder ResultOrder = 42 !TransposeC ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor; 43 static const gemmlowp::MapOrder LhsOrder = 44 !TransposeA ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor; 45 static const gemmlowp::MapOrder RhsOrder = 46 !TransposeB ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor; 47 gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(a_data_as_uint8, m, k, 48 lda); 49 gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(b_data_as_uint8, k, n, 50 ldb); 51 gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(c_data_as_int32, m, n, 52 ldc); 53 const std::tuple<> empty_pipeline = {}; 54 auto& worker_threads = 55 *(op_context->device()->tensorflow_cpu_worker_threads()); 56 TensorflowGemmContext context(worker_threads.num_threads, 57 worker_threads.workers); 58 gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t, 59 gemmlowp::DefaultL8R8BitDepthParams>( 60 &context, lhs, rhs, &result, -offset_a, -offset_b, empty_pipeline); 61 // Since gemmlowp uses assembly to write to the output, msan won't detect 62 // the output buffer as written to, so we mark it manually. 63 TF_ANNOTATE_MEMORY_IS_INITIALIZED(c_data_as_int32, m * n * sizeof(int32)); 64 } 65 66 template <class T1, class T2, class Toutput> 67 class QuantizedMatMulOp : public OpKernel { 68 public: 69 explicit QuantizedMatMulOp(OpKernelConstruction* context) 70 : OpKernel(context) { 71 OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &transpose_a_)); 72 OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &transpose_b_)); 73 } 74 75 void Compute(OpKernelContext* context) override { 76 const Tensor& a = context->input(0); 77 const Tensor& b = context->input(1); 78 const float min_a = context->input(2).flat<float>()(0); 79 const float max_a = context->input(3).flat<float>()(0); 80 const float min_b = context->input(4).flat<float>()(0); 81 const float max_b = context->input(5).flat<float>()(0); 82 83 // Make sure that we have valid quantization ranges for the input buffers. 84 // If the difference between the min and max is negative or zero, it makes 85 // it hard to do meaningful intermediate operations on the values. 86 OP_REQUIRES(context, (max_a > min_a), 87 errors::InvalidArgument("max_a must be larger than min_a.")); 88 OP_REQUIRES(context, (max_b > min_b), 89 errors::InvalidArgument("max_b must be larger than min_b.")); 90 const int32 offset_a = FloatToQuantizedUnclamped<T1>(0.0f, min_a, max_a); 91 const int32 offset_b = FloatToQuantizedUnclamped<T2>(0.0f, min_b, max_b); 92 const int32 offset_c = 0; 93 const int32 mult_c = 1; 94 const int32 shift_c = 0; 95 96 // Check that the dimensions of the two matrices are valid. 97 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(a.shape()), 98 errors::InvalidArgument("In[0] is not a matrix")); 99 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(b.shape()), 100 errors::InvalidArgument("In[1] is not a matrix")); 101 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; 102 dim_pair[0].first = transpose_a_ ? 0 : 1; 103 dim_pair[0].second = transpose_b_ ? 1 : 0; 104 105 OP_REQUIRES(context, 106 a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), 107 errors::InvalidArgument( 108 "Matrix size-compatible: In[0]: ", a.shape().DebugString(), 109 ", In[1]: ", b.shape().DebugString())); 110 111 OP_REQUIRES(context, ((shift_c >= 0) && (shift_c <= 31)), 112 errors::InvalidArgument("shift_c must be between 0 and 31, " 113 "inclusive.")); 114 115 int a_dim_remaining = 1 - dim_pair[0].first; 116 int b_dim_remaining = 1 - dim_pair[0].second; 117 TensorShape out_shape( 118 {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)}); 119 Tensor* c = nullptr; 120 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c)); 121 CHECK(c); 122 123 const T1* a_data = a.flat<T1>().data(); 124 const T2* b_data = b.flat<T2>().data(); 125 Toutput* c_data = c->flat<Toutput>().data(); 126 127 const bool transpose_c = false; 128 const size_t m = a.dim_size(a_dim_remaining); 129 const size_t n = b.dim_size(b_dim_remaining); 130 const size_t k = a.dim_size(dim_pair[0].first); 131 const size_t lda = a.dim_size(1); 132 const size_t ldb = b.dim_size(1); 133 const size_t ldc = n; 134 135 if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() && 136 std::is_same<T2, quint8>() && std::is_same<Toutput, qint32>() && 137 (offset_c == 0) && (mult_c == 1) && (shift_c == 0) && 138 (transpose_c == false) && (k <= 2048)) { 139 // Gemmlowp/meta code path works on 32 & 64 bit Arm with NEON Simd and 140 // allows optimized quantized 8bit to 32bit gemm. 141 meta::QuantizedGemm(context, transpose_a_, transpose_b_, a_data, b_data, 142 c_data, m, n, k, -offset_a, -offset_b, lda, ldb, ldc); 143 } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() && 144 std::is_same<Toutput, qint32>() && (offset_c == 0) && 145 (mult_c == 1) && (shift_c == 0) && (transpose_c == false)) { 146 // The gemmlowp optimized library only works for a particular set of data 147 // types, so check if we meet those requirements and fall back to a slower 148 // reference implementation if not. 149 if (transpose_a_) { 150 if (transpose_b_) { 151 GemmlowpMultiply<true, true, false>(context, a_data, b_data, c_data, 152 m, n, k, offset_a, offset_b, lda, 153 ldb, ldc); 154 } else { 155 GemmlowpMultiply<true, false, false>(context, a_data, b_data, c_data, 156 m, n, k, offset_a, offset_b, lda, 157 ldb, ldc); 158 } 159 } else { 160 if (transpose_b_) { 161 GemmlowpMultiply<false, true, false>(context, a_data, b_data, c_data, 162 m, n, k, offset_a, offset_b, lda, 163 ldb, ldc); 164 } else { 165 GemmlowpMultiply<false, false, false>(context, a_data, b_data, c_data, 166 m, n, k, offset_a, offset_b, 167 lda, ldb, ldc); 168 } 169 } 170 } else { 171 ReferenceGemm<T1, T2, Toutput>( 172 transpose_a_, transpose_b_, transpose_c, m, n, k, a_data, offset_a, 173 lda, b_data, offset_b, ldb, c_data, shift_c, offset_c, mult_c, ldc); 174 } 175 176 float min_c_value; 177 float max_c_value; 178 QuantizationRangeForMultiplication<T1, T2, Toutput>( 179 min_a, max_a, min_b, max_b, &min_c_value, &max_c_value); 180 Tensor* c_min = nullptr; 181 OP_REQUIRES_OK(context, context->allocate_output(1, {}, &c_min)); 182 c_min->flat<float>()(0) = min_c_value; 183 184 Tensor* c_max = nullptr; 185 OP_REQUIRES_OK(context, context->allocate_output(2, {}, &c_max)); 186 c_max->flat<float>()(0) = max_c_value; 187 } 188 189 private: 190 bool transpose_a_; 191 bool transpose_b_; 192 }; 193 194 REGISTER_KERNEL_BUILDER(Name("QuantizedMatMul") 195 .Device(DEVICE_CPU) 196 .TypeConstraint<quint8>("T1") 197 .TypeConstraint<quint8>("T2") 198 .TypeConstraint<qint32>("Toutput"), 199 QuantizedMatMulOp<quint8, quint8, qint32>); 200 201 } // namespace tensorflow 202