1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" 17 #include "tensorflow/core/framework/tensor_shape.pb.h" 18 #include "tensorflow/core/framework/types.h" 19 #include "tensorflow/core/platform/test.h" 20 #include "tensorflow/core/protobuf/device_properties.pb.h" 21 22 namespace tensorflow { 23 namespace grappler { 24 25 namespace { 26 // Wrangles the minimum number of proto fields to set up a matrix. 27 void DescribeMatrix(int rows, int columns, OpInfo* op_features) { 28 auto input = op_features->add_inputs(); 29 auto shape = input->mutable_shape(); 30 auto shape_rows = shape->add_dim(); 31 shape_rows->set_size(rows); 32 auto shape_columns = shape->add_dim(); 33 shape_columns->set_size(columns); 34 input->set_dtype(DT_FLOAT); 35 } 36 37 void SetCpuDevice(OpInfo* op_features) { 38 auto device = op_features->mutable_device(); 39 device->set_type("CPU"); 40 device->set_num_cores(10); 41 device->set_bandwidth(10000000); // 10000000 KB/s = 10 GB/s 42 device->set_frequency(1000); // 1000 Mhz = 1 GHz 43 } 44 45 // Returns an OpInfo for MatMul with the minimum set of fields set up. 46 OpContext DescribeMatMul(int m, int n, int l, int k) { 47 OpContext op_context; 48 SetCpuDevice(&op_context.op_info); 49 op_context.op_info.set_op("MatMul"); 50 51 DescribeMatrix(m, l, &op_context.op_info); 52 DescribeMatrix(k, n, &op_context.op_info); 53 return op_context; 54 } 55 56 // Returns an OpInfo for MatMul with unknown input shapes. 57 OpContext DescribeMatMulUnknownShape() { 58 OpContext op_context; 59 SetCpuDevice(&op_context.op_info); 60 op_context.op_info.set_op("MatMul"); 61 62 auto input = op_context.op_info.add_inputs(); 63 auto shape = input->mutable_shape(); 64 shape->set_unknown_rank(true); 65 66 input = op_context.op_info.add_inputs(); 67 shape = input->mutable_shape(); 68 shape->set_unknown_rank(true); 69 70 return op_context; 71 } 72 73 // Wrangles the minimum number of proto fields to set up an input of 74 // arbitrary rank and type. 75 void DescribeArbitraryRankInput(const std::vector<int>& dims, DataType dtype, 76 OpInfo* op_features) { 77 auto input = op_features->add_inputs(); 78 input->set_dtype(dtype); 79 auto shape = input->mutable_shape(); 80 for (auto d : dims) { 81 shape->add_dim()->set_size(d); 82 } 83 } 84 85 // Returns an OpInfo for a BatchMatMul 86 OpContext DescribeBatchMatMul(const std::vector<int>& dims_a, 87 const std::vector<int>& dims_b) { 88 OpContext op_context; 89 SetCpuDevice(&op_context.op_info); 90 op_context.op_info.set_op("BatchMatMul"); 91 92 DescribeArbitraryRankInput(dims_a, DT_FLOAT, &op_context.op_info); 93 DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info); 94 return op_context; 95 } 96 97 // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost 98 // estimation purposes. 99 void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3, 100 OpInfo* op_features) { 101 auto input = op_features->add_inputs(); 102 auto shape = input->mutable_shape(); 103 shape->add_dim()->set_size(dim0); 104 shape->add_dim()->set_size(dim1); 105 shape->add_dim()->set_size(dim2); 106 shape->add_dim()->set_size(dim3); 107 input->set_dtype(DT_FLOAT); 108 } 109 110 // Returns an OpInfo for Conv2D with the minimum set of fields set up. 111 OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2, 112 int kx, int ky, int oz) { 113 OpContext op_context; 114 SetCpuDevice(&op_context.op_info); 115 op_context.op_info.set_op("Conv2D"); 116 117 DescribeTensor4D(batch, ix, iy, iz1, &op_context.op_info); 118 DescribeTensor4D(kx, ky, iz2, oz, &op_context.op_info); 119 return op_context; 120 } 121 122 OpContext DescribeOp(const string& op, int size1, int size2) { 123 OpContext op_context; 124 SetCpuDevice(&op_context.op_info); 125 op_context.op_info.set_op(op); 126 127 DescribeTensor4D(size1, 1, 1, 1, &op_context.op_info); 128 DescribeTensor4D(2 * size1, size2, 1, 1, &op_context.op_info); 129 130 auto output = op_context.op_info.add_outputs(); 131 auto shape = output->mutable_shape(); 132 shape->add_dim()->set_size(2 * size1); 133 shape->add_dim()->set_size(size2); 134 shape->add_dim()->set_size(1); 135 shape->add_dim()->set_size(1); 136 output->set_dtype(DT_FLOAT); 137 138 SetCpuDevice(&op_context.op_info); 139 return op_context; 140 } 141 } // namespace 142 143 class OpLevelCostEstimatorTest : public ::testing::Test { 144 protected: 145 Costs PredictCosts(const OpContext& op_context) const { 146 return estimator_.PredictCosts(op_context); 147 } 148 149 int64 CountMatMulOperations(const OpInfo& op_features, 150 bool* found_unknown_shapes) const { 151 return estimator_.CountMatMulOperations(op_features, found_unknown_shapes); 152 } 153 154 int64 CountBatchMatMulOperations(const OpInfo& op_features, 155 bool* found_unknown_shapes) const { 156 return estimator_.CountBatchMatMulOperations(op_features, 157 found_unknown_shapes); 158 } 159 160 void SetComputeMemoryOverlap(bool value) { 161 estimator_.compute_memory_overlap_ = value; 162 } 163 164 OpLevelCostEstimator estimator_; 165 }; 166 167 TEST_F(OpLevelCostEstimatorTest, DummyExecutionTime) { 168 auto cost = PredictCosts(DescribeOp("Dummy", 1000, 1)); 169 EXPECT_EQ(Costs::Duration(2000), cost.memory_time); 170 EXPECT_EQ(Costs::Duration(0), cost.compute_time); 171 EXPECT_EQ(Costs::Duration(2000), cost.execution_time); 172 EXPECT_TRUE(cost.inaccurate); 173 } 174 175 TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) { 176 SetComputeMemoryOverlap(true); 177 auto cost = PredictCosts(DescribeOp("Dummy", 1000, 1)); 178 EXPECT_EQ(Costs::Duration(2000), cost.memory_time); 179 EXPECT_EQ(Costs::Duration(0), cost.compute_time); 180 EXPECT_EQ(Costs::Duration(2000), cost.execution_time); // max(2000, 200) 181 EXPECT_TRUE(cost.inaccurate); 182 SetComputeMemoryOverlap(false); // Set it back to default. 183 } 184 185 TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) { 186 auto cost = PredictCosts(DescribeOp("Mul", 1000, 1)); 187 EXPECT_EQ(Costs::Duration(2000), cost.memory_time); 188 EXPECT_EQ(Costs::Duration(200), cost.compute_time); 189 EXPECT_EQ(Costs::Duration(2200), cost.execution_time); 190 EXPECT_FALSE(cost.inaccurate); 191 } 192 193 TEST_F(OpLevelCostEstimatorTest, MulBroadcastExecutionTime) { 194 auto cost = PredictCosts(DescribeOp("Mul", 1000, 2)); 195 EXPECT_EQ(Costs::Duration(3600), cost.memory_time); 196 EXPECT_EQ(Costs::Duration(400), cost.compute_time); 197 EXPECT_EQ(Costs::Duration(4000), cost.execution_time); 198 EXPECT_FALSE(cost.inaccurate); 199 } 200 201 TEST_F(OpLevelCostEstimatorTest, ModExecutionTime) { 202 auto cost = PredictCosts(DescribeOp("Mod", 1000, 1)); 203 EXPECT_EQ(Costs::Duration(2000), cost.memory_time); 204 EXPECT_EQ(Costs::Duration(1600), cost.compute_time); 205 EXPECT_EQ(Costs::Duration(3600), cost.execution_time); 206 EXPECT_FALSE(cost.inaccurate); 207 } 208 209 TEST_F(OpLevelCostEstimatorTest, UnknownOrPartialShape) { 210 EXPECT_FALSE(PredictCosts(DescribeMatMul(2, 4, 7, 7)).inaccurate); 211 EXPECT_TRUE(PredictCosts(DescribeMatMul(-1, 4, 7, 7)).inaccurate); 212 EXPECT_TRUE(PredictCosts(DescribeMatMul(2, 4, -1, 7)).inaccurate); 213 214 EXPECT_FALSE(PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256)) 215 .inaccurate); 216 EXPECT_TRUE(PredictCosts(DescribeConvolution(16, -1, 19, 48, 48, 5, 5, 256)) 217 .inaccurate); 218 } 219 220 TEST_F(OpLevelCostEstimatorTest, BatchMatMul) { 221 EXPECT_TRUE(PredictCosts(DescribeBatchMatMul({}, {})).inaccurate); 222 EXPECT_TRUE(PredictCosts(DescribeBatchMatMul({2, 4}, {})).inaccurate); 223 EXPECT_FALSE(PredictCosts(DescribeBatchMatMul({2, 4}, {4, 2})).inaccurate); 224 EXPECT_FALSE( 225 PredictCosts(DescribeBatchMatMul({1, 2, 4}, {1, 4, 2})).inaccurate); 226 EXPECT_FALSE( 227 PredictCosts(DescribeBatchMatMul({2, 4}, {1, 3, 4, 2})).inaccurate); 228 bool matmul_inaccurate = false; 229 bool batch_matmul_inaccurate = false; 230 EXPECT_EQ( 231 CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info, 232 &matmul_inaccurate), 233 CountBatchMatMulOperations(DescribeBatchMatMul({2, 4}, {4, 2}).op_info, 234 &batch_matmul_inaccurate)); 235 EXPECT_EQ(matmul_inaccurate, batch_matmul_inaccurate); 236 EXPECT_EQ(10 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info, 237 &matmul_inaccurate), 238 CountBatchMatMulOperations( 239 DescribeBatchMatMul({10, 2, 4}, {-1, 10, 4, 2}).op_info, 240 &batch_matmul_inaccurate)); 241 EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate); 242 EXPECT_EQ(20 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info, 243 &matmul_inaccurate), 244 CountBatchMatMulOperations( 245 DescribeBatchMatMul({2, 10, 2, 4}, {-1, 10, 4, 2}).op_info, 246 &batch_matmul_inaccurate)); 247 EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate); 248 } 249 250 } // end namespace grappler 251 } // end namespace tensorflow 252