Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
     17 #include "tensorflow/core/framework/tensor.h"
     18 #include "tensorflow/core/platform/test.h"
     19 #include "tensorflow/core/platform/test_benchmark.h"
     20 
     21 namespace tensorflow {
     22 
     23 template <typename T>
     24 static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
     25                           bool adjoint_b, DataType type) {
     26   Graph* g = new Graph(OpRegistry::Global());
     27   Tensor in0(type, adjoint_a ? TensorShape({b, k, m}) : TensorShape({b, m, k}));
     28   in0.flat<T>().setRandom();
     29   Tensor in1(type, adjoint_b ? TensorShape({b, n, k}) : TensorShape({b, k, n}));
     30   in1.flat<T>().setRandom();
     31   test::graph::BatchMatmul(g, test::graph::Constant(g, in0),
     32                            test::graph::Constant(g, in1), adjoint_a, adjoint_b);
     33   return g;
     34 }
     35 
     36 #define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE)                  \
     37   static void                                                                     \
     38       BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
     39           int iters) {                                                            \
     40     testing::UseRealTime();                                                       \
     41     testing::ItemsProcessed(static_cast<int64>(iters) * B * M * K * N * 2);       \
     42     test::Benchmark(#DEVICE, BatchMatmul<T>(B, M, K, N, TA, TB, TFTYPE))          \
     43         .Run(iters);                                                              \
     44   }                                                                               \
     45   BENCHMARK(                                                                      \
     46       BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE);
     47 
     48 #define BM_BatchMatmul(B, M, K, N, TA, TB) \
     49   BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, cpu);
     50 // BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
     51 // cpu);
     52 //  BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, gpu);
     53 /* Uncomment to enable benchmarks for double & complex types: */
     54 // BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
     55 // gpu);
     56 // BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \
     57 // BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu);
     58 // \
     59 // BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
     60 // BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
     61 
     62 // Typical fully connected layers
     63 BM_BatchMatmul(1, 1, 1024, 1024, false, false);
     64 BM_BatchMatmul(1, 8, 1024, 1024, false, false);
     65 BM_BatchMatmul(1, 16, 1024, 1024, false, false);
     66 BM_BatchMatmul(1, 128, 1024, 1024, false, false);
     67 BM_BatchMatmul(2, 1, 1024, 1024, false, false);
     68 BM_BatchMatmul(2, 8, 1024, 1024, false, false);
     69 BM_BatchMatmul(2, 16, 1024, 1024, false, false);
     70 BM_BatchMatmul(2, 128, 1024, 1024, false, false);
     71 BM_BatchMatmul(8, 1, 1024, 1024, false, false);
     72 BM_BatchMatmul(8, 8, 1024, 1024, false, false);
     73 BM_BatchMatmul(8, 16, 1024, 1024, false, false);
     74 BM_BatchMatmul(8, 128, 1024, 1024, false, false);
     75 BM_BatchMatmul(32, 1, 1024, 1024, false, false);
     76 BM_BatchMatmul(32, 8, 1024, 1024, false, false);
     77 BM_BatchMatmul(32, 16, 1024, 1024, false, false);
     78 BM_BatchMatmul(32, 128, 1024, 1024, false, false);
     79 
     80 // Square matmul.
     81 BM_BatchMatmul(1, 32, 32, 32, false, false);
     82 BM_BatchMatmul(1, 128, 128, 128, false, false);
     83 BM_BatchMatmul(1, 256, 256, 256, false, false);
     84 BM_BatchMatmul(1, 1024, 1024, 1024, false, false);
     85 BM_BatchMatmul(1, 2048, 2048, 2048, false, false);
     86 BM_BatchMatmul(2, 32, 32, 32, false, false);
     87 BM_BatchMatmul(2, 128, 128, 128, false, false);
     88 BM_BatchMatmul(2, 256, 256, 256, false, false);
     89 BM_BatchMatmul(2, 1024, 1024, 1024, false, false);
     90 BM_BatchMatmul(2, 2048, 2048, 2048, false, false);
     91 BM_BatchMatmul(4, 32, 32, 32, false, false);
     92 BM_BatchMatmul(4, 128, 128, 128, false, false);
     93 BM_BatchMatmul(4, 256, 256, 256, false, false);
     94 BM_BatchMatmul(4, 1024, 1024, 1024, false, false);
     95 BM_BatchMatmul(4, 2048, 2048, 2048, false, false);
     96 BM_BatchMatmul(8, 32, 32, 32, false, false);
     97 BM_BatchMatmul(8, 128, 128, 128, false, false);
     98 BM_BatchMatmul(8, 256, 256, 256, false, false);
     99 BM_BatchMatmul(8, 1024, 1024, 1024, false, false);
    100 BM_BatchMatmul(8, 2048, 2048, 2048, false, false);
    101 BM_BatchMatmul(32, 32, 32, 32, false, false);
    102 BM_BatchMatmul(32, 128, 128, 128, false, false);
    103 BM_BatchMatmul(32, 256, 256, 256, false, false);
    104 BM_BatchMatmul(32, 1024, 1024, 1024, false, false);
    105 BM_BatchMatmul(32, 2048, 2048, 2048, false, false);
    106 
    107 // Matrix-vector multiplies.
    108 BM_BatchMatmul(1, 10000, 200, 1, false, false);
    109 BM_BatchMatmul(8, 10000, 200, 1, false, false);
    110 BM_BatchMatmul(32, 10000, 200, 1, false, false);
    111 BM_BatchMatmul(1, 10000, 200, 1, true, false);
    112 BM_BatchMatmul(8, 10000, 200, 1, true, false);
    113 BM_BatchMatmul(32, 10000, 200, 1, true, false);
    114 BM_BatchMatmul(1, 10000, 200, 1, false, true);
    115 BM_BatchMatmul(8, 10000, 200, 1, false, true);
    116 BM_BatchMatmul(32, 10000, 200, 1, false, true);
    117 BM_BatchMatmul(1, 10000, 200, 1, true, true);
    118 BM_BatchMatmul(8, 10000, 200, 1, true, true);
    119 BM_BatchMatmul(32, 10000, 200, 1, true, true);
    120 
    121 // Vector-matrix multiplies.
    122 BM_BatchMatmul(1, 1, 200, 10000, false, false);
    123 BM_BatchMatmul(8, 1, 200, 10000, false, false);
    124 BM_BatchMatmul(32, 1, 200, 10000, false, false);
    125 BM_BatchMatmul(1, 1, 200, 10000, true, false);
    126 BM_BatchMatmul(8, 1, 200, 10000, true, false);
    127 BM_BatchMatmul(32, 1, 200, 10000, true, false);
    128 BM_BatchMatmul(1, 1, 200, 10000, false, true);
    129 BM_BatchMatmul(8, 1, 200, 10000, false, true);
    130 BM_BatchMatmul(32, 1, 200, 10000, false, true);
    131 BM_BatchMatmul(1, 1, 200, 10000, true, true);
    132 BM_BatchMatmul(8, 1, 200, 10000, true, true);
    133 BM_BatchMatmul(32, 1, 200, 10000, true, true);
    134 
    135 }  // end namespace tensorflow
    136