Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
     17 #include "tensorflow/core/framework/tensor.h"
     18 #include "tensorflow/core/platform/test.h"
     19 #include "tensorflow/core/platform/test_benchmark.h"
     20 
     21 namespace tensorflow {
     22 
     23 template <typename T>
     24 static Graph* Matmul(int m, int k, int n, bool transpose_a, bool transpose_b,
     25                      DataType type) {
     26   Graph* g = new Graph(OpRegistry::Global());
     27   Tensor in0(type, transpose_a ? TensorShape({k, m}) : TensorShape({m, k}));
     28   in0.flat<T>().setRandom();
     29   Tensor in1(type, transpose_b ? TensorShape({n, k}) : TensorShape({k, n}));
     30   in1.flat<T>().setRandom();
     31   test::graph::Matmul(g, test::graph::Constant(g, in0),
     32                       test::graph::Constant(g, in1), transpose_a, transpose_b);
     33   return g;
     34 }
     35 
     36 #define BM_MatmulDev(M, K, N, TA, TB, T, TFTYPE, DEVICE)                       \
     37   static void BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
     38       int iters) {                                                             \
     39     testing::UseRealTime();                                                    \
     40     testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2);        \
     41     test::Benchmark(#DEVICE, Matmul<T>(M, K, N, TA, TB, TFTYPE)).Run(iters);   \
     42   }                                                                            \
     43   BENCHMARK(BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE);
     44 
     45 #define BM_Matmul(M, K, N, TA, TB)                                       \
     46   BM_MatmulDev(M, K, N, TA, TB, float, DT_FLOAT, cpu);                   \
     47   BM_MatmulDev(M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64, cpu); \
     48   BM_MatmulDev(M, K, N, TA, TB, float, DT_FLOAT, gpu);                   \
     49   BM_MatmulDev(M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64, gpu); \
     50 /* Uncomment to enable benchmarks for double/complex128: */              \
     51 // BM_MatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu);                   \
     52 // BM_MatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu); \
     53 // BM_MatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu);                   \
     54 // BM_MatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
     55 
     56 // Batch size of 1 included for inference.
     57 // Typical fully connected layers
     58 BM_Matmul(1, 512, 512, false, false);
     59 BM_Matmul(8, 512, 512, false, false);
     60 BM_Matmul(16, 512, 512, false, false);
     61 BM_Matmul(128, 512, 512, false, false);
     62 
     63 BM_Matmul(1, 1024, 1024, false, false);
     64 BM_Matmul(8, 1024, 1024, false, false);
     65 BM_Matmul(16, 1024, 1024, false, false);
     66 BM_Matmul(128, 1024, 1024, false, false);
     67 BM_Matmul(4096, 4096, 4096, false, false);
     68 
     69 // Backward for fully connected layers
     70 BM_Matmul(1, 1024, 1024, false, true);
     71 BM_Matmul(8, 1024, 1024, false, true);
     72 BM_Matmul(16, 1024, 1024, false, true);
     73 BM_Matmul(128, 1024, 1024, false, true);
     74 
     75 // Forward softmax with large output size
     76 BM_Matmul(1, 200, 10000, false, false);
     77 BM_Matmul(8, 200, 10000, false, false);
     78 BM_Matmul(20, 200, 10000, false, false);
     79 BM_Matmul(20, 200, 20000, false, false);
     80 
     81 // Backward softmax with large output size
     82 BM_Matmul(1, 10000, 200, false, true);
     83 BM_Matmul(1, 10000, 200, false, false);
     84 BM_Matmul(8, 10000, 200, false, true);
     85 BM_Matmul(20, 10000, 200, false, true);
     86 BM_Matmul(20, 20000, 200, false, true);
     87 
     88 // Test some matrix-vector multiplies.
     89 BM_Matmul(50, 50, 1, false, false);
     90 BM_Matmul(50, 50, 1, true, false);
     91 BM_Matmul(50, 50, 1, false, true);
     92 BM_Matmul(50, 50, 1, true, true);
     93 BM_Matmul(500, 500, 1, false, false);
     94 BM_Matmul(500, 500, 1, true, false);
     95 BM_Matmul(500, 500, 1, false, true);
     96 BM_Matmul(500, 500, 1, true, true);
     97 BM_Matmul(2000, 2000, 1, false, false);
     98 BM_Matmul(2000, 2000, 1, true, false);
     99 BM_Matmul(2000, 2000, 1, false, true);
    100 BM_Matmul(2000, 2000, 1, true, true);
    101 
    102 // Test some vector-matrix multiplies.
    103 BM_Matmul(1, 50, 50, false, false);
    104 BM_Matmul(1, 50, 50, true, false);
    105 BM_Matmul(1, 50, 50, false, true);
    106 BM_Matmul(1, 50, 50, true, true);
    107 BM_Matmul(1, 500, 500, false, false);
    108 BM_Matmul(1, 500, 500, true, false);
    109 BM_Matmul(1, 500, 500, false, true);
    110 BM_Matmul(1, 500, 500, true, true);
    111 BM_Matmul(1, 2000, 2000, false, false);
    112 BM_Matmul(1, 2000, 2000, true, false);
    113 BM_Matmul(1, 2000, 2000, false, true);
    114 BM_Matmul(1, 2000, 2000, true, true);
    115 
    116 // Test some rank-one products.
    117 BM_Matmul(50, 1, 50, false, false);
    118 BM_Matmul(50, 1, 50, true, false);
    119 BM_Matmul(50, 1, 50, false, true);
    120 BM_Matmul(50, 1, 50, true, true);
    121 BM_Matmul(500, 1, 500, false, false);
    122 BM_Matmul(500, 1, 500, true, false);
    123 BM_Matmul(500, 1, 500, false, true);
    124 BM_Matmul(500, 1, 500, true, true);
    125 BM_Matmul(2000, 1, 2000, false, false);
    126 BM_Matmul(2000, 1, 2000, true, false);
    127 BM_Matmul(2000, 1, 2000, false, true);
    128 BM_Matmul(2000, 1, 2000, true, true);
    129 
    130 }  // end namespace tensorflow
    131