1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <functional> 17 #include <memory> 18 #include <vector> 19 20 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" 21 #include "tensorflow/core/graph/node_builder.h" 22 #include "tensorflow/core/kernels/ops_testutil.h" 23 #include "tensorflow/core/platform/test_benchmark.h" 24 25 namespace tensorflow { 26 27 static Graph* PTruncatedNormal(int num_batches, int samples_per_batch) { 28 Graph* g = new Graph(OpRegistry::Global()); 29 Tensor shape_t(DT_INT32, TensorShape({2})); 30 shape_t.flat<int32>().setValues({num_batches, samples_per_batch}); 31 32 // Use mean 0 and stdev 1 33 Tensor means_t(DT_FLOAT, TensorShape({num_batches})); 34 means_t.flat<float>().setConstant(0.0); 35 Tensor stdevs_t(DT_FLOAT, TensorShape({num_batches})); 36 stdevs_t.flat<float>().setConstant(1.0); 37 38 Tensor minvals_t(DT_FLOAT, TensorShape({num_batches})); 39 minvals_t.flat<float>().setRandom(); 40 Tensor maxvals_t(DT_FLOAT, TensorShape({num_batches})); 41 maxvals_t.flat<float>().setConstant(5.0); 42 43 Node* ret; 44 TF_CHECK_OK( 45 NodeBuilder(g->NewName("truncatednormal"), "ParameterizedTruncatedNormal") 46 .Input(test::graph::Constant(g, shape_t)) 47 .Input(test::graph::Constant(g, means_t)) 48 .Input(test::graph::Constant(g, stdevs_t)) 49 .Input(test::graph::Constant(g, minvals_t)) 50 .Input(test::graph::Constant(g, maxvals_t)) 51 .Attr("dtype", DT_FLOAT) 52 .Finalize(g, &ret)); 53 return g; 54 } 55 56 static Graph* PTruncatedNormal2SD(int num_batches, int samples_per_batch) { 57 Graph* g = new Graph(OpRegistry::Global()); 58 Tensor shape_t(DT_INT32, TensorShape({2})); 59 shape_t.flat<int32>().setValues({num_batches, samples_per_batch}); 60 61 Tensor means_t(DT_FLOAT, TensorShape({num_batches})); 62 means_t.flat<float>().setConstant(0.0); 63 Tensor stdevs_t(DT_FLOAT, TensorShape({num_batches})); 64 stdevs_t.flat<float>().setConstant(1.0); 65 Tensor minvals_t(DT_FLOAT, TensorShape({num_batches})); 66 minvals_t.flat<float>().setConstant(-2.0); 67 Tensor maxvals_t(DT_FLOAT, TensorShape({num_batches})); 68 maxvals_t.flat<float>().setConstant(2.0); 69 70 Node* ret; 71 TF_CHECK_OK( 72 NodeBuilder(g->NewName("truncatednormal"), "ParameterizedTruncatedNormal") 73 .Input(test::graph::Constant(g, shape_t)) 74 .Input(test::graph::Constant(g, means_t)) 75 .Input(test::graph::Constant(g, stdevs_t)) 76 .Input(test::graph::Constant(g, minvals_t)) 77 .Input(test::graph::Constant(g, maxvals_t)) 78 .Attr("dtype", DT_FLOAT) 79 .Finalize(g, &ret)); 80 return g; 81 } 82 83 static Graph* PTruncatedNormalOneTail(int num_batches, int samples_per_batch) { 84 Graph* g = new Graph(OpRegistry::Global()); 85 Tensor shape_t(DT_INT32, TensorShape({2})); 86 shape_t.flat<int32>().setValues({num_batches, samples_per_batch}); 87 88 Tensor means_t(DT_FLOAT, TensorShape({num_batches})); 89 means_t.flat<float>().setConstant(0.0); 90 Tensor stdevs_t(DT_FLOAT, TensorShape({num_batches})); 91 stdevs_t.flat<float>().setConstant(1.0); 92 Tensor minvals_t(DT_FLOAT, TensorShape({num_batches})); 93 minvals_t.flat<float>().setConstant(2.0); 94 Tensor maxvals_t(DT_FLOAT, TensorShape({num_batches})); 95 maxvals_t.flat<float>().setConstant(std::numeric_limits<float>::infinity()); 96 97 Node* ret; 98 TF_CHECK_OK( 99 NodeBuilder(g->NewName("truncatednormal"), "ParameterizedTruncatedNormal") 100 .Input(test::graph::Constant(g, shape_t)) 101 .Input(test::graph::Constant(g, means_t)) 102 .Input(test::graph::Constant(g, stdevs_t)) 103 .Input(test::graph::Constant(g, minvals_t)) 104 .Input(test::graph::Constant(g, maxvals_t)) 105 .Attr("dtype", DT_FLOAT) 106 .Finalize(g, &ret)); 107 return g; 108 } 109 110 #define BM_PTruncatedNormalDev(DEVICE, B, S) \ 111 static void BM_PTruncatedNormal_##DEVICE##_##B##_##S(int iters) { \ 112 test::Benchmark(#DEVICE, PTruncatedNormal(B, S)).Run(iters); \ 113 testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \ 114 } \ 115 BENCHMARK(BM_PTruncatedNormal_##DEVICE##_##B##_##S); 116 117 #define BM_PTruncatedNormalDev_2SD(DEVICE, B, S) \ 118 static void BM_PTruncatedNormal_2SD_##DEVICE##_##B##_##S(int iters) { \ 119 test::Benchmark(#DEVICE, PTruncatedNormal2SD(B, S)).Run(iters); \ 120 testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \ 121 } \ 122 BENCHMARK(BM_PTruncatedNormal_2SD_##DEVICE##_##B##_##S); 123 124 #define BM_PTruncatedNormalDev_OneTail(DEVICE, B, S) \ 125 static void BM_PTruncatedNormal_OneTail_##DEVICE##_##B##_##S(int iters) { \ 126 test::Benchmark(#DEVICE, PTruncatedNormalOneTail(B, S)).Run(iters); \ 127 testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \ 128 } \ 129 BENCHMARK(BM_PTruncatedNormal_OneTail_##DEVICE##_##B##_##S); 130 131 BM_PTruncatedNormalDev(cpu, 1000, 1000); 132 BM_PTruncatedNormalDev_2SD(cpu, 10000, 100); 133 BM_PTruncatedNormalDev_OneTail(cpu, 10000, 100); 134 BM_PTruncatedNormalDev(gpu, 1000, 1000); 135 BM_PTruncatedNormalDev_2SD(gpu, 10000, 100); 136 BM_PTruncatedNormalDev_OneTail(gpu, 10000, 100); 137 138 } // namespace tensorflow 139