Home | History | Annotate | Download | only in costs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
     17 
     18 #include "third_party/eigen3/Eigen/Core"
     19 #include "tensorflow/core/framework/attr_value.pb.h"
     20 #include "tensorflow/core/framework/attr_value_util.h"
     21 #include "tensorflow/core/framework/tensor_shape.pb.h"
     22 #include "tensorflow/core/grappler/clusters/utils.h"
     23 
     24 namespace tensorflow {
     25 namespace grappler {
     26 
     27 constexpr int kOpsPerMac = 2;
     28 constexpr char kConst[] = "Const";
     29 constexpr char kConv2d[] = "Conv2D";
     30 constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
     31 constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
     32 constexpr char kMatMul[] = "MatMul";
     33 constexpr char kSparseMatMul[] = "SparseMatMul";
     34 constexpr char kPlaceholder[] = "Placeholder";
     35 constexpr char kIdentity[] = "Identity";
     36 constexpr char kRefIdentity[] = "RefIdentity";
     37 constexpr char kNoOp[] = "NoOp";
     38 constexpr char kReshape[] = "Reshape";
     39 constexpr char kRecv[] = "_Recv";
     40 constexpr char kSend[] = "_Send";
     41 constexpr char kBatchMatMul[] = "BatchMatMul";
     42 constexpr char kVariable[] = "Variable";
     43 constexpr char kVariableV2[] = "VariableV2";
     44 constexpr char kRank[] = "Rank";
     45 constexpr char kShape[] = "Shape";
     46 constexpr char kSize[] = "Size";
     47 constexpr char kStopGradient[] = "StopGradient";
     48 constexpr char kPreventGradient[] = "PreventGradient";
     49 
     50 static const Costs::Duration kMinComputeTime(1);
     51 
     52 namespace {
     53 
     54 string GetDataFormat(const OpInfo& op_features) {
     55   string data_format = "NHWC";  // Default format.
     56   if (op_features.attr().find("data_format") != op_features.attr().end()) {
     57     data_format = op_features.attr().at("data_format").s();
     58   }
     59   return data_format;
     60 }
     61 
     62 Padding GetPadding(const OpInfo& op_features) {
     63   if (op_features.attr().find("padding") != op_features.attr().end() &&
     64       op_features.attr().at("padding").s() == "VALID") {
     65     return Padding::VALID;
     66   }
     67   return Padding::SAME;  // Default padding.
     68 }
     69 
     70 std::vector<int64> GetStrides(const OpInfo& op_features) {
     71   if (op_features.attr().find("strides") != op_features.attr().end()) {
     72     const auto strides = op_features.attr().at("strides").list().i();
     73     return {strides[0], strides[1], strides[2], strides[3]};
     74   }
     75   return {1, 1, 1, 1};
     76 }
     77 
     78 int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride,
     79                     const Padding& padding) {
     80   // Logic for calculating output shape is from GetWindowedOutputSizeVerbose()
     81   // function in third_party/tensorflow/core/framework/common_shape_fns.cc.
     82   if (padding == Padding::VALID) {
     83     return (input - filter + stride) / stride;
     84   } else {  // SAME.
     85     return (input + stride - 1) / stride;
     86   }
     87 }
     88 
     89 // Return a minimum shape if the shape is unknown. If known, return the original
     90 // shape.
     91 TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
     92                                       int rank, bool* found_unknown_shapes) {
     93   auto shape = original_shape;
     94   if (shape.unknown_rank() || shape.dim_size() < rank) {
     95     *found_unknown_shapes = true;
     96     TensorShapeProto::Dim dim;
     97     VLOG(2) << "Use minimum shape because the rank is unknown.";
     98     // The size of each dimension is at least 1, if unknown.
     99     dim.set_size(1);
    100     for (int i = 0; i < rank; i++) {
    101       *shape.add_dim() = dim;
    102     }
    103   } else {
    104     for (int i = 0; i < shape.dim_size(); i++) {
    105       if (shape.dim(i).size() < 0) {
    106         *found_unknown_shapes = true;
    107         VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
    108         // The size of each dimension is at least 1, if unknown.
    109         shape.mutable_dim(i)->set_size(1);
    110       }
    111     }
    112   }
    113   return shape;
    114 }
    115 
    116 // Return the output element count of a binary element-wise op considering
    117 // broadcasting.
    118 int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
    119                               const TensorShapeProto& input_shape_2) {
    120   bool found_unknown_shapes;
    121   int rank = std::max(1, input_shape_1.dim_size());
    122   TensorShapeProto output_shape =
    123       MaybeGetMinimumShape(input_shape_1, rank, &found_unknown_shapes);
    124 
    125   if (input_shape_1.dim_size() == input_shape_2.dim_size()) {
    126     auto shape_1 =
    127         MaybeGetMinimumShape(input_shape_1, rank, &found_unknown_shapes);
    128     auto shape_2 =
    129         MaybeGetMinimumShape(input_shape_2, rank, &found_unknown_shapes);
    130     if (shape_1.dim_size() == shape_2.dim_size()) {
    131       for (int i = 0; i < shape_1.dim_size(); i++) {
    132         output_shape.mutable_dim(i)->set_size(
    133             std::max(shape_1.dim(i).size(), shape_2.dim(i).size()));
    134       }
    135     }
    136   }
    137 
    138   int64 count = 1;
    139   for (int i = 0; i < output_shape.dim_size(); i++) {
    140     count *= output_shape.dim(i).size();
    141   }
    142   return count;
    143 }
    144 
    145 }  // namespace
    146 
    147 OpLevelCostEstimator::OpLevelCostEstimator() {
    148   // Syntactic sugar to build and return a lambda that takes an OpInfo and
    149   // returns a cost.
    150   typedef Costs (OpLevelCostEstimator::*CostImpl)(const OpContext& op_context)
    151       const;
    152   auto wrap = [this](CostImpl impl) -> std::function<Costs(const OpContext&)> {
    153     return [this, impl](const OpContext& op_context) {
    154       return (this->*impl)(op_context);
    155     };
    156   };
    157 
    158   device_cost_impl_ = {
    159       {kConv2d, wrap(&OpLevelCostEstimator::PredictConv2D)},
    160       {kConv2dBackpropFilter,
    161        wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)},
    162       {kConv2dBackpropInput,
    163        wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
    164       {kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
    165       {kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
    166       {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
    167 
    168       {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
    169 
    170       {kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)},
    171       {kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
    172       {kRefIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
    173       {kStopGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
    174       {kPreventGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
    175       {kReshape, wrap(&OpLevelCostEstimator::PredictIdentity)},
    176       {kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)},
    177       {kSend, wrap(&OpLevelCostEstimator::PredictIdentity)},
    178 
    179       {kConst, wrap(&OpLevelCostEstimator::PredictVariable)},
    180       {kVariable, wrap(&OpLevelCostEstimator::PredictVariable)},
    181       {kVariableV2, wrap(&OpLevelCostEstimator::PredictVariable)},
    182 
    183       {kRank, wrap(&OpLevelCostEstimator::PredictMetadata)},
    184       {kShape, wrap(&OpLevelCostEstimator::PredictMetadata)},
    185       {kSize, wrap(&OpLevelCostEstimator::PredictMetadata)}};
    186 
    187   elementwise_ops_ = {
    188       // Unary ops alphabetically sorted
    189       {"Acos", Eigen::internal::functor_traits<
    190                    Eigen::internal::scalar_acos_op<float>>::Cost},
    191       {"Asin", Eigen::internal::functor_traits<
    192                    Eigen::internal::scalar_asin_op<float>>::Cost},
    193       {"Atan", Eigen::internal::functor_traits<
    194                    Eigen::internal::scalar_atan_op<float>>::Cost},
    195       {"Atan2", Eigen::internal::functor_traits<
    196                     Eigen::internal::scalar_quotient_op<float>>::Cost +
    197                     Eigen::internal::functor_traits<
    198                         Eigen::internal::scalar_atan_op<float>>::Cost},
    199       {"Ceil", Eigen::internal::functor_traits<
    200                    Eigen::internal::scalar_ceil_op<float>>::Cost},
    201       {"Cos", Eigen::internal::functor_traits<
    202                   Eigen::internal::scalar_cos_op<float>>::Cost},
    203       {"Erf", 1},
    204       {"Erfc", 1},
    205       {"Exp", Eigen::internal::functor_traits<
    206                   Eigen::internal::scalar_exp_op<float>>::Cost},
    207       {"Expm1", Eigen::internal::functor_traits<
    208                     Eigen::internal::scalar_expm1_op<float>>::Cost},
    209       {"Floor", Eigen::internal::functor_traits<
    210                     Eigen::internal::scalar_floor_op<float>>::Cost},
    211       {"Inv", Eigen::internal::functor_traits<
    212                   Eigen::internal::scalar_inverse_op<float>>::Cost},
    213       {"InvGrad", 1},
    214       {"Lgamma", 1},
    215       {"Log", Eigen::internal::functor_traits<
    216                   Eigen::internal::scalar_log_op<float>>::Cost},
    217       {"Log1p", Eigen::internal::functor_traits<
    218                     Eigen::internal::scalar_log1p_op<float>>::Cost},
    219       {"Neg", Eigen::internal::functor_traits<
    220                   Eigen::internal::scalar_opposite_op<float>>::Cost},
    221       {"Reciprocal", Eigen::internal::functor_traits<
    222                          Eigen::internal::scalar_inverse_op<float>>::Cost},
    223       {"Rint", 1},
    224       {"Round", Eigen::internal::functor_traits<
    225                     Eigen::internal::scalar_round_op<float>>::Cost},
    226       {"Rsqrt", Eigen::internal::functor_traits<
    227                     Eigen::internal::scalar_rsqrt_op<float>>::Cost},
    228       {"Sqrt", Eigen::internal::functor_traits<
    229                    Eigen::internal::scalar_sqrt_op<float>>::Cost},
    230       {"Square", Eigen::internal::functor_traits<
    231                      Eigen::internal::scalar_square_op<float>>::Cost},
    232       {"Tanh", Eigen::internal::functor_traits<
    233                    Eigen::internal::scalar_tanh_op<float>>::Cost},
    234       {"Relu", Eigen::internal::functor_traits<
    235                    Eigen::internal::scalar_max_op<float>>::Cost},
    236       {"Sigmoid", Eigen::internal::functor_traits<
    237                       Eigen::internal::scalar_sigmoid_op<float>>::Cost},
    238       {"Sign", Eigen::internal::functor_traits<
    239                    Eigen::internal::scalar_sign_op<float>>::Cost},
    240       {"Sin", Eigen::internal::functor_traits<
    241                   Eigen::internal::scalar_sin_op<float>>::Cost},
    242       {"Tan", Eigen::internal::functor_traits<
    243                   Eigen::internal::scalar_tan_op<float>>::Cost},
    244       // Binary ops alphabetically sorted
    245       {"Add", Eigen::internal::functor_traits<
    246                   Eigen::internal::scalar_sum_op<float>>::Cost},
    247       {"ApproximateEqual", 1},
    248       {"Div", Eigen::internal::functor_traits<
    249                   Eigen::internal::scalar_quotient_op<float>>::Cost},
    250       {"Equal", 1},
    251       {"FloorDiv", Eigen::internal::functor_traits<
    252                        Eigen::internal::scalar_quotient_op<float>>::Cost},
    253       {"FloorMod", Eigen::internal::functor_traits<
    254                        Eigen::internal::scalar_mod_op<float>>::Cost},
    255       {"Greater", 1},
    256       {"GreaterEqual", 1},
    257       {"Less", 1},
    258       {"LessEqual", 1},
    259       {"LogicalAnd", Eigen::internal::functor_traits<
    260                          Eigen::internal::scalar_boolean_and_op>::Cost},
    261       {"LogicalNot", 1},
    262       {"LogicalOr", Eigen::internal::functor_traits<
    263                         Eigen::internal::scalar_boolean_or_op>::Cost},
    264       {"Maximum", Eigen::internal::functor_traits<
    265                       Eigen::internal::scalar_max_op<float>>::Cost},
    266       {"Minimum", Eigen::internal::functor_traits<
    267                       Eigen::internal::scalar_min_op<float>>::Cost},
    268       {"Mod", Eigen::internal::functor_traits<
    269                   Eigen::internal::scalar_mod_op<float>>::Cost},
    270       {"Mul", Eigen::internal::functor_traits<
    271                   Eigen::internal::scalar_product_op<float>>::Cost},
    272       {"NotEqual", 1},
    273       {"QuantizedAdd", Eigen::internal::functor_traits<
    274                            Eigen::internal::scalar_sum_op<float>>::Cost},
    275       {"QuantizedMul", Eigen::internal::functor_traits<
    276                            Eigen::internal::scalar_product_op<float>>::Cost},
    277       {"RealDiv", Eigen::internal::functor_traits<
    278                       Eigen::internal::scalar_quotient_op<float>>::Cost},
    279       {"SquareDifference", 1},
    280       {"Sub", Eigen::internal::functor_traits<
    281                   Eigen::internal::scalar_difference_op<float>>::Cost},
    282       {"TruncateDiv", Eigen::internal::functor_traits<
    283                           Eigen::internal::scalar_quotient_op<float>>::Cost},
    284       {"TruncateMod", Eigen::internal::functor_traits<
    285                           Eigen::internal::scalar_mod_op<float>>::Cost}};
    286 
    287   // By default, use sum of memory_time and compute_time for execution_time.
    288   compute_memory_overlap_ = false;
    289 }
    290 
    291 Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
    292   const auto& op_features = op_context.op_info;
    293   auto it = device_cost_impl_.find(op_features.op());
    294   if (it == device_cost_impl_.end()) {
    295     if (elementwise_ops_.find(op_features.op()) != elementwise_ops_.end()) {
    296       return PredictCwiseOp(op_context);
    297     }
    298 
    299     VLOG(1) << "Missing accurate estimator for op: " << op_features.op();
    300 
    301     return PredictCostOfAnUnknownOp(op_context);
    302   }
    303 
    304   std::function<Costs(const OpContext&)> estimator = it->second;
    305   Costs costs = estimator(op_context);
    306   VLOG(1) << "Operation " << op_features.op() << " takes "
    307           << costs.execution_time.count() << " ns.";
    308   return costs;
    309 }
    310 
    311 OpLevelCostEstimator::DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
    312     const DeviceProperties& device) const {
    313   double gflops = -1;
    314   double gb_per_sec = -1;
    315 
    316   if (device.type() == "CPU") {
    317     // Check if vector instructions are available, and refine performance
    318     // prediction based on this.
    319     // Frequencies are stored in MHz in the DeviceProperties.
    320     gflops = device.num_cores() * device.frequency() * 1e-3;
    321     if (gb_per_sec < 0) {
    322       if (device.bandwidth() > 0) {
    323         gb_per_sec = device.bandwidth() / 1e6;
    324       } else {
    325         gb_per_sec = 32;
    326       }
    327     }
    328   } else if (device.type() == "GPU") {
    329     const string architecture = device.environment().at("architecture");
    330     int cores_per_multiprocessor;
    331     if (architecture < "3") {
    332       // Fermi
    333       cores_per_multiprocessor = 32;
    334     } else if (architecture < "4") {
    335       // Kepler
    336       cores_per_multiprocessor = 192;
    337     } else if (architecture < "6") {
    338       // Maxwell
    339       cores_per_multiprocessor = 128;
    340     } else {
    341       // Pascal (compute capability version 6) and Volta (compute capability
    342       // version 7)
    343       cores_per_multiprocessor = 64;
    344     }
    345     gflops = device.num_cores() * device.frequency() * 1e-3 *
    346              cores_per_multiprocessor * kOpsPerMac;
    347     if (device.bandwidth() > 0) {
    348       gb_per_sec = device.bandwidth() / 1e6;
    349     } else {
    350       gb_per_sec = 100;
    351     }
    352   }
    353   VLOG(1) << "Device: " << device.type() << " gflops: " << gflops
    354           << " gb_per_sec: " << gb_per_sec;
    355 
    356   DCHECK_LT(0, gflops) << device.DebugString();
    357   DCHECK_LT(0, gb_per_sec) << device.DebugString();
    358 
    359   return {gflops, gb_per_sec};
    360 }
    361 
    362 Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const {
    363   const auto& op_features = op_context.op_info;
    364   bool found_unknown_shapes = false;
    365   // For unary or binary element-wise operations, op count is the element count
    366   // of any input. We use the count for the largest input here to be more robust
    367   // in case that the shape is unknown or partially known for other input.
    368   int64 op_count =
    369       CalculateLargestInputCount(op_features, &found_unknown_shapes);
    370   // If output shape is available, try use the element count calcuated from
    371   // that.
    372   if (op_features.outputs_size() > 0) {
    373     op_count =
    374         std::max(op_count, CalculateTensorElementCount(op_features.outputs(0),
    375                                                        &found_unknown_shapes));
    376   }
    377   // For binary ops, calculate the output shape possibly resulting from
    378   // broadcasting.
    379   if (op_features.inputs_size() >= 2) {
    380     op_count = std::max(op_count,
    381                         CwiseOutputElementCount(op_features.inputs(0).shape(),
    382                                                 op_features.inputs(1).shape()));
    383   }
    384 
    385   int op_cost = 1;
    386   bool is_known_elementwise_op = false;
    387   auto it = elementwise_ops_.find(op_features.op());
    388   if (it != elementwise_ops_.end()) {
    389     op_cost = it->second;
    390     is_known_elementwise_op = true;
    391   } else {
    392     LOG(WARNING) << "Not a cwise op: " << op_features.op();
    393   }
    394 
    395   Costs costs = PredictOpCountBasedCost(op_count * op_cost, op_features);
    396   if (found_unknown_shapes || !is_known_elementwise_op) {
    397     costs.inaccurate = true;
    398   }
    399   return costs;
    400 }
    401 
    402 Costs OpLevelCostEstimator::PredictCostOfAnUnknownOp(
    403     const OpContext& op_context) const {
    404   // Don't assume the operation is cwise, return cost based on input/output size
    405   // and admit that it is inaccurate...
    406   auto costs = PredictOpCountBasedCost(0, op_context.op_info);
    407   costs.inaccurate = true;
    408   return costs;
    409 }
    410 
    411 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
    412     double operations, const OpInfo& op_features) const {
    413   DeviceInfo device_perf = GetDeviceInfo(op_features.device());
    414   if (device_perf.gigaops <= 0 || device_perf.gb_per_sec <= 0) {
    415     VLOG(1) << "BAD DEVICE. Op:" << op_features.op()
    416             << " device type:" << op_features.device().type()
    417             << " device model:" << op_features.device().model();
    418   }
    419 
    420   Costs::NanoSeconds compute_cost(std::ceil(operations / device_perf.gigaops));
    421   VLOG(1) << "Op:" << op_features.op() << " GOps:" << operations / 1e9
    422           << " Execution Time (ns):" << compute_cost.count();
    423 
    424   bool found_unknown_shapes = false;
    425   const double total_input_size =
    426       CalculateInputSize(op_features, &found_unknown_shapes);
    427   const double total_output_size =
    428       CalculateOutputSize(op_features, &found_unknown_shapes);
    429   const double total_io_size = total_input_size + total_output_size;
    430 
    431   Costs::NanoSeconds memory_cost(
    432       std::ceil(total_io_size / device_perf.gb_per_sec));
    433   VLOG(1) << "Op:" << op_features.op() << " Size (KB):" << (total_io_size) / 1e3
    434           << " Memory Time (ns):" << memory_cost.count();
    435 
    436   Costs costs;
    437   costs.compute_time = compute_cost;
    438   costs.memory_time = memory_cost;
    439   if (compute_memory_overlap_) {
    440     costs.execution_time = std::max(compute_cost, memory_cost);
    441   } else {
    442     costs.execution_time = compute_cost + memory_cost;
    443   }
    444   costs.inaccurate = found_unknown_shapes;
    445   costs.max_memory = total_output_size;
    446   return costs;
    447 }
    448 
    449 int64 OpLevelCostEstimator::CountConv2DOperations(
    450     const OpInfo& op_features, bool* found_unknown_shapes) const {
    451   return CountConv2DOperations(op_features, nullptr, found_unknown_shapes);
    452 }
    453 
    454 // Helper to translate the positional arguments into named fields.
    455 OpLevelCostEstimator::ConvolutionDimensions
    456 OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
    457     const TensorShapeProto& original_image_shape,
    458     const TensorShapeProto& original_filter_shape, const OpInfo& op_features,
    459     bool* found_unknown_shapes) {
    460   VLOG(2) << "op features: " << op_features.DebugString();
    461   VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
    462   VLOG(2) << "Original filter shape: " << original_filter_shape.DebugString();
    463   auto image_shape =
    464       MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
    465   auto filter_shape =
    466       MaybeGetMinimumShape(original_filter_shape, 4, found_unknown_shapes);
    467   VLOG(2) << "Image shape: " << image_shape.DebugString();
    468   VLOG(2) << "Filter shape: " << filter_shape.DebugString();
    469 
    470   int x_index, y_index, channel_index;
    471   const string& data_format = GetDataFormat(op_features);
    472   if (data_format == "NCHW") {
    473     x_index = 2;
    474     y_index = 3;
    475     channel_index = 1;
    476   } else {
    477     x_index = 1;
    478     y_index = 2;
    479     channel_index = 3;
    480   }
    481   int64 batch = image_shape.dim(0).size();
    482   int64 ix = image_shape.dim(x_index).size();
    483   int64 iy = image_shape.dim(y_index).size();
    484   int64 iz = image_shape.dim(channel_index).size();
    485   int64 kx = filter_shape.dim(0).size();
    486   int64 ky = filter_shape.dim(1).size();
    487   std::vector<int64> strides = GetStrides(op_features);
    488   const auto padding = GetPadding(op_features);
    489   int64 sx = strides[x_index];
    490   int64 sy = strides[y_index];
    491   int64 ox = GetOutputSize(ix, kx, sx, padding);
    492   int64 oy = GetOutputSize(iy, ky, sy, padding);
    493   int64 oz = filter_shape.dim(3).size();
    494   // Only check equality when both sizes are known (in other words, when
    495   // neither is set to a minimum dimension size of 1).
    496   if (iz != 1 && filter_shape.dim(2).size() != 1) {
    497     CHECK_EQ(iz, filter_shape.dim(2).size());
    498   } else {
    499     iz = std::max<int64>(iz, filter_shape.dim(2).size());
    500   }
    501   OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
    502       batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding};
    503 
    504   VLOG(1) << "Batch Size:" << batch;
    505   VLOG(1) << "Image Dims:" << ix << "," << iy;
    506   VLOG(1) << "Input Features:" << iz;
    507   VLOG(1) << "Kernel Dims:" << kx << "," << ky;
    508   VLOG(1) << "Output Features:" << oz;
    509   VLOG(1) << "Output Dims:" << ox << "," << oy;
    510   VLOG(1) << "Strides:" << sx << "," << sy;
    511   VLOG(1) << "Padding:" << (padding == Padding::VALID ? "VALID" : "SAME");
    512   return conv_dims;
    513 }
    514 
    515 int64 OpLevelCostEstimator::CountConv2DOperations(
    516     const OpInfo& op_features, ConvolutionDimensions* conv_info,
    517     bool* found_unknown_shapes) const {
    518   if (op_features.op() != kConv2d) {
    519     LOG(ERROR) << "Invalid Operation";
    520     return 0;
    521   }
    522   ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
    523       op_features.inputs(0).shape(), op_features.inputs(1).shape(), op_features,
    524       found_unknown_shapes);
    525 
    526   int64 ops = conv_dims.batch;
    527   ops *= conv_dims.ox * conv_dims.oy;
    528   ops *= conv_dims.kx * conv_dims.ky;
    529   ops *= conv_dims.iz * conv_dims.oz;
    530   ops *= kOpsPerMac;
    531   VLOG(1) << "Operations for Conv2D " << ops;
    532 
    533   if (conv_info != nullptr) {
    534     *conv_info = conv_dims;
    535   }
    536   return ops;
    537 }
    538 
    539 int64 OpLevelCostEstimator::CountMatMulOperations(
    540     const OpInfo& op_features, bool* found_unknown_shapes) const {
    541   return CountMatMulOperations(op_features, nullptr, found_unknown_shapes);
    542 }
    543 
    544 // TODO(nishantpatil): Create separate estimator for Sparse Matmul
    545 int64 OpLevelCostEstimator::CountMatMulOperations(
    546     const OpInfo& op_features, MatMulDimensions* mat_mul,
    547     bool* found_unknown_shapes) const {
    548   double ops = 0;
    549 
    550   if (op_features.inputs_size() < 2) {
    551     LOG(ERROR) << "Need 2 inputs but got " << op_features.inputs_size();
    552     *found_unknown_shapes = true;
    553     return 0;
    554   }
    555 
    556   auto& a_matrix = op_features.inputs(0);
    557   auto& b_matrix = op_features.inputs(1);
    558 
    559   bool transpose_a = false;
    560   bool transpose_b = false;
    561 
    562   double m_dim, n_dim, k_dim, k_dim_b = 0;
    563 
    564   for (const auto& item : op_features.attr()) {
    565     VLOG(1) << "Key:" << item.first
    566             << " Value:" << SummarizeAttrValue(item.second);
    567     if (item.first == "transpose_a" && item.second.b() == true)
    568       transpose_a = true;
    569     if (item.first == "transpose_b" && item.second.b() == true)
    570       transpose_b = true;
    571   }
    572   VLOG(1) << "transpose_a:" << transpose_a;
    573   VLOG(1) << "transpose_b:" << transpose_b;
    574   auto a_matrix_shape =
    575       MaybeGetMinimumShape(a_matrix.shape(), 2, found_unknown_shapes);
    576   auto b_matrix_shape =
    577       MaybeGetMinimumShape(b_matrix.shape(), 2, found_unknown_shapes);
    578   if (transpose_a) {
    579     m_dim = a_matrix_shape.dim(1).size();
    580     k_dim = a_matrix_shape.dim(0).size();
    581   } else {
    582     m_dim = a_matrix_shape.dim(0).size();
    583     k_dim = a_matrix_shape.dim(1).size();
    584   }
    585   if (transpose_b) {
    586     k_dim_b = b_matrix_shape.dim(1).size();
    587     n_dim = b_matrix_shape.dim(0).size();
    588   } else {
    589     k_dim_b = b_matrix_shape.dim(0).size();
    590     n_dim = b_matrix_shape.dim(1).size();
    591   }
    592 
    593   VLOG(1) << "M, N, K: " << m_dim << "," << n_dim << "," << k_dim;
    594   // Only check equality when both sizes are known (in other words, when
    595   // neither is set to a minimum dimension size of 1).
    596   if (k_dim_b != 1 && k_dim != 1 && k_dim_b != k_dim) {
    597     LOG(ERROR) << "Incompatible Matrix dimensions";
    598     return ops;
    599   } else {
    600     // One of k_dim and k_dim_b might be 1 (mininum dimension size).
    601     k_dim = std::max(k_dim, k_dim_b);
    602   }
    603 
    604   ops = m_dim * n_dim * k_dim * 2;
    605   VLOG(1) << "Operations for Matmul" << ops;
    606 
    607   if (mat_mul != nullptr) {
    608     mat_mul->m = m_dim;
    609     mat_mul->n = n_dim;
    610     mat_mul->k = k_dim;
    611   }
    612   return ops;
    613 }
    614 
    615 int64 OpLevelCostEstimator::CountBatchMatMulOperations(
    616     const OpInfo& op_features, bool* found_unknown_shapes) const {
    617   if (op_features.op() != kBatchMatMul) {
    618     LOG(ERROR) << "Invalid Operation: " << op_features.op();
    619     *found_unknown_shapes = true;
    620     return 0;
    621   }
    622   if (op_features.inputs_size() != 2) {
    623     LOG(ERROR) << "Expected 2 inputs but got " << op_features.inputs_size();
    624     *found_unknown_shapes = true;
    625     return 0;
    626   }
    627 
    628   double ops = 0;
    629   const auto& a_input = op_features.inputs(0);
    630   const auto& b_input = op_features.inputs(1);
    631 
    632   // BatchMatMul requires inputs of at least matrix shape (rank 2).
    633   // The two most minor dimensions of each input are matrices that
    634   // need to be multiplied together. The other dimensions determine
    635   // the number of such MatMuls.  For example, if the BatchMatMul has
    636   // inputs of shape:
    637   //   a_input_shape = [2, 3, 4, 5]
    638   //   b_input_shape = [2, 3, 5, 6]
    639   // then there are 2*3 = 6 MatMuls of dimensions m = 4, k = 5, n = 6
    640   // in this BatchMatMul.
    641   const int matrix_rank = 2;
    642 
    643   bool a_input_shape_unknown = false;
    644   bool b_input_shape_unknown = false;
    645 
    646   TensorShapeProto a_input_shape = MaybeGetMinimumShape(
    647       a_input.shape(), std::max(matrix_rank, a_input.shape().dim_size()),
    648       &a_input_shape_unknown);
    649   TensorShapeProto b_input_shape = MaybeGetMinimumShape(
    650       b_input.shape(), std::max(matrix_rank, b_input.shape().dim_size()),
    651       &b_input_shape_unknown);
    652 
    653   *found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
    654                           (a_input.shape().dim_size() < matrix_rank) ||
    655                           (b_input.shape().dim_size() < matrix_rank);
    656 
    657   // Compute the number of matmuls as the max indicated at each dimension
    658   // by either input. Note that the shapes do not have to have
    659   // the same rank due to incompleteness.
    660   TensorShapeProto* bigger_rank_shape = &a_input_shape;
    661   TensorShapeProto* smaller_rank_shape = &b_input_shape;
    662   if (b_input_shape.dim_size() > a_input_shape.dim_size()) {
    663     bigger_rank_shape = &b_input_shape;
    664     smaller_rank_shape = &a_input_shape;
    665   }
    666   int num_matmuls = 1;
    667   for (int b_i = 0,
    668            s_i = smaller_rank_shape->dim_size() - bigger_rank_shape->dim_size();
    669        b_i < bigger_rank_shape->dim_size() - matrix_rank; ++b_i, ++s_i) {
    670     int b_dim = bigger_rank_shape->dim(b_i).size();
    671     int s_dim = 1;
    672     if (s_i >= 0) {
    673       s_dim = smaller_rank_shape->dim(s_i).size();
    674     }
    675     num_matmuls *= std::max(b_dim, s_dim);
    676   }
    677 
    678   // Build the MatMul. Note that values are ignored here since we are just
    679   // counting ops (e.g. only shapes matter).
    680   OpInfo matmul_op_features;
    681   matmul_op_features.set_op("MatMul");
    682 
    683   AttrValue transpose_a;
    684   transpose_a.set_b(false);
    685   if (op_features.attr().find("adj_x") != op_features.attr().end()) {
    686     transpose_a.set_b(op_features.attr().at("adj_x").b());
    687   }
    688   (*matmul_op_features.mutable_attr())["transpose_a"] = transpose_a;
    689 
    690   AttrValue transpose_b;
    691   transpose_b.set_b(false);
    692   if (op_features.attr().find("adj_y") != op_features.attr().end()) {
    693     transpose_b.set_b(op_features.attr().at("adj_y").b());
    694   }
    695   (*matmul_op_features.mutable_attr())["transpose_b"] = transpose_b;
    696 
    697   OpInfo::TensorProperties* a_matrix = matmul_op_features.add_inputs();
    698   a_matrix->set_dtype(a_input.dtype());
    699   TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
    700   for (int i = std::max(0, a_input_shape.dim_size() - matrix_rank);
    701        i < a_input_shape.dim_size(); ++i) {
    702     *(a_matrix_shape->add_dim()) = a_input_shape.dim(i);
    703   }
    704 
    705   OpInfo::TensorProperties* b_matrix = matmul_op_features.add_inputs();
    706   b_matrix->set_dtype(b_input.dtype());
    707   TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
    708   for (int i = std::max(0, b_input_shape.dim_size() - matrix_rank);
    709        i < b_input_shape.dim_size(); ++i) {
    710     *(b_matrix_shape->add_dim()) = b_input_shape.dim(i);
    711   }
    712 
    713   for (int i = 0; i < num_matmuls; ++i) {
    714     bool matmul_unknown_shapes = false;
    715     ops += CountMatMulOperations(matmul_op_features, &matmul_unknown_shapes);
    716     *found_unknown_shapes |= matmul_unknown_shapes;
    717   }
    718   return ops;
    719 }
    720 
    721 // TODO(cliffy): Dedup this method and CountConv2DBackpropFilterOperations.
    722 int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
    723     const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims,
    724     bool* found_unknown_shapes) const {
    725   int64 ops = 0;
    726 
    727   DCHECK_EQ(kConv2dBackpropInput, op_features.op());
    728 
    729   if (op_features.inputs_size() < 2) {
    730     *found_unknown_shapes = true;
    731     return ops;
    732   }
    733 
    734   TensorShapeProto input_shape;
    735   if (op_features.inputs(0).has_value()) {
    736     const TensorProto& value = op_features.inputs(0).value();
    737     if (value.int64_val_size() > 0) {
    738       for (int i = 0; i < value.int64_val_size(); ++i) {
    739         input_shape.add_dim()->set_size(value.int64_val(i));
    740       }
    741     } else {
    742       for (int i = 0; i < value.int_val_size(); ++i) {
    743         input_shape.add_dim()->set_size(value.int_val(i));
    744       }
    745     }
    746   } else if (op_features.outputs_size() == 1) {
    747     input_shape = op_features.outputs(0).shape();
    748   } else {
    749     // Set the minimum filter size that's feasible.
    750     for (int i = 0; i < 4; ++i) {
    751       input_shape.add_dim()->set_size(1);
    752     }
    753     *found_unknown_shapes = true;
    754   }
    755 
    756   ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
    757       input_shape, op_features.inputs(1).shape(), op_features,
    758       found_unknown_shapes);
    759 
    760   ops = conv_dims.batch;
    761   ops *= conv_dims.ox * conv_dims.oy;
    762   ops *= conv_dims.kx * conv_dims.ky;
    763   ops *= conv_dims.iz * conv_dims.oz;
    764   ops *= kOpsPerMac;
    765 
    766   VLOG(1) << "Operations for Conv2DBackpropInput " << ops;
    767 
    768   if (returned_conv_dims != nullptr) {
    769     *returned_conv_dims = conv_dims;
    770   }
    771   return ops;
    772 }
    773 
    774 int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
    775     const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims,
    776     bool* found_unknown_shapes) const {
    777   int64 ops = 0;
    778   DCHECK_EQ(kConv2dBackpropFilter, op_features.op());
    779 
    780   TensorShapeProto filter_shape;
    781   if (op_features.inputs_size() >= 2 && op_features.inputs(1).has_value()) {
    782     const TensorProto& value = op_features.inputs(1).value();
    783     if (value.int64_val_size() > 0) {
    784       for (int i = 0; i < value.int64_val_size(); ++i) {
    785         filter_shape.add_dim()->set_size(value.int64_val(i));
    786       }
    787     } else {
    788       for (int i = 0; i < value.int_val_size(); ++i) {
    789         filter_shape.add_dim()->set_size(value.int_val(i));
    790       }
    791     }
    792   } else if (op_features.outputs_size() == 1) {
    793     filter_shape = op_features.outputs(0).shape();
    794   } else {
    795     // Set the minimum filter size that's feasible.
    796     for (int i = 0; i < 4; ++i) {
    797       filter_shape.add_dim()->set_size(1);
    798     }
    799     *found_unknown_shapes = true;
    800   }
    801 
    802   if (op_features.inputs_size() < 1) {
    803     *found_unknown_shapes = true;
    804     return ops;
    805   }
    806   ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
    807       op_features.inputs(0).shape(), filter_shape, op_features,
    808       found_unknown_shapes);
    809 
    810   ops = conv_dims.batch;
    811   ops *= conv_dims.ox * conv_dims.oy;
    812   ops *= conv_dims.kx * conv_dims.ky;
    813   ops *= conv_dims.iz * conv_dims.oz;
    814   ops *= kOpsPerMac;
    815 
    816   VLOG(1) << "Operations for Conv2DBackpropFilter" << ops;
    817 
    818   if (returned_conv_dims != nullptr) {
    819     *returned_conv_dims = conv_dims;
    820   }
    821   return ops;
    822 }
    823 
    824 int64 OpLevelCostEstimator::CalculateTensorElementCount(
    825     const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) const {
    826   VLOG(2) << "   with " << tensor.dtype() << " tensor of shape "
    827           << tensor.shape().DebugString();
    828   int64 tensor_size = 1;
    829   int num_dims = std::max(1, tensor.shape().dim_size());
    830   auto tensor_shape =
    831       MaybeGetMinimumShape(tensor.shape(), num_dims, found_unknown_shapes);
    832   for (const auto& dim : tensor_shape.dim()) {
    833     tensor_size *= dim.size();
    834   }
    835   return tensor_size;
    836 }
    837 
    838 int64 OpLevelCostEstimator::CalculateTensorSize(
    839     const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) const {
    840   return CalculateTensorElementCount(tensor, found_unknown_shapes) *
    841          DataTypeSize(BaseType(tensor.dtype()));
    842 }
    843 
    844 int64 OpLevelCostEstimator::CalculateInputSize(
    845     const OpInfo& op_features, bool* found_unknown_shapes) const {
    846   int64 total_input_size = 0;
    847   for (auto& input : op_features.inputs()) {
    848     int64 input_size = CalculateTensorSize(input, found_unknown_shapes);
    849     total_input_size += input_size;
    850     VLOG(1) << "Input Size: " << input_size
    851             << " Total Input Size:" << total_input_size;
    852   }
    853   return total_input_size;
    854 }
    855 
    856 int64 OpLevelCostEstimator::CalculateLargestInputCount(
    857     const OpInfo& op_features, bool* found_unknown_shapes) const {
    858   int64 largest_input_count = 0;
    859   for (auto& input : op_features.inputs()) {
    860     int64 input_count =
    861         CalculateTensorElementCount(input, found_unknown_shapes);
    862     if (input_count > largest_input_count) {
    863       largest_input_count = input_count;
    864     }
    865     VLOG(1) << "Input Count: " << input_count
    866             << " Largest Input Count:" << largest_input_count;
    867   }
    868   return largest_input_count;
    869 }
    870 
    871 int64 OpLevelCostEstimator::CalculateOutputSize(
    872     const OpInfo& op_features, bool* found_unknown_shapes) const {
    873   int64 total_output_size = 0;
    874   // use float as default for calculations
    875   for (const auto& output : op_features.outputs()) {
    876     DataType dt = output.dtype();
    877     const auto& original_output_shape = output.shape();
    878     int64 output_size = DataTypeSize(BaseType(dt));
    879     int num_dims = std::max(1, original_output_shape.dim_size());
    880     auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
    881                                              found_unknown_shapes);
    882     for (const auto& dim : output_shape.dim()) {
    883       output_size *= dim.size();
    884     }
    885     total_output_size += output_size;
    886     VLOG(1) << "Output Size: " << output_size
    887             << " Total Output Size:" << total_output_size;
    888   }
    889   return total_output_size;
    890 }
    891 
    892 Costs OpLevelCostEstimator::PredictConv2D(const OpContext& op_context) const {
    893   const auto& op_features = op_context.op_info;
    894   bool found_unknown_shapes = false;
    895   auto costs = PredictOpCountBasedCost(
    896       CountConv2DOperations(op_features, &found_unknown_shapes), op_features);
    897   costs.inaccurate = found_unknown_shapes;
    898   return costs;
    899 }
    900 
    901 Costs OpLevelCostEstimator::PredictConv2DBackpropInput(
    902     const OpContext& op_context) const {
    903   const auto& op_features = op_context.op_info;
    904   bool found_unknown_shapes = false;
    905   auto costs =
    906       PredictOpCountBasedCost(CountConv2DBackpropInputOperations(
    907                                   op_features, nullptr, &found_unknown_shapes),
    908                               op_features);
    909   costs.inaccurate = found_unknown_shapes;
    910   return costs;
    911 }
    912 
    913 Costs OpLevelCostEstimator::PredictConv2DBackpropFilter(
    914     const OpContext& op_context) const {
    915   const auto& op_features = op_context.op_info;
    916   bool found_unknown_shapes = false;
    917   auto costs =
    918       PredictOpCountBasedCost(CountConv2DBackpropFilterOperations(
    919                                   op_features, nullptr, &found_unknown_shapes),
    920                               op_features);
    921   costs.inaccurate = found_unknown_shapes;
    922   return costs;
    923 }
    924 
    925 Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const {
    926   const auto& op_features = op_context.op_info;
    927   bool found_unknown_shapes = false;
    928   auto costs = PredictOpCountBasedCost(
    929       CountMatMulOperations(op_features, &found_unknown_shapes), op_features);
    930   costs.inaccurate = found_unknown_shapes;
    931   return costs;
    932 }
    933 
    934 Costs OpLevelCostEstimator::PredictNoOp(const OpContext& op_context) const {
    935   const auto& op_features = op_context.op_info;
    936   VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)";
    937   return Costs::ZeroCosts();
    938 }
    939 
    940 Costs OpLevelCostEstimator::PredictIdentity(const OpContext& op_context) const {
    941   const auto& op_features = op_context.op_info;
    942   VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)";
    943   Costs result = Costs::ZeroCosts();
    944   result.max_memory = CalculateOutputSize(op_features, &result.inaccurate);
    945   // Assign the minimum amount of time we can represent to the identity op since
    946   // it tends to be really cheap.
    947   result.compute_time = kMinComputeTime;
    948   result.execution_time = result.compute_time;
    949   return result;
    950 }
    951 
    952 Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const {
    953   const auto& op_features = op_context.op_info;
    954   VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)";
    955   Costs result = Costs::ZeroCosts();
    956   result.persistent_memory =
    957       CalculateOutputSize(op_features, &result.inaccurate);
    958 
    959   result.compute_time = kMinComputeTime;
    960   result.execution_time = result.execution_time;
    961   return result;
    962 }
    963 
    964 Costs OpLevelCostEstimator::PredictBatchMatMul(
    965     const OpContext& op_context) const {
    966   const auto& op_features = op_context.op_info;
    967   bool found_unknown_shapes = false;
    968   Costs costs = PredictOpCountBasedCost(
    969       CountBatchMatMulOperations(op_features, &found_unknown_shapes),
    970       op_features);
    971   costs.inaccurate = found_unknown_shapes;
    972   return costs;
    973 }
    974 
    975 Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const {
    976   const auto& op_features = op_context.op_info;
    977   Costs costs = Costs::ZeroCosts();
    978   costs.max_memory = CalculateOutputSize(op_features, &costs.inaccurate);
    979   // Metadata operations are so cheap we assume they take the minimum amount of
    980   // time we can represent (1 ns).
    981   costs.compute_time = kMinComputeTime;
    982   costs.execution_time = costs.compute_time;
    983 
    984   return costs;
    985 }
    986 
    987 }  // end namespace grappler
    988 }  // end namespace tensorflow
    989