Home | History | Annotate | Download | only in costs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
     17 #include "tensorflow/core/framework/tensor_shape.pb.h"
     18 #include "tensorflow/core/framework/types.h"
     19 #include "tensorflow/core/platform/test.h"
     20 #include "tensorflow/core/protobuf/device_properties.pb.h"
     21 
     22 namespace tensorflow {
     23 namespace grappler {
     24 
     25 namespace {
     26 // Wrangles the minimum number of proto fields to set up a matrix.
     27 void DescribeMatrix(int rows, int columns, OpInfo* op_features) {
     28   auto input = op_features->add_inputs();
     29   auto shape = input->mutable_shape();
     30   auto shape_rows = shape->add_dim();
     31   shape_rows->set_size(rows);
     32   auto shape_columns = shape->add_dim();
     33   shape_columns->set_size(columns);
     34   input->set_dtype(DT_FLOAT);
     35 }
     36 
     37 void SetCpuDevice(OpInfo* op_features) {
     38   auto device = op_features->mutable_device();
     39   device->set_type("CPU");
     40   device->set_num_cores(10);
     41   device->set_bandwidth(10000000);  // 10000000 KB/s = 10 GB/s
     42   device->set_frequency(1000);      // 1000 Mhz = 1 GHz
     43 }
     44 
     45 // Returns an OpInfo for MatMul with the minimum set of fields set up.
     46 OpContext DescribeMatMul(int m, int n, int l, int k) {
     47   OpContext op_context;
     48   SetCpuDevice(&op_context.op_info);
     49   op_context.op_info.set_op("MatMul");
     50 
     51   DescribeMatrix(m, l, &op_context.op_info);
     52   DescribeMatrix(k, n, &op_context.op_info);
     53   return op_context;
     54 }
     55 
     56 // Returns an OpInfo for MatMul with unknown input shapes.
     57 OpContext DescribeMatMulUnknownShape() {
     58   OpContext op_context;
     59   SetCpuDevice(&op_context.op_info);
     60   op_context.op_info.set_op("MatMul");
     61 
     62   auto input = op_context.op_info.add_inputs();
     63   auto shape = input->mutable_shape();
     64   shape->set_unknown_rank(true);
     65 
     66   input = op_context.op_info.add_inputs();
     67   shape = input->mutable_shape();
     68   shape->set_unknown_rank(true);
     69 
     70   return op_context;
     71 }
     72 
     73 // Wrangles the minimum number of proto fields to set up an input of
     74 // arbitrary rank and type.
     75 void DescribeArbitraryRankInput(const std::vector<int>& dims, DataType dtype,
     76                                 OpInfo* op_features) {
     77   auto input = op_features->add_inputs();
     78   input->set_dtype(dtype);
     79   auto shape = input->mutable_shape();
     80   for (auto d : dims) {
     81     shape->add_dim()->set_size(d);
     82   }
     83 }
     84 
     85 // Returns an OpInfo for a BatchMatMul
     86 OpContext DescribeBatchMatMul(const std::vector<int>& dims_a,
     87                               const std::vector<int>& dims_b) {
     88   OpContext op_context;
     89   SetCpuDevice(&op_context.op_info);
     90   op_context.op_info.set_op("BatchMatMul");
     91 
     92   DescribeArbitraryRankInput(dims_a, DT_FLOAT, &op_context.op_info);
     93   DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
     94   return op_context;
     95 }
     96 
     97 // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
     98 // estimation purposes.
     99 void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
    100                       OpInfo* op_features) {
    101   auto input = op_features->add_inputs();
    102   auto shape = input->mutable_shape();
    103   shape->add_dim()->set_size(dim0);
    104   shape->add_dim()->set_size(dim1);
    105   shape->add_dim()->set_size(dim2);
    106   shape->add_dim()->set_size(dim3);
    107   input->set_dtype(DT_FLOAT);
    108 }
    109 
    110 // Returns an OpInfo for Conv2D with the minimum set of fields set up.
    111 OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2,
    112                               int kx, int ky, int oz) {
    113   OpContext op_context;
    114   SetCpuDevice(&op_context.op_info);
    115   op_context.op_info.set_op("Conv2D");
    116 
    117   DescribeTensor4D(batch, ix, iy, iz1, &op_context.op_info);
    118   DescribeTensor4D(kx, ky, iz2, oz, &op_context.op_info);
    119   return op_context;
    120 }
    121 
    122 OpContext DescribeOp(const string& op, int size1, int size2) {
    123   OpContext op_context;
    124   SetCpuDevice(&op_context.op_info);
    125   op_context.op_info.set_op(op);
    126 
    127   DescribeTensor4D(size1, 1, 1, 1, &op_context.op_info);
    128   DescribeTensor4D(2 * size1, size2, 1, 1, &op_context.op_info);
    129 
    130   auto output = op_context.op_info.add_outputs();
    131   auto shape = output->mutable_shape();
    132   shape->add_dim()->set_size(2 * size1);
    133   shape->add_dim()->set_size(size2);
    134   shape->add_dim()->set_size(1);
    135   shape->add_dim()->set_size(1);
    136   output->set_dtype(DT_FLOAT);
    137 
    138   SetCpuDevice(&op_context.op_info);
    139   return op_context;
    140 }
    141 }  // namespace
    142 
    143 class OpLevelCostEstimatorTest : public ::testing::Test {
    144  protected:
    145   Costs PredictCosts(const OpContext& op_context) const {
    146     return estimator_.PredictCosts(op_context);
    147   }
    148 
    149   int64 CountMatMulOperations(const OpInfo& op_features,
    150                               bool* found_unknown_shapes) const {
    151     return estimator_.CountMatMulOperations(op_features, found_unknown_shapes);
    152   }
    153 
    154   int64 CountBatchMatMulOperations(const OpInfo& op_features,
    155                                    bool* found_unknown_shapes) const {
    156     return estimator_.CountBatchMatMulOperations(op_features,
    157                                                  found_unknown_shapes);
    158   }
    159 
    160   void SetComputeMemoryOverlap(bool value) {
    161     estimator_.compute_memory_overlap_ = value;
    162   }
    163 
    164   OpLevelCostEstimator estimator_;
    165 };
    166 
    167 TEST_F(OpLevelCostEstimatorTest, DummyExecutionTime) {
    168   auto cost = PredictCosts(DescribeOp("Dummy", 1000, 1));
    169   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
    170   EXPECT_EQ(Costs::Duration(0), cost.compute_time);
    171   EXPECT_EQ(Costs::Duration(2000), cost.execution_time);
    172   EXPECT_TRUE(cost.inaccurate);
    173 }
    174 
    175 TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) {
    176   SetComputeMemoryOverlap(true);
    177   auto cost = PredictCosts(DescribeOp("Dummy", 1000, 1));
    178   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
    179   EXPECT_EQ(Costs::Duration(0), cost.compute_time);
    180   EXPECT_EQ(Costs::Duration(2000), cost.execution_time);  // max(2000, 200)
    181   EXPECT_TRUE(cost.inaccurate);
    182   SetComputeMemoryOverlap(false);  // Set it back to default.
    183 }
    184 
    185 TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) {
    186   auto cost = PredictCosts(DescribeOp("Mul", 1000, 1));
    187   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
    188   EXPECT_EQ(Costs::Duration(200), cost.compute_time);
    189   EXPECT_EQ(Costs::Duration(2200), cost.execution_time);
    190   EXPECT_FALSE(cost.inaccurate);
    191 }
    192 
    193 TEST_F(OpLevelCostEstimatorTest, MulBroadcastExecutionTime) {
    194   auto cost = PredictCosts(DescribeOp("Mul", 1000, 2));
    195   EXPECT_EQ(Costs::Duration(3600), cost.memory_time);
    196   EXPECT_EQ(Costs::Duration(400), cost.compute_time);
    197   EXPECT_EQ(Costs::Duration(4000), cost.execution_time);
    198   EXPECT_FALSE(cost.inaccurate);
    199 }
    200 
    201 TEST_F(OpLevelCostEstimatorTest, ModExecutionTime) {
    202   auto cost = PredictCosts(DescribeOp("Mod", 1000, 1));
    203   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
    204   EXPECT_EQ(Costs::Duration(1600), cost.compute_time);
    205   EXPECT_EQ(Costs::Duration(3600), cost.execution_time);
    206   EXPECT_FALSE(cost.inaccurate);
    207 }
    208 
    209 TEST_F(OpLevelCostEstimatorTest, UnknownOrPartialShape) {
    210   EXPECT_FALSE(PredictCosts(DescribeMatMul(2, 4, 7, 7)).inaccurate);
    211   EXPECT_TRUE(PredictCosts(DescribeMatMul(-1, 4, 7, 7)).inaccurate);
    212   EXPECT_TRUE(PredictCosts(DescribeMatMul(2, 4, -1, 7)).inaccurate);
    213 
    214   EXPECT_FALSE(PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256))
    215                    .inaccurate);
    216   EXPECT_TRUE(PredictCosts(DescribeConvolution(16, -1, 19, 48, 48, 5, 5, 256))
    217                   .inaccurate);
    218 }
    219 
    220 TEST_F(OpLevelCostEstimatorTest, BatchMatMul) {
    221   EXPECT_TRUE(PredictCosts(DescribeBatchMatMul({}, {})).inaccurate);
    222   EXPECT_TRUE(PredictCosts(DescribeBatchMatMul({2, 4}, {})).inaccurate);
    223   EXPECT_FALSE(PredictCosts(DescribeBatchMatMul({2, 4}, {4, 2})).inaccurate);
    224   EXPECT_FALSE(
    225       PredictCosts(DescribeBatchMatMul({1, 2, 4}, {1, 4, 2})).inaccurate);
    226   EXPECT_FALSE(
    227       PredictCosts(DescribeBatchMatMul({2, 4}, {1, 3, 4, 2})).inaccurate);
    228   bool matmul_inaccurate = false;
    229   bool batch_matmul_inaccurate = false;
    230   EXPECT_EQ(
    231       CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
    232                             &matmul_inaccurate),
    233       CountBatchMatMulOperations(DescribeBatchMatMul({2, 4}, {4, 2}).op_info,
    234                                  &batch_matmul_inaccurate));
    235   EXPECT_EQ(matmul_inaccurate, batch_matmul_inaccurate);
    236   EXPECT_EQ(10 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
    237                                        &matmul_inaccurate),
    238             CountBatchMatMulOperations(
    239                 DescribeBatchMatMul({10, 2, 4}, {-1, 10, 4, 2}).op_info,
    240                 &batch_matmul_inaccurate));
    241   EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
    242   EXPECT_EQ(20 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
    243                                        &matmul_inaccurate),
    244             CountBatchMatMulOperations(
    245                 DescribeBatchMatMul({2, 10, 2, 4}, {-1, 10, 4, 2}).op_info,
    246                 &batch_matmul_inaccurate));
    247   EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
    248 }
    249 
    250 }  // end namespace grappler
    251 }  // end namespace tensorflow
    252