Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include "tensorflow/core/kernels/meta_support.h"
     19 
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/resource_mgr.h"
     22 #include "tensorflow/core/kernels/quantization_utils.h"
     23 #include "tensorflow/core/platform/logging.h"
     24 #include "tensorflow/core/platform/mutex.h"
     25 
     26 #if (defined(GEMMLOWP_NEON_32) || defined(GEMMLOWP_NEON_64)) && \
     27     !defined(TENSORFLOW_DISABLE_META) && !defined(__APPLE__)
     28 #define TENSORFLOW_USE_META (1)
     29 #endif
     30 
     31 namespace tensorflow {
     32 namespace meta {
     33 
     34 namespace {
     35 
     36 int g_num_threads = 0;
     37 bool g_enabled = true;
     38 bool g_use_local_context = false;
     39 
     40 #ifdef TENSORFLOW_USE_META
     41 
     42 const int kAlignment = 32;
     43 const int kScratchSize = 2048 * 1024 + kAlignment;
     44 
     45 class Scratch : public ResourceBase {
     46  public:
     47   Scratch() : scratch_(new uint8_t[kScratchSize]) {
     48     // Make sure scratch is aligned to 32 bytes. Scratch object owns the
     49     // scratch buffer.
     50     scratch_32_aligned_ =
     51         scratch_.get() + kAlignment -
     52         (reinterpret_cast<uintptr_t>(scratch_.get()) % kAlignment);
     53   }
     54 
     55   uint8_t* buffer() { return scratch_32_aligned_; }
     56 
     57   string DebugString() { return "MetaGemmScratchResource"; }
     58 
     59  private:
     60   std::unique_ptr<uint8_t> scratch_;
     61   uint8_t* scratch_32_aligned_;
     62 };
     63 
     64 uint8_t* GetScratch(OpKernelContext* context) {
     65   Scratch* scratch = nullptr;
     66   std::function<Status(Scratch**)> creator = [](Scratch** resource) {
     67     *resource = new Scratch();
     68     return Status::OK();
     69   };
     70   Status s = context->resource_manager()->LookupOrCreate(
     71       "MetaGemm", "ScratchBuffer", &scratch, creator);
     72   if (!s.ok()) {
     73     context->CtxFailureWithWarning(s);
     74     return nullptr;
     75   }
     76   return scratch->buffer();
     77 }
     78 
     79 gemmlowp::WorkersPool* GetWorkersPool() {
     80   static gemmlowp::WorkersPool* pool = new gemmlowp::WorkersPool();
     81   return pool;
     82 }
     83 
     84 mutex& GetMutex() {
     85   static mutex mu(LINKER_INITIALIZED);
     86   return mu;
     87 }
     88 
     89 int GetWorkersCount(OpKernelContext* tf_context) {
     90   if (g_num_threads == 0) {
     91     return tf_context->device()->tensorflow_cpu_worker_threads()->num_threads;
     92   }
     93   return g_num_threads;
     94 }
     95 
     96 typedef gemmlowp::meta::SimpleContext<gemmlowp::WorkersPool> LocalContext;
     97 
     98 template <typename Context, typename Params>
     99 void MultiThreadGemm(Context* context, const Params& params) {
    100   if (params.m <= 4) {
    101     gemmlowp::meta::MultiThreadGemm<
    102         Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params, 1,
    103         8, 8>(context, params);
    104   } else {
    105     if (params.m >= params.n) {
    106       gemmlowp::meta::MultiThreadGemm<
    107           Context, gemmlowp::meta::GemmExecutorPackRHSCacheFriendly<>, Params,
    108           2, 4, 8>(context, params);
    109     } else {
    110       gemmlowp::meta::MultiThreadGemm<
    111           Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params,
    112           2, 4, 8>(context, params);
    113     }
    114   }
    115 }
    116 
    117 template <typename LeftStream, typename RightStream>
    118 void QuantizedGemmImpl(OpKernelContext* tf_context, const quint8* a_data,
    119                        const quint8* b_data, qint32* c_data, int m, int n,
    120                        int k, int offset_a, int offset_b, int lda, int ldb,
    121                        int ldc) {
    122   typedef gemmlowp::meta::GemmParams<
    123       uint8_t, int32_t, LeftStream, RightStream,
    124       gemmlowp::meta::QuantizedStaticPreprocessedAsInt32,
    125       gemmlowp::meta::RowMajor>
    126       Params;
    127   Params params;
    128 
    129   params.m = m;
    130   params.n = n;
    131   params.k = k;
    132 
    133   params.lhs = reinterpret_cast<const uint8_t*>(&(a_data->value));
    134   params.rhs = reinterpret_cast<const uint8_t*>(&(b_data->value));
    135   params.result = reinterpret_cast<int32_t*>(&(c_data->value));
    136   params.scratch = CHECK_NOTNULL(GetScratch(tf_context));
    137 
    138   params.left_stream.count = k;
    139   params.left_stream.stride = lda;
    140   params.left_stream.multiplicative_sum_offset = offset_b;
    141   params.left_stream.additive_sum_offset = k * offset_a * offset_b;
    142 
    143   params.right_stream.count = k;
    144   params.right_stream.stride = ldb;
    145   params.right_stream.multiplicative_sum_offset = offset_a;
    146   params.right_stream.additive_sum_offset = 0;
    147 
    148   params.fused_kernel.kernel.count = k;
    149   params.fused_kernel.output_stream.stride = ldc * sizeof(int32_t);
    150 
    151   if (g_use_local_context) {
    152     LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
    153     MultiThreadGemm<LocalContext, Params>(&local_context, params);
    154   } else {
    155     auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
    156     TensorflowGemmContext context(workers.num_threads, workers.workers);
    157     MultiThreadGemm<TensorflowGemmContext, Params>(&context, params);
    158   }
    159 }
    160 
    161 template <typename Params, int kernel_size>
    162 void MultiThreadTransform1D(OpKernelContext* tf_context, const Params& params) {
    163   if (g_use_local_context) {
    164     LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
    165     gemmlowp::meta::MultiThreadTransform1D<LocalContext, Params, kernel_size>(
    166         &local_context, params);
    167   } else {
    168     auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
    169     TensorflowGemmContext context(workers.num_threads, workers.workers);
    170     gemmlowp::meta::MultiThreadTransform1D<TensorflowGemmContext, Params,
    171                                            kernel_size>(&context, params);
    172   }
    173 }
    174 
    175 template <typename QuantizedType>
    176 double CalculateRangeScale(float min, float max) {
    177   const int bits = sizeof(QuantizedType) * 8;
    178   return static_cast<double>(max - min) /
    179          ((static_cast<int64_t>(1) << bits) - 1);
    180 }
    181 
    182 template <typename QuantizedType>
    183 double CalculateOneOverRangeScale(float min, float max) {
    184   if (min == max) {
    185     return 0.0;
    186   }
    187   const int bits = sizeof(QuantizedType) * 8;
    188   return static_cast<double>((static_cast<int64_t>(1) << bits) - 1) /
    189          (max - min);
    190 }
    191 
    192 #endif  // TENSORFLOW_USE_META
    193 
    194 }  // namespace
    195 
    196 void SetNumThreads(int num_threads) { g_num_threads = num_threads; }
    197 
    198 int GetNumThreads() { return g_num_threads; }
    199 
    200 void SetUseLocalContext(bool use_local_context) {
    201   g_use_local_context = use_local_context;
    202 }
    203 
    204 bool GetUseLocalContext() { return g_use_local_context; }
    205 
    206 bool IsSupported() {
    207 #if defined(TENSORFLOW_USE_META)
    208   return true;
    209 #else
    210   return false;
    211 #endif
    212 }
    213 
    214 bool IsEnabled() { return g_enabled; }
    215 
    216 void SetEnabled(bool enabled) { g_enabled = enabled; }
    217 
    218 bool IsSupportedAndEnabled() { return IsSupported() && IsEnabled(); }
    219 
    220 void QuantizedGemm(OpKernelContext* tf_context, bool transpose_a,
    221                    bool transpose_b, const quint8* a_data, const quint8* b_data,
    222                    qint32* c_data, int m, int n, int k, int offset_a,
    223                    int offset_b, int lda, int ldb, int ldc) {
    224 #ifdef TENSORFLOW_USE_META
    225   mutex_lock library_lock(GetMutex());
    226   if (transpose_a) {
    227     if (transpose_b) {
    228       QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
    229                         gemmlowp::meta::RowMajorWithSum>(
    230           tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
    231           ldb, ldc);
    232     } else {
    233       QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
    234                         gemmlowp::meta::ColumnMajorWithSum>(
    235           tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
    236           ldb, ldc);
    237     }
    238   } else {
    239     if (transpose_b) {
    240       QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
    241                         gemmlowp::meta::RowMajorWithSum>(
    242           tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
    243           ldb, ldc);
    244     } else {
    245       QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
    246                         gemmlowp::meta::ColumnMajorWithSum>(
    247           tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
    248           ldb, ldc);
    249     }
    250   }
    251 #else
    252   LOG(FATAL) << "QuantizedGemm: Meta fastpath not supported.";
    253 #endif
    254 }
    255 
    256 void Requantize(OpKernelContext* tf_context, const qint32* input, int count,
    257                 float input_min, float input_max, float output_min,
    258                 float output_max, quint8* output) {
    259 #ifdef TENSORFLOW_USE_META
    260   mutex_lock library_lock(GetMutex());
    261   typedef gemmlowp::meta::Transform1DParams<int32_t, uint8_t,
    262                                             gemmlowp::meta::Requantize>
    263       Params;
    264 
    265   Params params;
    266   params.input = reinterpret_cast<const int32_t*>(input);
    267   params.output = reinterpret_cast<uint8_t*>(output);
    268   params.kernel.count = count;
    269   params.kernel.input_range_min = input_min;
    270   params.kernel.output_range_min = output_min;
    271   params.kernel.input_range_scale =
    272       CalculateRangeScale<int32_t>(input_min, input_max);
    273   params.kernel.one_over_output_range_scale =
    274       CalculateOneOverRangeScale<uint8_t>(output_min, output_max);
    275   params.kernel.input_range_offset =
    276       static_cast<float>(std::numeric_limits<int32_t>::lowest());
    277 
    278   // After adding the output_range_offset the value is cast from float to uint.
    279   // The float to int/uint cast in NEON uses round toward 0. To keep the
    280   // rounding consistent with Eigen, which uses round toward closest, we can
    281   // add 0.5f and exploit the fact that we only operate on non negative values.
    282   // TODO(maciekc): fix the actual kernel in gemmlowp/meta
    283   params.kernel.output_range_offset =
    284       static_cast<float>(std::numeric_limits<uint8_t>::lowest()) + 0.5f;
    285 
    286   MultiThreadTransform1D<Params, 16>(tf_context, params);
    287 #else
    288   LOG(FATAL) << "Requantize: Meta fastpath not supported.";
    289 #endif
    290 }
    291 
    292 void Dequantize(OpKernelContext* tf_context, const quint8* input, int count,
    293                 float range_min, float range_max, float* output) {
    294 #ifdef TENSORFLOW_USE_META
    295   mutex_lock library_lock(GetMutex());
    296   typedef gemmlowp::meta::Transform1DParams<uint8_t, float,
    297                                             gemmlowp::meta::Dequantize>
    298       Params;
    299 
    300   Params params;
    301   params.input = reinterpret_cast<const uint8_t*>(input);
    302   params.output = reinterpret_cast<float*>(output);
    303   params.kernel.count = count;
    304   params.kernel.range_min = range_min;
    305   params.kernel.range_scale =
    306       CalculateRangeScale<uint8_t>(range_min, range_max);
    307   params.kernel.range_offset =
    308       static_cast<float>(std::numeric_limits<uint8_t>::lowest());
    309 
    310   MultiThreadTransform1D<Params, 16>(tf_context, params);
    311 #else
    312   LOG(FATAL) << "Dequantize: Meta fastpath not supported.";
    313 #endif
    314 }
    315 
    316 void Quantize(OpKernelContext* tf_context, const float* input, int count,
    317               float range_min, float range_max, quint8* output) {
    318 #ifdef TENSORFLOW_USE_META
    319   mutex_lock library_lock(GetMutex());
    320   typedef gemmlowp::meta::Transform1DParams<float, uint8_t,
    321                                             gemmlowp::meta::Quantize>
    322       Params;
    323 
    324   Params params;
    325   params.input = reinterpret_cast<const float*>(input);
    326   params.output = reinterpret_cast<uint8_t*>(output);
    327   params.kernel.count = count;
    328   params.kernel.range_min = range_min;
    329   params.kernel.range_scale =
    330       CalculateOneOverRangeScale<uint8_t>(range_min, range_max);
    331 
    332   // After adding the range_offset the value is cast from float to uint.
    333   // The float to int/uint cast in NEON uses round toward 0. To keep the
    334   // rounding consistent with Eigen, which uses round toward closest, we can
    335   // add 0.5f and exploit the fact that we only operate on non negative values.
    336   // TODO(maciekc): fix the actual kernel in gemmlowp/meta
    337   params.kernel.range_offset =
    338       static_cast<float>(std::numeric_limits<uint8_t>::lowest()) + 0.5f;
    339 
    340   MultiThreadTransform1D<Params, 16>(tf_context, params);
    341 #else
    342   LOG(FATAL) << "Quantize: Meta fastpath not supported.";
    343 #endif
    344 }
    345 
    346 void QuantizedBiasAdd(OpKernelContext* tf_context, const quint8* input,
    347                       int input_count, const quint8* bias, int bias_count,
    348                       float input_min, float input_max, float bias_min,
    349                       float bias_max, float output_min, float output_max,
    350                       qint32* output) {
    351 #ifdef TENSORFLOW_USE_META
    352   mutex_lock library_lock(GetMutex());
    353   typedef gemmlowp::meta::Transform1DParams<uint8_t, int32_t,
    354                                             gemmlowp::meta::BiasAdd<uint8_t>>
    355       Params;
    356 
    357   Params params;
    358   params.input = reinterpret_cast<const uint8_t*>(input);
    359   params.output = reinterpret_cast<int32_t*>(output);
    360   params.kernel.bias = reinterpret_cast<const uint8_t*>(bias);
    361   params.kernel.count = bias_count;
    362   params.kernel.rows = input_count / bias_count;
    363   params.kernel.input_range_min = input_min;
    364   params.kernel.bias_range_min = bias_min;
    365   params.kernel.input_range_scale =
    366       CalculateRangeScale<uint8_t>(input_min, input_max);
    367   params.kernel.bias_range_scale =
    368       CalculateRangeScale<uint8_t>(bias_min, bias_max);
    369   params.kernel.input_range_offset = 0;
    370   params.kernel.bias_range_offset = 0;
    371   params.kernel.output_range_min = output_min;
    372   params.kernel.one_over_output_range_scale =
    373       CalculateOneOverRangeScale<int32_t>(output_min, output_max);
    374   params.kernel.output_range_offset =
    375       static_cast<float>(std::numeric_limits<int32_t>::lowest());
    376 
    377   // TODO(maciekc): add multithreading to bias add.
    378   // Right now this kernel does not support multi threaded execution.
    379   gemmlowp::meta::Transform1D<Params, 16>(params);
    380 #else
    381   LOG(FATAL) << "QuantizedBiasAdd: Meta fastpath not supported.";
    382 #endif
    383 }
    384 
    385 void Clamp(OpKernelContext* tf_context, const quint8* input, int count,
    386            quint8 clamp_min, quint8 clamp_max, quint8* output) {
    387 #ifdef TENSORFLOW_USE_META
    388   mutex_lock library_lock(GetMutex());
    389   typedef gemmlowp::meta::Transform1DParams<uint8_t, uint8_t,
    390                                             gemmlowp::meta::MinMax<uint8_t>>
    391       Params;
    392 
    393   Params params;
    394   params.input = reinterpret_cast<const uint8_t*>(input);
    395   params.output = reinterpret_cast<uint8_t*>(output);
    396   params.kernel.count = count;
    397   params.kernel.min = clamp_min;
    398   params.kernel.max = clamp_max;
    399 
    400   MultiThreadTransform1D<Params, 16>(tf_context, params);
    401 #else
    402   LOG(FATAL) << "Clamp: Meta fastpath not supported.";
    403 #endif
    404 }
    405 
    406 }  // namespace meta
    407 }  // namespace tensorflow
    408