Home | History | Annotate | Download | only in aot
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Generated by the tf_library build rule.  DO NOT EDIT!
     17 //
     18 // This file contains a test and benchmark for the function generated by
     19 // tfcompile.  All tokens of the form `{{TFCOMPILE_*}}` must be rewritten to
     20 // real values before this file can be compiled.
     21 //
     22 //    TFCOMPILE_HEADER    : Path to the header file generated by tfcompile.
     23 //    TFCOMPILE_CPP_CLASS : Name of the C++ class generated by tfcompile.
     24 //    TFCOMPILE_NAME      : Name for tests and benchmarks.
     25 //
     26 // The tf_library bazel macro in tfcompile.bzl performs the token rewriting, and
     27 // generates a cc_test rule for you.
     28 
     29 // These macros must be defined before eigen files are included.
     30 #define EIGEN_USE_THREADS
     31 #define EIGEN_USE_CUSTOM_THREAD_POOL
     32 
     33 // clang-format off
     34 #include "{{TFCOMPILE_HEADER}}"  // NOLINT(whitespace/braces)
     35 // clang-format on
     36 
     37 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     38 #include "tensorflow/core/platform/cpu_info.h"
     39 #include "tensorflow/core/platform/test.h"
     40 #include "tensorflow/core/platform/test_benchmark.h"
     41 
     42 // Macros that expand to tokens based on the entry point name.
     43 // clang-format off
     44 #define CPP_CLASS {{TFCOMPILE_CPP_CLASS}}  // NOLINT(whitespace/braces)
     45 #define TEST_NAME {{TFCOMPILE_NAME}}Test   // NOLINT(whitespace/braces)
     46 #define BM_NAME   BM_{{TFCOMPILE_NAME}}    // NOLINT(whitespace/braces)
     47 // clang-format on
     48 
     49 namespace tensorflow {
     50 namespace tfcompile {
     51 namespace {
     52 
     53 void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) {
     54   for (int i = 0; i < n; ++i) {
     55     if (sizes[i] != -1) {
     56       memset(bufs[i], 0, sizes[i]);
     57     }
     58   }
     59 }
     60 
     61 // Trivial test that runs the generated function to ensure it doesn't crash.
     62 TEST(TEST_NAME, NoCrash) {
     63   Eigen::ThreadPool pool(port::NumSchedulableCPUs());
     64   Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
     65 
     66   CPP_CLASS computation;
     67   computation.set_thread_pool(&device);
     68   zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
     69 
     70   EXPECT_TRUE(computation.Run());
     71 }
     72 
     73 // Simple benchmark that repeatedly runs the generated function.
     74 void BM_NAME(int iters) {
     75   testing::StopTiming();
     76 
     77   Eigen::ThreadPool pool(port::NumSchedulableCPUs());
     78   Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
     79 
     80   CPP_CLASS computation;
     81   computation.set_thread_pool(&device);
     82   zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
     83 
     84   testing::StartTiming();
     85   while (--iters) {
     86     computation.Run();
     87   }
     88   testing::StopTiming();
     89 }
     90 BENCHMARK(BM_NAME);
     91 
     92 }  // namespace
     93 }  // namespace tfcompile
     94 }  // namespace tensorflow
     95