Home | History | Annotate | Download | only in optimizers
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <deque>
     17 #include <unordered_set>
     18 
     19 #include "tensorflow/core/framework/attr_value.pb.h"
     20 #include "tensorflow/core/framework/memory_types.h"
     21 #include "tensorflow/core/framework/node_def.pb.h"
     22 #include "tensorflow/core/framework/op.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/tensor.pb.h"
     25 #include "tensorflow/core/framework/tensor_shape.pb.h"
     26 #include "tensorflow/core/framework/types.h"
     27 #include "tensorflow/core/grappler/clusters/cluster.h"
     28 #include "tensorflow/core/grappler/devices.h"
     29 #include "tensorflow/core/grappler/grappler_item.h"
     30 #include "tensorflow/core/grappler/op_types.h"
     31 #include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
     32 #include "tensorflow/core/grappler/utils/frame.h"
     33 #include "tensorflow/core/lib/strings/numbers.h"
     34 #include "tensorflow/core/lib/strings/str_util.h"
     35 #include "tensorflow/core/lib/strings/strcat.h"
     36 #include "tensorflow/core/util/device_name_utils.h"
     37 
     38 namespace tensorflow {
     39 namespace grappler {
     40 namespace {
     41 
     42 const char kSuffix[] = "LayoutOptimizer";
     43 const char kPermNHWCToNCHW[] = "PermConstNHWCToNCHW";
     44 const char kPermNCHWToNHWC[] = "PermConstNCHWToNHWC";
     45 const char kTransposeNHWCToNCHW[] = "TransposeNHWCToNCHW";
     46 const char kTransposeNCHWToNHWC[] = "TransposeNCHWToNHWC";
     47 const char kDimMapNHWCToNCHW[] = "DimMapNHWCToNCHW";
     48 const char kDimMapNCHWToNHWC[] = "DimMapNCHWToNHWC";
     49 const char kVecPermuteNHWCToNCHW[] = "VecPermuteNHWCToNCHW";
     50 const char kVecPermuteNCHWToNHWC[] = "VecPermuteNCHWToNHWC";
     51 const char kReshapeNHWCToNCHW[] = "ReshapeNHWCToNCHW";
     52 const char kReshapeConst[] = "ReshapeConst";
     53 
     54 std::set<string> GetOpsFormatSupported() {
     55   std::set<string> ops_format_supported = {
     56       "AvgPool",
     57       "AvgPoolGrad",
     58       "Conv2D",
     59       "Conv2DBackpropFilter",
     60       "Conv2DBackpropInput",
     61       "BiasAdd",
     62       "BiasAddGrad",
     63       "DepthwiseConv2dNative",
     64       "DepthwiseConv2dNativeBackpropInput",
     65       "DepthwiseConv2dNativeBackpropFilter",
     66       "FusedBatchNorm",
     67       "FusedBatchNormV2",
     68       "FusedBatchNormGrad",
     69       "FusedBatchNormGradV2",
     70       "FusedConv2DBiasActivation",
     71       "MaxPool",
     72       "MaxPoolV2",
     73       "MaxPoolGrad",
     74       "MaxPoolGradGrad",
     75       "MaxPoolGradV2",
     76       "MaxPoolGradGradV2",
     77       "SpaceToDepth",
     78       "DepthToSpace"};
     79   return ops_format_supported;
     80 }
     81 
     82 std::set<string> GetOpsFormatAgnostic() {
     83   std::set<string> ops_format_agnostic = {"Abs",
     84                                           "Add",
     85                                           "AddN",
     86                                           "AddV2",
     87                                           "Acos",
     88                                           "Acosh",
     89                                           "All",
     90                                           "Angle",
     91                                           "Any",
     92                                           "ApproximateEqual",
     93                                           "Asin",
     94                                           "Asinh",
     95                                           "Atan",
     96                                           "Atan2",
     97                                           "Atanh",
     98                                           "Betainc",
     99                                           "Bitcast",
    100                                           "Cast",
    101                                           "Ceil",
    102                                           "CheckNumerics",
    103                                           "Complex",
    104                                           "ComplexAbs",
    105                                           "Concat",
    106                                           "ConcatV2",
    107                                           "Conj",
    108                                           "Cos",
    109                                           "Cosh",
    110                                           "Digamma",
    111                                           "Div",
    112                                           "Elu",
    113                                           "EluGrad",
    114                                           "Enter",
    115                                           "Equal",
    116                                           "Erf",
    117                                           "Erfc",
    118                                           "Exit",
    119                                           "Exp",
    120                                           "Expm1",
    121                                           "FakeQuantWithMinMaxVars",
    122                                           "FakeQuantWithMinMaxArgs",
    123                                           "Fill",
    124                                           "Floor",
    125                                           "FloorDiv",
    126                                           "FloorMod",
    127                                           "Greater",
    128                                           "GreaterEqual",
    129                                           "GuaranteeConst",
    130                                           "HistogramSummary",
    131                                           "Identity",
    132                                           "IdentityN",
    133                                           "Igamma",
    134                                           "Igammac",
    135                                           "Imag",
    136                                           "Inv",
    137                                           "InvGrad",
    138                                           "IsFinite",
    139                                           "IsInf",
    140                                           "IsNan",
    141                                           "Less",
    142                                           "LessEqual",
    143                                           "Lgamma",
    144                                           "Log",
    145                                           "LogicalAnd",
    146                                           "LogicalNot",
    147                                           "LogicalOr",
    148                                           "Log1p",
    149                                           "Max",
    150                                           "Maximum",
    151                                           "Mean",
    152                                           "Merge",
    153                                           "Min",
    154                                           "Minimum",
    155                                           "Mod",
    156                                           "Mul",
    157                                           "Neg",
    158                                           "NextIteration",
    159                                           "NotEqual",
    160                                           "OnesLike",
    161                                           "Pad",
    162                                           "PreventGradient",
    163                                           "Prod",
    164                                           "Polygamma",
    165                                           "QuantizeAndDequantizeV2",
    166                                           "QuantizeAndDequantizeV3",
    167                                           "Pow",
    168                                           "Real",
    169                                           "RealDiv",
    170                                           "Reciprocal",
    171                                           "ReciprocalGrad",
    172                                           "Relu",
    173                                           "Relu6",
    174                                           "Relu6Grad",
    175                                           "ReluGrad",
    176                                           "Rint",
    177                                           "Select",
    178                                           "Selu",
    179                                           "SeluGrad",
    180                                           "Shape",
    181                                           "ShapeN",
    182                                           "Sigmoid",
    183                                           "SigmoidGrad",
    184                                           "Sign",
    185                                           "Sin",
    186                                           "Sinh",
    187                                           "Slice",
    188                                           "Snapshot",
    189                                           "Softplus",
    190                                           "SoftplusGrad",
    191                                           "Split",
    192                                           "SplitV",
    193                                           "StridedSlice",
    194                                           "StridedSliceGrad",
    195                                           "Switch",
    196                                           "Tile",
    197                                           "TruncateDiv",
    198                                           "TruncateMod",
    199                                           "ReverseV2",
    200                                           "Round",
    201                                           "Rsqrt",
    202                                           "RsqrtGrad",
    203                                           "Sqrt",
    204                                           "SqrtGrad",
    205                                           "Square",
    206                                           "SquaredDifference",
    207                                           "Squeeze",
    208                                           "StopGradient",
    209                                           "Sub",
    210                                           "Sum",
    211                                           "Tan",
    212                                           "Tanh",
    213                                           "TanhGrad",
    214                                           "ZerosLike",
    215                                           "Zeta"};
    216   return ops_format_agnostic;
    217 }
    218 
    219 bool EndWith(const string& str, const string& ending) {
    220   if (str.size() < ending.size()) return false;
    221   if (str.substr(str.size() - ending.size(), ending.size()) == ending)
    222     return true;
    223   return false;
    224 }
    225 
    226 bool IsNodeByLayoutOptimizer(const string& node_name) {
    227   const string suffix = kSuffix;
    228   return EndWith(node_name, suffix);
    229 }
    230 
    231 bool IsNodeType(const string& node_name, const string& type) {
    232   const string suffix = strings::StrCat(type, "-", kSuffix);
    233   return EndWith(node_name, suffix);
    234 }
    235 
    236 bool IsTransposeNHWCToNCHW(const string& node_name) {
    237   return IsNodeType(node_name, kTransposeNHWCToNCHW);
    238 }
    239 
    240 bool IsTransposeNCHWToNHWC(const string& node_name) {
    241   return IsNodeType(node_name, kTransposeNCHWToNHWC);
    242 }
    243 
    244 bool IsDimMapNHWCToNCHW(const string& node_name) {
    245   return IsNodeType(node_name, kDimMapNHWCToNCHW);
    246 }
    247 
    248 bool IsDimMapNCHWToNHWC(const string& node_name) {
    249   return IsNodeType(node_name, kDimMapNCHWToNHWC);
    250 }
    251 
    252 bool IsVecPermuteNHWCToNCHW(const string& node_name) {
    253   return IsNodeType(node_name, kVecPermuteNHWCToNCHW);
    254 }
    255 
    256 bool IsVecPermuteNCHWToNHWC(const string& node_name) {
    257   return IsNodeType(node_name, kVecPermuteNCHWToNHWC);
    258 }
    259 
    260 bool IsConcat(const NodeDef& node) {
    261   const auto op = node.op();
    262   return op == "Concat" || op == "ConcatV2";
    263 }
    264 
    265 bool IsConcatV1(const NodeDef& node) {
    266   const auto op = node.op();
    267   return op == "Concat";
    268 }
    269 
    270 bool IsMaxPoolV2(const NodeDef& node) {
    271   const auto& op = node.op();
    272   return op == "MaxPoolV2";
    273 }
    274 
    275 bool IsMaxPoolGradV1(const NodeDef& node) {
    276   const auto& op = node.op();
    277   return op == "MaxPoolGrad";
    278 }
    279 
    280 bool IsMaxPoolGradV2(const NodeDef& node) {
    281   const auto& op = node.op();
    282   return op == "MaxPoolGradV2";
    283 }
    284 
    285 bool IsMaxPoolGradGradV1(const NodeDef& node) {
    286   const auto& op = node.op();
    287   return op == "MaxPoolGradGrad";
    288 }
    289 
    290 bool IsMaxPoolGradGradV2(const NodeDef& node) {
    291   const auto& op = node.op();
    292   return op == "MaxPoolGradGradV2";
    293 }
    294 
    295 bool IsUnaryGrad(const NodeDef& node) {
    296   bool is_unary_grad =
    297       IsEluGrad(node) || IsInvGrad(node) || IsReciprocalGrad(node) ||
    298       IsRelu6Grad(node) || IsReluGrad(node) || IsRsqrtGrad(node) ||
    299       IsSeluGrad(node) || IsSigmoidGrad(node) || IsSoftplusGrad(node) ||
    300       IsSoftsignGrad(node) || IsSqrtGrad(node) || IsTanhGrad(node);
    301   return is_unary_grad;
    302 }
    303 
    304 bool IsComparisonOp(const NodeDef& node) {
    305   bool is_compare = IsApproximateEqual(node) || IsEqual(node) ||
    306                     IsGreater(node) || IsGreaterEqual(node) || IsLess(node) ||
    307                     IsLessEqual(node) || IsNotEqual(node);
    308   return is_compare;
    309 }
    310 
    311 bool IsReduceOp(const NodeDef& node) {
    312   return IsSum(node) || IsMean(node) || IsProd(node) || IsMax(node) ||
    313          IsMin(node) || IsAll(node) || IsAny(node);
    314 }
    315 
    316 bool IsBinaryOp(const NodeDef& node) {
    317   bool is_binary =
    318       IsAdd(node) || IsAtan2(node) || IsComparisonOp(node) || IsComplex(node) ||
    319       IsDiv(node) || IsFloorDiv(node) || IsIgamma(node) || IsIgammac(node) ||
    320       IsLogicalAnd(node) || IsLogicalOr(node) || IsMaximum(node) ||
    321       IsMinimum(node) || IsMod(node) || IsMul(node) || IsPolygamma(node) ||
    322       IsPow(node) || IsRealDiv(node) || IsSquaredDifference(node) ||
    323       IsSub(node) || IsTruncateDiv(node) || IsTruncateMod(node) || IsZeta(node);
    324   return is_binary;
    325 }
    326 
    327 std::vector<int> NonControlInputs(const NodeDef& node) {
    328   std::vector<int> pos;
    329   for (int i = 0; i < node.input_size(); i++) {
    330     if (!IsControlInput(node.input(i))) {
    331       pos.push_back(i);
    332     }
    333   }
    334   return pos;
    335 }
    336 
    337 std::vector<int> DataInputPosConcat(const NodeDef& node) {
    338   int n = node.attr().at("N").i();
    339   std::vector<int> input_pos;
    340   int start = (IsConcatV1(node)) ? 1 : 0;
    341   int end = start + n;
    342   for (int i = start; i < end; i++) {
    343     input_pos.push_back(i);
    344   }
    345   return input_pos;
    346 }
    347 
    348 std::vector<int> DataInputPos(const NodeDef& node) {
    349   if (IsSplit(node) || IsHistogramSummary(node)) {
    350     return {1};
    351   }
    352   if (IsStridedSliceGrad(node)) {
    353     return {4};
    354   }
    355   if (IsBinaryOp(node) || IsUnaryGrad(node)) {
    356     return {0, 1};
    357   }
    358   if (IsBetainc(node) || IsSelect(node)) {
    359     return {0, 1, 2};
    360   }
    361   if (IsShapeN(node) || IsIdentityN(node) || IsAddN(node) || IsMerge(node)) {
    362     return NonControlInputs(node);
    363   }
    364   if (IsConcat(node)) {
    365     return DataInputPosConcat(node);
    366   }
    367   if (node.input_size() > 0 && !IsControlInput(node.input(0))) {
    368     return {0};
    369   }
    370   return {};
    371 }
    372 
    373 bool IsHostMemory(const NodeDef& node, int output_port) {
    374   DeviceNameUtils::ParsedName parsed_name;
    375   if (DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
    376     DeviceType device_type(parsed_name.type);
    377     Status s = FindKernelDef(device_type, node, nullptr, nullptr);
    378     if (s.ok()) {
    379       tensorflow::MemoryTypeVector in_mtypes;
    380       tensorflow::MemoryTypeVector out_mtypes;
    381       s = tensorflow::MemoryTypesForNode(OpRegistry::Global(), device_type,
    382                                          node, &in_mtypes, &out_mtypes);
    383       if (s.ok()) {
    384         if (out_mtypes[output_port] == HOST_MEMORY) {
    385           return true;
    386         }
    387       }
    388     } else {
    389       return true;
    390     }
    391   }
    392   return false;
    393 }
    394 
    395 class GraphProcessor {
    396  public:
    397   GraphProcessor(const GraphProperties& graph_properties,
    398                  const VirtualPlacer& virtual_placer,
    399                  const std::unordered_set<string>& nodes_to_preserve,
    400                  GraphDef* graph, NodeMap* node_map)
    401       : graph_properties_(graph_properties),
    402         virtual_placer_(virtual_placer),
    403         nodes_to_preserve_(nodes_to_preserve),
    404         graph_(graph),
    405         node_map_(node_map) {}
    406 
    407  protected:
    408   NodeDef* AddNodePermConst(const string& name, const string& device,
    409                             const std::vector<int>& permutation) {
    410     NodeDef* node = graph_->add_node();
    411     node_map_->AddNode(name, node);
    412     node->set_name(name);
    413     node->set_op("Const");
    414     AttrValue attr_data_type;
    415     attr_data_type.set_type(DT_INT32);
    416     node->mutable_attr()->insert({"dtype", attr_data_type});
    417     AttrValue attr_tensor;
    418     Tensor tensor(DT_INT32, TensorShape({4}));
    419     for (int i = 0; static_cast<size_t>(i) < permutation.size(); i++) {
    420       tensor.flat<int>()(i) = permutation[i];
    421     }
    422     tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
    423     node->mutable_attr()->insert({"value", attr_tensor});
    424     string device_name;
    425     if (device.empty()) {
    426       device_name = virtual_placer_.get_canonical_device_name(*node);
    427     } else {
    428       device_name = device;
    429     }
    430     node->set_device(device_name);
    431     return node;
    432   }
    433 
    434   NodeDef* AddNodeConstScalar(const string& name, const string& device,
    435                               DataType dtype, int value) {
    436     NodeDef* node = graph_->add_node();
    437     node_map_->AddNode(name, node);
    438     node->set_name(name);
    439     node->set_op("Const");
    440     AttrValue attr_data_type;
    441     attr_data_type.set_type(dtype);
    442     node->mutable_attr()->insert({"dtype", attr_data_type});
    443     AttrValue attr_tensor;
    444     Tensor tensor(dtype, TensorShape({}));
    445     tensor.scalar<int>()() = value;
    446     tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
    447     node->mutable_attr()->insert({"value", attr_tensor});
    448     string device_name;
    449     if (device.empty()) {
    450       device_name = virtual_placer_.get_canonical_device_name(*node);
    451     } else {
    452       device_name = device;
    453     }
    454     node->set_device(device_name);
    455     return node;
    456   }
    457 
    458   string LayoutOptimizerNode(const string& base_name) {
    459     return strings::StrCat(base_name, "-", kSuffix);
    460   }
    461 
    462   const GraphProperties& graph_properties_;
    463   const VirtualPlacer& virtual_placer_;
    464   const std::unordered_set<string>& nodes_to_preserve_;
    465   GraphDef* graph_;
    466   NodeMap* node_map_;
    467 };
    468 
    469 struct OptimizeContext {
    470   OptimizeContext(GraphDef* graph, NodeDef* node, NodeMap* node_map,
    471                   const GraphProperties& graph_properties,
    472                   const VirtualPlacer& virtual_placer,
    473                   const std::unordered_set<string>& nodes_to_preserve,
    474                   bool is_in_frame)
    475       : graph(graph),
    476         node(node),
    477         node_map(node_map),
    478         graph_properties(graph_properties),
    479         virtual_placer(virtual_placer),
    480         nodes_to_preserve(nodes_to_preserve),
    481         is_in_frame(is_in_frame) {}
    482   GraphDef* graph;
    483   NodeDef* node;
    484   NodeMap* node_map;
    485   const GraphProperties& graph_properties;
    486   const VirtualPlacer& virtual_placer;
    487   const std::unordered_set<string>& nodes_to_preserve;
    488   bool is_in_frame;
    489 };
    490 
    491 class NodeProcessor : public GraphProcessor {
    492  public:
    493   explicit NodeProcessor(const OptimizeContext& opt_cxt)
    494       : GraphProcessor(opt_cxt.graph_properties, opt_cxt.virtual_placer,
    495                        opt_cxt.nodes_to_preserve, opt_cxt.graph,
    496                        opt_cxt.node_map),
    497         node_(opt_cxt.node),
    498         is_in_frame_(opt_cxt.is_in_frame) {}
    499   virtual ~NodeProcessor() {}
    500   virtual Status ConvertNode() {
    501     if (ShouldProcess()) {
    502       UpdateAttrDataFormat();
    503       UpdateAttrKSize();
    504       UpdateAttrStrides();
    505       UpdateAttrDilations();
    506       UpdateAttrExplicitPaddings();
    507       UpdateAttrShape();
    508       TF_RETURN_IF_ERROR(AddLayoutTransposeToInputs());
    509       TF_RETURN_IF_ERROR(AddLayoutTransposeToOutputs());
    510       TF_RETURN_IF_ERROR(CustomizedProcessing());
    511     }
    512     return Status::OK();
    513   }
    514 
    515  protected:
    516   bool IsPortDimsN(const NodeDef& node, int port, int n) const {
    517     if (node.attr().find("_output_shapes") != node.attr().end()) {
    518       if (node.attr().at("_output_shapes").list().shape_size() > port) {
    519         auto shape = node.attr().at("_output_shapes").list().shape(port);
    520         if (shape.unknown_rank()) {
    521           return false;
    522         }
    523         if (shape.dim_size() == n) {
    524           return true;
    525         }
    526       }
    527     }
    528     return false;
    529   }
    530 
    531   bool IsPortZeroDimsN(const NodeDef& node, int n) const {
    532     return IsPortDimsN(node, 0, n);
    533   }
    534 
    535   bool IsPortZeroDimsFour(const NodeDef& node) const {
    536     return NodeProcessor::IsPortZeroDimsN(node, 4) ||
    537            IsTransposeNCHWToNHWC(node.name());
    538   }
    539 
    540   bool IsPortDimsFour(const NodeDef& node, int port) const {
    541     return NodeProcessor::IsPortDimsN(node, port, 4) ||
    542            IsTransposeNCHWToNHWC(node.name());
    543   }
    544 
    545   bool IsNHWC() const {
    546     if (node_->attr().find("data_format") != node_->attr().end()) {
    547       if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
    548         return true;
    549       }
    550     }
    551     return false;
    552   }
    553 
    554   bool HasOutputs() const {
    555     auto outputs = node_map_->GetOutputs(node_->name());
    556     return !outputs.empty();
    557   }
    558 
    559   Status HasAttribute(const NodeDef& node, const string& attr) const {
    560     if (node.attr().find(attr) == node.attr().end()) {
    561       return Status(error::INVALID_ARGUMENT,
    562                     strings::StrCat("Missing attribute ", attr));
    563     }
    564     return Status::OK();
    565   }
    566 
    567   bool MustPreserve() const {
    568     return nodes_to_preserve_.find(node_->name()) != nodes_to_preserve_.end();
    569   }
    570 
    571   bool IsOnGPU() const {
    572     string device_name;
    573     if (node_->device().empty()) {
    574       device_name = virtual_placer_.get_canonical_device_name(*node_);
    575     } else {
    576       device_name = node_->device();
    577     }
    578     string device;
    579     string not_used;
    580     if (DeviceNameUtils::SplitDeviceName(device_name, &not_used, &device) &&
    581         str_util::StrContains(str_util::Lowercase(device),
    582                               str_util::Lowercase(DEVICE_GPU))) {
    583       return true;
    584     }
    585     return false;
    586   }
    587 
    588   virtual bool ShouldProcess() const {
    589     return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
    590            HasOutputs() && IsOnGPU();
    591   }
    592 
    593   virtual void UpdateAttrShape() {
    594     if (node_->attr().find("_output_shapes") != node_->attr().end()) {
    595       for (const auto& pos : GetOutputPos()) {
    596         auto shape = node_->mutable_attr()
    597                          ->at("_output_shapes")
    598                          .mutable_list()
    599                          ->mutable_shape(pos);
    600         if (shape->dim_size() == 4) {
    601           int64 h = shape->dim(1).size();
    602           int64 w = shape->dim(2).size();
    603           int64 c = shape->dim(3).size();
    604           shape->mutable_dim(1)->set_size(c);
    605           shape->mutable_dim(2)->set_size(h);
    606           shape->mutable_dim(3)->set_size(w);
    607         }
    608       }
    609     }
    610   }
    611 
    612   Status UpdateAttrValueOfInput(int input_index, bool permute) {
    613     auto input_node = node_map_->GetNode(node_->input(input_index));
    614     // We created a copy of the node, so that we don't modify the original node,
    615     // which might be used elsewhere. Note that this copy also copies the
    616     // control dependency input in the case this node is inside a loop,
    617     // to ensure added_node is in the same frame with node_.
    618     NodeDef* added_node = graph_->add_node();
    619     *added_node = *input_node;
    620     string base_name = strings::StrCat(node_->name(), "-", input_index);
    621     string node_name = LayoutOptimizerNode(base_name);
    622     added_node->set_name(node_name);
    623     *node_->mutable_input(input_index) = node_name;
    624     node_map_->AddNode(node_name, added_node);
    625     node_map_->AddOutput(node_name, node_->name());
    626     return UpdateAttrValue(added_node, permute);
    627   }
    628 
    629   virtual std::vector<int> GetInputPos() const { return {0}; }
    630 
    631   virtual std::set<int> GetOutputPos() const {
    632     // For most nodes, no need to process control nodes or nodes that use an
    633     // output other than the first output: only the first output is of
    634     // 4D NCHW/NHWC format and thus relevant here.
    635     std::set<int> output_pos = {0};
    636     return output_pos;
    637   }
    638 
    639   virtual Status AddLayoutTransposeToInputs() {
    640     std::vector<int> input_pos = GetInputPos();
    641     for (const auto& pos : input_pos) {
    642       string node_name = LayoutOptimizerNode(
    643           strings::StrCat(node_->name(), "-", pos, "-", kTransposeNHWCToNCHW));
    644       DataType dtype =
    645           graph_properties_.GetInputProperties(node_->name())[pos].dtype();
    646       auto input_node = node_map_->GetNode(node_->input(pos));
    647       TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
    648       string const_name = GetOrAddNodePermNHWCToNCHW(pos);
    649       int output_pos;
    650       ParseNodeName(node_->input(pos), &output_pos);
    651       AddNodeTranspose(
    652           node_name, node_->input(pos), const_name, dtype,
    653           input_node->attr().at("_output_shapes").list().shape(output_pos),
    654           true);
    655       node_map_->UpdateOutput(NodeName(node_->input(pos)), node_->name(),
    656                               node_name);
    657       node_map_->AddOutput(node_name, node_->name());
    658       *node_->mutable_input(pos) = node_name;
    659     }
    660     return Status::OK();
    661   }
    662 
    663   Status AddTransformToOutputs(const string& op) {
    664     auto outputs = node_map_->GetOutputs(node_->name());
    665     string const_name = GetOrAddNodePermNCHWToNHWC();
    666     int output_count = 0;
    667     for (const auto& output : outputs) {
    668       int connections = 0;
    669       int connections_removed = 0;
    670       for (int i = 0; i < output->input_size(); i++) {
    671         auto& input = *output->mutable_input(i);
    672         int input_port;
    673         string input_name = ParseNodeName(input, &input_port);
    674         auto output_pos = GetOutputPos();
    675         if (input_name == node_->name()) {
    676           connections++;
    677           if (output_pos.find(input_port) != output_pos.end()) {
    678             connections_removed++;
    679             string added_node_base_name =
    680                 strings::StrCat(node_->name(), "-", output_count, "-", i);
    681             string added_node_name;
    682             DataType dtype =
    683                 graph_properties_.GetOutputProperties(node_->name())[input_port]
    684                     .dtype();
    685             if (op == "Transpose") {
    686               added_node_name = LayoutOptimizerNode(strings::StrCat(
    687                   added_node_base_name, "-", kTransposeNCHWToNHWC));
    688               TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
    689               AddNodeTranspose(
    690                   added_node_name, input, const_name, dtype,
    691                   node_->attr().at("_output_shapes").list().shape(input_port),
    692                   false);
    693             } else if (op == "DataFormatVecPermute") {
    694               added_node_name = LayoutOptimizerNode(strings::StrCat(
    695                   added_node_base_name, "-", kVecPermuteNCHWToNHWC));
    696               AddNodeDataFormatOp(added_node_name, input, op, dtype, false);
    697             } else {
    698               return errors::InvalidArgument("Unsupported op type: ", op);
    699             }
    700             input = added_node_name;
    701             node_map_->AddOutput(node_->name(), added_node_name);
    702             node_map_->AddOutput(added_node_name, output->name());
    703           }
    704         }
    705       }
    706       if (connections == connections_removed) {
    707         node_map_->RemoveOutput(node_->name(), output->name());
    708       }
    709       output_count++;
    710     }
    711     return Status::OK();
    712   }
    713 
    714   virtual Status AddLayoutTransposeToOutputs() {
    715     return AddTransformToOutputs("Transpose");
    716   }
    717 
    718   virtual Status CustomizedProcessing() { return Status::OK(); }
    719 
    720   Status UpdateOrTransformParamInput(int param_index, const string& op,
    721                                      DataType dtype) {
    722     auto param_node = node_map_->GetNode(node_->input(param_index));
    723     bool permute = (op == "DataFormatVecPermute") ? true : false;
    724     if (IsConstant(*param_node)) {
    725       TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(param_index, permute));
    726     } else {
    727       AddDataFormatTranformToParamInput(op, param_index, dtype);
    728     }
    729     return Status::OK();
    730   }
    731 
    732   NodeDef* node_;
    733   bool is_in_frame_;
    734 
    735  private:
    736   void UpdateAttrKSize() {
    737     if (node_->attr().find("ksize") != node_->attr().end()) {
    738       auto list = node_->mutable_attr()->at("ksize").mutable_list();
    739       UpdateTuple(list);
    740     }
    741   }
    742 
    743   void UpdateAttrStrides() {
    744     if (node_->attr().find("strides") != node_->attr().end()) {
    745       auto list = node_->mutable_attr()->at("strides").mutable_list();
    746       UpdateTuple(list);
    747     }
    748   }
    749 
    750   void UpdateAttrDilations() {
    751     if (node_->attr().find("dilations") != node_->attr().end()) {
    752       auto list = node_->mutable_attr()->at("dilations").mutable_list();
    753       UpdateTuple(list);
    754     }
    755   }
    756 
    757   void UpdateAttrExplicitPaddings() {
    758     if (node_->attr().find("explicit_paddings") != node_->attr().end()) {
    759       auto list = node_->mutable_attr()->at("explicit_paddings").mutable_list();
    760       int size = list->i_size();
    761       if (size == 8) {
    762         int64 height_before = list->i(2);
    763         int64 height_after = list->i(3);
    764         int64 width_before = list->i(4);
    765         int64 width_after = list->i(5);
    766         list->set_i(2, 0);
    767         list->set_i(3, 0);
    768         list->set_i(4, height_before);
    769         list->set_i(5, height_after);
    770         list->set_i(6, width_before);
    771         list->set_i(7, width_after);
    772       } else if (size != 0) {
    773         LOG(ERROR) << "Cannot handle explicit_paddings attribute of size "
    774                    << size;
    775       }
    776     }
    777   }
    778 
    779   void UpdateAttrDataFormat() {
    780     if (node_->attr().find("data_format") != node_->attr().end()) {
    781       if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
    782         string* data_format =
    783             node_->mutable_attr()->at("data_format").mutable_s();
    784         *data_format = "NCHW";
    785       }
    786     }
    787   }
    788 
    789   Status UpdateAttrValue(NodeDef* node, bool permute) {
    790     TF_RETURN_IF_ERROR(HasAttribute(*node, "value"));
    791     Tensor tensor;
    792     auto success =
    793         tensor.FromProto(node->mutable_attr()->at({"value"}).tensor());
    794     if (!success) {
    795       LOG(ERROR) << "Failed to parse TensorProto.";
    796     }
    797 
    798     if (permute) {
    799       if (tensor.dims() == 1) {
    800         if (tensor.flat<int>().size() == 4) {
    801           int c = tensor.flat<int>()(3);
    802           tensor.flat<int>()(3) = tensor.flat<int>()(2);
    803           tensor.flat<int>()(2) = tensor.flat<int>()(1);
    804           tensor.flat<int>()(1) = c;
    805         } else {
    806           return Status(error::INVALID_ARGUMENT,
    807                         strings::StrCat("Unsupported tensor size: ",
    808                                         tensor.flat<int>().size()));
    809         }
    810       } else if (tensor.dims() == 2) {
    811         for (int i = 0; i < 2; i++) {
    812           int c = tensor.matrix<int>()(3, i);
    813           tensor.matrix<int>()(3, i) = tensor.matrix<int>()(2, i);
    814           tensor.matrix<int>()(2, i) = tensor.matrix<int>()(1, i);
    815           tensor.matrix<int>()(1, i) = c;
    816         }
    817       } else {
    818         return Status(
    819             error::INVALID_ARGUMENT,
    820             strings::StrCat("Unsupported dimension size: ", tensor.dims()));
    821       }
    822     } else {
    823       for (int i = 0; i < tensor.flat<int>().size(); i++) {
    824         int value = tensor.flat<int>()(i);
    825         value = (value >= 0) ? value : value + 4;
    826         if (value == 1 || value == 2) {
    827           value = value + 1;
    828         } else if (value == 3) {
    829           value = 1;
    830         }
    831         tensor.flat<int>()(i) = value;
    832       }
    833     }
    834 
    835     if (tensor.dims() == 0) {
    836       tensor.AsProtoField(node->mutable_attr()->at({"value"}).mutable_tensor());
    837     } else {
    838       tensor.AsProtoTensorContent(
    839           node->mutable_attr()->at({"value"}).mutable_tensor());
    840     }
    841     return Status::OK();
    842   }
    843 
    844   NodeDef* AddNodeTranspose(const string& node_name, const string& input_name,
    845                             const string& const_name, DataType data_type,
    846                             const TensorShapeProto& input_shape,
    847                             bool NHWCToNCHW) {
    848     NodeDef* node = graph_->add_node();
    849     node_map_->AddNode(node_name, node);
    850     node->set_name(node_name);
    851     *node->add_input() = input_name;
    852     *node->add_input() = const_name;
    853     node->set_op("Transpose");
    854     node->set_device(node_->device());
    855     AttrValue attr_data_type;
    856     attr_data_type.set_type(data_type);
    857     node->mutable_attr()->insert({"T", attr_data_type});
    858     AttrValue attr_data_type_perm;
    859     attr_data_type_perm.set_type(DT_INT32);
    860     node->mutable_attr()->insert({"Tperm", attr_data_type_perm});
    861     if (!input_shape.unknown_rank()) {
    862       AttrValue attr_output_shape;
    863       auto output_shape = attr_output_shape.mutable_list()->add_shape();
    864       if (NHWCToNCHW) {
    865         output_shape->add_dim()->set_size(input_shape.dim(0).size());
    866         output_shape->add_dim()->set_size(input_shape.dim(3).size());
    867         output_shape->add_dim()->set_size(input_shape.dim(1).size());
    868         output_shape->add_dim()->set_size(input_shape.dim(2).size());
    869       } else {
    870         output_shape->add_dim()->set_size(input_shape.dim(0).size());
    871         output_shape->add_dim()->set_size(input_shape.dim(2).size());
    872         output_shape->add_dim()->set_size(input_shape.dim(3).size());
    873         output_shape->add_dim()->set_size(input_shape.dim(1).size());
    874       }
    875       node->mutable_attr()->insert({"_output_shapes", attr_output_shape});
    876     }
    877     return node;
    878   }
    879 
    880   NodeDef* AddNodePermNHWCToNCHW(const string& base_name,
    881                                  const string& depended_node,
    882                                  const string& device) {
    883     string name =
    884         LayoutOptimizerNode(strings::StrCat(base_name, "-", kPermNHWCToNCHW));
    885     auto const_node = AddNodePermConst(name, device, {0, 3, 1, 2});
    886     // This is to ensure the transpose node and the const node are in the
    887     // same frame.
    888     *const_node->add_input() = AsControlDependency(depended_node);
    889     return const_node;
    890   }
    891 
    892   NodeDef* AddNodePermNCHWToNHWC(const string& base_name,
    893                                  const string& depended_node,
    894                                  const string& device) {
    895     auto const_node = AddNodePermConst(
    896         LayoutOptimizerNode(strings::StrCat(base_name, "-", kPermNCHWToNHWC)),
    897         device, {0, 2, 3, 1});
    898     // This is to ensure the transpose node and the const node are in the same
    899     // frame.
    900     *const_node->add_input() = AsControlDependency(depended_node);
    901     return const_node;
    902   }
    903 
    904   string GetOrAddNodePermNHWCToNCHW(int pos) {
    905     string const_name;
    906     if (is_in_frame_) {
    907       string base_name = strings::StrCat(node_->name(), "-", pos);
    908       string input = NodeName(node_->input(pos));
    909       string depended_node;
    910       if (!IsTransposeNCHWToNHWC(input)) {
    911         depended_node = input;
    912       } else {
    913         auto input_node = node_map_->GetNode(input);
    914         depended_node = NodeName(input_node->input(0));
    915       }
    916       auto const_node =
    917           AddNodePermNHWCToNCHW(base_name, depended_node, node_->device());
    918       const_name = const_node->name();
    919     } else {
    920       const_name = LayoutOptimizerNode(kPermNHWCToNCHW);
    921     }
    922     return const_name;
    923   }
    924 
    925   string GetOrAddNodePermNCHWToNHWC() {
    926     string const_name;
    927     if (is_in_frame_) {
    928       auto const_node =
    929           AddNodePermNCHWToNHWC(node_->name(), node_->name(), node_->device());
    930       const_name = const_node->name();
    931     } else {
    932       const_name = LayoutOptimizerNode(kPermNCHWToNHWC);
    933     }
    934     return const_name;
    935   }
    936 
    937   void UpdateTuple(AttrValue_ListValue* list) {
    938     int64 h = list->i(1);
    939     int64 w = list->i(2);
    940     int64 c = list->i(3);
    941     list->set_i(1, c);
    942     list->set_i(2, h);
    943     list->set_i(3, w);
    944   }
    945 
    946   bool IsInputOnHost(const string& input_name) const {
    947     string device = node_->device();
    948     DeviceNameUtils::ParsedName parsed_name;
    949     if (DeviceNameUtils::ParseFullName(device, &parsed_name)) {
    950       if (parsed_name.type != "CPU") {
    951         NodeDef* input = node_map_->GetNode(input_name);
    952         int port;
    953         ParseNodeName(input_name, &port);
    954         if (IsHostMemory(*input, port)) {
    955           return true;
    956         }
    957       }
    958     }
    959     return false;
    960   }
    961 
    962   NodeDef* AddNodeDataFormatOp(const string& name, const string& input_name,
    963                                const string& op, DataType dtype,
    964                                bool nhwc_to_nchw) {
    965     NodeDef* added_node = graph_->add_node();
    966     added_node->set_name(name);
    967     added_node->set_op(op);
    968     node_map_->AddNode(added_node->name(), added_node);
    969     added_node->set_device(node_->device());
    970     // The inputs of a DataFormat op could be in host memory for ops such as
    971     // Reshape. In such cases, run the kernel on the host too.
    972     if (IsInputOnHost(input_name)) {
    973       AttrValue attr_kernel;
    974       attr_kernel.set_s("host");
    975       added_node->mutable_attr()->insert({"_kernel", attr_kernel});
    976     }
    977     AttrValue attr_data_type;
    978     attr_data_type.set_type(dtype);
    979     added_node->mutable_attr()->insert({"T", attr_data_type});
    980     string src_format = (nhwc_to_nchw) ? "NHWC" : "NCHW";
    981     string dst_format = (nhwc_to_nchw) ? "NCHW" : "NHWC";
    982     AttrValue attr_format;
    983     attr_format.set_s(src_format);
    984     added_node->mutable_attr()->insert({"src_format", attr_format});
    985     attr_format.set_s(dst_format);
    986     added_node->mutable_attr()->insert({"dst_format", attr_format});
    987     *added_node->add_input() = input_name;
    988     return added_node;
    989   }
    990 
    991   void AddDataFormatTranformToParamInput(const string& op, int input_pos,
    992                                          DataType dtype) {
    993     string suffix = (op == "DataFormatVecPermute") ? kVecPermuteNHWCToNCHW
    994                                                    : kDimMapNHWCToNCHW;
    995     string name = LayoutOptimizerNode(
    996         strings::StrCat(node_->name(), "-", input_pos, "-", suffix));
    997     auto added_node =
    998         AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true);
    999     *node_->mutable_input(input_pos) = added_node->name();
   1000     node_map_->UpdateOutput(NodeName(added_node->input(0)), node_->name(),
   1001                             added_node->name());
   1002     node_map_->AddOutput(added_node->name(), node_->name());
   1003   }
   1004 };
   1005 
   1006 class AvgPoolGradProcessor : public NodeProcessor {
   1007  public:
   1008   explicit AvgPoolGradProcessor(const OptimizeContext& opt_cxt)
   1009       : NodeProcessor(opt_cxt) {}
   1010 
   1011  protected:
   1012   std::vector<int> GetInputPos() const override { return {1}; }
   1013   Status CustomizedProcessing() override {
   1014     return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
   1015   }
   1016 };
   1017 
   1018 class BiasAddGradProcessor : public NodeProcessor {
   1019  public:
   1020   explicit BiasAddGradProcessor(const OptimizeContext& opt_cxt)
   1021       : NodeProcessor(opt_cxt) {}
   1022 
   1023  protected:
   1024   bool ShouldProcess() const override {
   1025     if (MustPreserve()) {
   1026       return false;
   1027     }
   1028     if (!IsOnGPU()) {
   1029       return false;
   1030     }
   1031     auto input = node_map_->GetNode(node_->input(0));
   1032     if (input) {
   1033       int port;
   1034       ParseNodeName(node_->input(0), &port);
   1035       if (IsNHWC() && IsPortDimsFour(*input, port)) {
   1036         return true;
   1037       }
   1038     }
   1039     return false;
   1040   }
   1041 
   1042   Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
   1043 };
   1044 
   1045 class Conv2DProcessor : public NodeProcessor {
   1046  public:
   1047   Conv2DProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
   1048       : NodeProcessor(opt_cxt), no_gemm_(no_gemm) {}
   1049 
   1050  protected:
   1051   bool ShouldProcess() const override {
   1052     return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
   1053            HasOutputs() && (!IsGemmUsed() || no_gemm_) && IsOnGPU();
   1054   }
   1055 
   1056   TensorShapeProto GetShape(const string& input_name) const {
   1057     string node_name;
   1058     int output_pos;
   1059     node_name = ParseNodeName(input_name, &output_pos);
   1060     NodeDef* node = node_map_->GetNode(node_name);
   1061     if (node->attr().find("_output_shapes") != node->attr().end()) {
   1062       return node->attr().at("_output_shapes").list().shape(output_pos);
   1063     }
   1064     TensorShapeProto shape;
   1065     return shape;
   1066   }
   1067 
   1068   bool IsStrideOne() const {
   1069     if (node_->attr().find("strides") != node_->attr().end()) {
   1070       auto list = node_->attr().at("strides").list();
   1071       return list.i(1) == 1 && list.i(2) == 1;
   1072     }
   1073     return false;
   1074   }
   1075 
   1076   bool IsValidPadding() const {
   1077     if (node_->attr().find("padding") != node_->attr().end()) {
   1078       auto padding = node_->attr().at("padding").s();
   1079       return padding == "VALID";
   1080     }
   1081     return false;
   1082   }
   1083 
   1084   // The logic inside this function is based on the internal implementation of
   1085   // Conv2D, Conv2DBackpropInput, and Conv2DBackpropFilter ops, and thus
   1086   // needs to be updated accordingly if the internal implementation changes.
   1087   bool IsGemmUsed(const TensorShapeProto& filter_shape,
   1088                   const TensorShapeProto& input_shape) const {
   1089     if (filter_shape.dim_size() == 4) {
   1090       if (filter_shape.dim(0).size() == 1 && filter_shape.dim(1).size() == 1 &&
   1091           IsStrideOne()) {
   1092         return true;
   1093       }
   1094     }
   1095     if (input_shape.dim_size() == 4 && filter_shape.dim_size() == 4) {
   1096       if (input_shape.dim(1).size() == filter_shape.dim(0).size() &&
   1097           input_shape.dim(2).size() == filter_shape.dim(1).size() &&
   1098           IsValidPadding()) {
   1099         return true;
   1100       }
   1101     }
   1102     return false;
   1103   }
   1104 
   1105   virtual bool IsGemmUsed() const {
   1106     auto filter_shape = GetShape(node_->input(1));
   1107     auto input_shape = GetShape(node_->input(0));
   1108     return IsGemmUsed(filter_shape, input_shape);
   1109   }
   1110 
   1111   bool no_gemm_;
   1112 };
   1113 
   1114 class Conv2DBackpropFilterProcessor : public Conv2DProcessor {
   1115  public:
   1116   Conv2DBackpropFilterProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
   1117       : Conv2DProcessor(opt_cxt, no_gemm) {}
   1118 
   1119  protected:
   1120   bool IsGemmUsed() const override {
   1121     auto filter_shape = GetShape(node_->name());
   1122     auto input_shape = GetShape(node_->input(0));
   1123     return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape);
   1124   }
   1125 
   1126   std::vector<int> GetInputPos() const override { return {0, 2}; }
   1127 
   1128   Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
   1129   // No need to update output shape, as it is always of shape
   1130   // [filter_height, filter_width, in_channels, out_channels], regardless of
   1131   // whether NCHW or NHWC is used.
   1132   void UpdateAttrShape() override {}
   1133 };
   1134 
   1135 class Conv2DBackpropInputProcessor : public Conv2DProcessor {
   1136  public:
   1137   Conv2DBackpropInputProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
   1138       : Conv2DProcessor(opt_cxt, no_gemm) {}
   1139 
   1140  protected:
   1141   bool IsGemmUsed() const override {
   1142     auto filter_shape = GetShape(node_->input(1));
   1143     auto input_shape = GetShape(node_->name());
   1144     return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape);
   1145   }
   1146 
   1147   std::vector<int> GetInputPos() const override { return {2}; }
   1148 
   1149   Status CustomizedProcessing() override {
   1150     return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
   1151   }
   1152 };
   1153 
   1154 class FusedBatchNormGradProcessor : public NodeProcessor {
   1155  public:
   1156   explicit FusedBatchNormGradProcessor(const OptimizeContext& opt_cxt)
   1157       : NodeProcessor(opt_cxt) {}
   1158 
   1159  protected:
   1160   bool ShouldProcess() const override {
   1161     return NodeProcessor::ShouldProcess() && IsTraining();
   1162   }
   1163 
   1164   std::vector<int> GetInputPos() const override { return {0, 1}; }
   1165 
   1166  private:
   1167   bool IsTraining() const {
   1168     if (node_->attr().find("is_training") != node_->attr().end()) {
   1169       if (node_->attr().at("is_training").b()) {
   1170         return true;
   1171       }
   1172     }
   1173     return false;
   1174   }
   1175 };
   1176 
   1177 class MaxPoolGradProcessor : public NodeProcessor {
   1178  public:
   1179   explicit MaxPoolGradProcessor(const OptimizeContext& opt_cxt)
   1180       : NodeProcessor(opt_cxt) {}
   1181 
   1182  protected:
   1183   std::vector<int> GetInputPos() const override { return {0, 1, 2}; }
   1184 };
   1185 
   1186 class MaxPoolGradV2Processor : public MaxPoolGradProcessor {
   1187  public:
   1188   explicit MaxPoolGradV2Processor(const OptimizeContext& opt_cxt)
   1189       : MaxPoolGradProcessor(opt_cxt) {}
   1190 
   1191  protected:
   1192   Status CustomizedProcessing() override {
   1193     for (int i = 3; i <= 4; i++) {
   1194       TF_RETURN_IF_ERROR(
   1195           UpdateOrTransformParamInput(i, "DataFormatVecPermute", DT_INT32));
   1196     }
   1197     return Status::OK();
   1198   }
   1199 };
   1200 
   1201 class MaxPoolV2Processor : public NodeProcessor {
   1202  public:
   1203   explicit MaxPoolV2Processor(const OptimizeContext& opt_cxt)
   1204       : NodeProcessor(opt_cxt) {}
   1205 
   1206  protected:
   1207   bool ShouldProcess() const override {
   1208     // We check data_input's shape instead, because the shape inference of
   1209     // MaxPoolV2 is not able to infer the shape when ksize or strides is not
   1210     // constant.
   1211     auto data_input = node_map_->GetNode(node_->input(0));
   1212     int port;
   1213     ParseNodeName(node_->input(0), &port);
   1214     return !MustPreserve() && IsNHWC() && IsPortDimsFour(*data_input, port) &&
   1215            HasOutputs() && IsOnGPU();
   1216   }
   1217 
   1218   Status CustomizedProcessing() override {
   1219     for (int i = 1; i <= 2; i++) {
   1220       TF_RETURN_IF_ERROR(
   1221           UpdateOrTransformParamInput(i, "DataFormatVecPermute", DT_INT32));
   1222     }
   1223     return Status::OK();
   1224   }
   1225 };
   1226 
   1227 class AgnosticNodeProcessor : public NodeProcessor {
   1228  public:
   1229   explicit AgnosticNodeProcessor(const OptimizeContext& opt_cxt)
   1230       : NodeProcessor(opt_cxt) {}
   1231 
   1232  protected:
   1233   bool ShouldProcess() const override {
   1234     return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
   1235            IsNodeAfterNCHWToNHWC() && IsOnGPU();
   1236   }
   1237 
   1238   bool IsNodeAfterNCHWToNHWC(const NodeDef& node) const {
   1239     std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
   1240     std::deque<NodeDef*> queue;
   1241     auto data_node_pos = DataInputPos(node);
   1242     std::unordered_set<string> visited;
   1243     for (const auto& pos : data_node_pos) {
   1244       auto input_node = node_map_->GetNode(node.input(pos));
   1245       queue.push_back(input_node);
   1246       visited.insert(input_node->name());
   1247     }
   1248     // The code will exit this while loop in one iteration in most cases, as the
   1249     // graph is already topologically sorted.
   1250     while (!queue.empty()) {
   1251       NodeDef* current_node = queue.front();
   1252       queue.pop_front();
   1253       if (IsTransposeNCHWToNHWC(current_node->name()) ||
   1254           IsDimMapNCHWToNHWC(current_node->name()) ||
   1255           IsVecPermuteNCHWToNHWC(current_node->name())) {
   1256         return true;
   1257       }
   1258       // We only continue searching if the path is connected through
   1259       // format-agnostic nodes.
   1260       if (ops_format_agnostic.find(current_node->op()) !=
   1261           ops_format_agnostic.end()) {
   1262         auto current_node_pos = DataInputPos(*current_node);
   1263         for (const auto& pos : current_node_pos) {
   1264           auto input_node = node_map_->GetNode(current_node->input(pos));
   1265           if (visited.find(input_node->name()) == visited.end()) {
   1266             queue.push_back(input_node);
   1267             visited.insert(input_node->name());
   1268           }
   1269         }
   1270       }
   1271     }
   1272     return false;
   1273   }
   1274 
   1275   bool IsNodeAfterNCHWToNHWC() const { return IsNodeAfterNCHWToNHWC(*node_); }
   1276 };
   1277 
   1278 class AddNProcessor : public AgnosticNodeProcessor {
   1279  public:
   1280   explicit AddNProcessor(const OptimizeContext& opt_cxt)
   1281       : AgnosticNodeProcessor(opt_cxt) {}
   1282 
   1283  protected:
   1284   std::vector<int> GetInputPos() const override {
   1285     return NonControlInputs(*node_);
   1286   }
   1287 };
   1288 
   1289 class BinaryOpProcessor : public AgnosticNodeProcessor {
   1290  public:
   1291   explicit BinaryOpProcessor(const OptimizeContext& opt_cxt)
   1292       : AgnosticNodeProcessor(opt_cxt) {}
   1293 
   1294  protected:
   1295   bool ShouldProcess() const override {
   1296     return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
   1297            IsNodeAfterNCHWToNHWC() &&
   1298            (IsNDOperateWithMD(4, 0) || IsNDOperateWithMD(4, 1) ||
   1299             IsNDOperateWithMD(4, 4) || IsNDOperateWithMD(0, 4) ||
   1300             IsNDOperateWithMD(1, 4)) &&
   1301            IsOnGPU();
   1302   }
   1303 
   1304   std::vector<int> GetInputPos() const override {
   1305     std::vector<int> input_pos;
   1306     auto input0 = node_map_->GetNode(node_->input(0));
   1307     auto input1 = node_map_->GetNode(node_->input(1));
   1308     int input0_port;
   1309     ParseNodeName(node_->input(0), &input0_port);
   1310     int input1_port;
   1311     ParseNodeName(node_->input(1), &input1_port);
   1312     if (IsPortDimsFour(*input0, input0_port)) {
   1313       input_pos.push_back(0);
   1314     }
   1315     if (IsPortDimsFour(*input1, input1_port)) {
   1316       input_pos.push_back(1);
   1317     }
   1318     return input_pos;
   1319   }
   1320 
   1321   bool IsNDOperateWithMD(int n, int m) const {
   1322     auto input0 = node_map_->GetNode(node_->input(0));
   1323     auto input1 = node_map_->GetNode(node_->input(1));
   1324     int input0_port;
   1325     ParseNodeName(node_->input(0), &input0_port);
   1326     int input1_port;
   1327     ParseNodeName(node_->input(1), &input1_port);
   1328 
   1329     if (input0 && input1) {
   1330       bool input0_is_n = (n == 4) ? IsPortDimsFour(*input0, input0_port)
   1331                                   : IsPortDimsN(*input0, input0_port, n);
   1332       bool input1_is_m = (m == 4) ? IsPortDimsFour(*input1, input1_port)
   1333                                   : IsPortDimsN(*input1, input1_port, m);
   1334       return input0_is_n && input1_is_m;
   1335     }
   1336     return false;
   1337   }
   1338 
   1339   NodeDef* AddNodeShapeConst(const string& name, int num_channels,
   1340                              const string& depended_node) {
   1341     NodeDef* node = graph_->add_node();
   1342     node_map_->AddNode(name, node);
   1343     node->set_name(name);
   1344     node->set_op("Const");
   1345     node->set_device(node_->device());
   1346     AttrValue attr_data_type;
   1347     attr_data_type.set_type(DT_INT32);
   1348     node->mutable_attr()->insert({"dtype", attr_data_type});
   1349 
   1350     AttrValue attr_tensor;
   1351     Tensor tensor(DT_INT32, TensorShape({4}));
   1352     std::vector<int> shape = {1, num_channels, 1, 1};
   1353     for (int i = 0; i < static_cast<int>(shape.size()); i++) {
   1354       tensor.flat<int>()(i) = shape[i];
   1355     }
   1356     tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
   1357     node->mutable_attr()->insert({"value", attr_tensor});
   1358     if (is_in_frame_) {
   1359       // This is to ensure the transpose node and the const node are in the
   1360       // same frame.
   1361       *node->add_input() = AsControlDependency(depended_node);
   1362     }
   1363     return node;
   1364   }
   1365 
   1366   NodeDef* AddNodeReshape(const string& node_name, const string& input_name,
   1367                           const string& shape_const_node_name,
   1368                           DataType data_type) {
   1369     NodeDef* node = graph_->add_node();
   1370     node_map_->AddNode(node_name, node);
   1371     node->set_name(node_name);
   1372     *node->add_input() = input_name;
   1373     *node->add_input() = shape_const_node_name;
   1374     node->set_op("Reshape");
   1375     node->set_device(node_->device());
   1376 
   1377     AttrValue attr_type_indices;
   1378     attr_type_indices.set_type(DT_INT32);
   1379     node->mutable_attr()->insert({"Tshape", attr_type_indices});
   1380 
   1381     AttrValue attr_type_params;
   1382     attr_type_params.set_type(data_type);
   1383     node->mutable_attr()->insert({"T", attr_type_params});
   1384     return node;
   1385   }
   1386 
   1387   Status CustomizedProcessing() override {
   1388     int vector_index = -1;
   1389     if (IsNDOperateWithMD(4, 1)) {
   1390       vector_index = 1;
   1391     } else if (IsNDOperateWithMD(1, 4)) {
   1392       vector_index = 0;
   1393     }
   1394     if (vector_index != -1) {
   1395       string base_name = strings::StrCat(node_->name(), "-", vector_index);
   1396       string reshape_node_name = LayoutOptimizerNode(
   1397           strings::StrCat(base_name, "-", kReshapeNHWCToNCHW));
   1398       string shape_const_node_name =
   1399           LayoutOptimizerNode(strings::StrCat(base_name, "-", kReshapeConst));
   1400       auto input_node = node_map_->GetNode(node_->input(vector_index));
   1401       TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
   1402       int port;
   1403       ParseNodeName(node_->input(vector_index), &port);
   1404       int vector_size = input_node->attr()
   1405                             .at("_output_shapes")
   1406                             .list()
   1407                             .shape(port)
   1408                             .dim(0)
   1409                             .size();
   1410       AddNodeShapeConst(shape_const_node_name, vector_size,
   1411                         NodeName(node_->input(vector_index)));
   1412       TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
   1413       AddNodeReshape(reshape_node_name, node_->input(vector_index),
   1414                      shape_const_node_name, node_->attr().at("T").type());
   1415       node_map_->AddOutput(shape_const_node_name, reshape_node_name);
   1416       node_map_->UpdateOutput(NodeName(node_->input(vector_index)),
   1417                               node_->name(), reshape_node_name);
   1418       node_map_->AddOutput(reshape_node_name, node_->name());
   1419       *node_->mutable_input(vector_index) = reshape_node_name;
   1420     }
   1421     return Status::OK();
   1422   }
   1423 };
   1424 
   1425 class ConcatProcessor : public AgnosticNodeProcessor {
   1426  public:
   1427   explicit ConcatProcessor(const OptimizeContext& opt_cxt)
   1428       : AgnosticNodeProcessor(opt_cxt) {
   1429     // For Concat,  the concat axis is the first input; for ConcatV2,
   1430     // the last input. Note that if with control inputs, the number of inputs
   1431     // is larger than the integer attribute N.
   1432     int n = node_->attr().at("N").i();
   1433     axis_node_pos_ = (IsConcatV1(*node_)) ? 0 : n;
   1434   }
   1435 
   1436  protected:
   1437   std::vector<int> GetInputPos() const override {
   1438     return DataInputPosConcat(*node_);
   1439   }
   1440 
   1441   Status CustomizedProcessing() override {
   1442     DataType dtype =
   1443         (IsConcatV1(*node_)) ? DT_INT32 : node_->attr().at("Tidx").type();
   1444     return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap",
   1445                                        dtype);
   1446   }
   1447 
   1448   int axis_node_pos_;
   1449 };
   1450 
   1451 class FillProcessor : public AgnosticNodeProcessor {
   1452  public:
   1453   explicit FillProcessor(const OptimizeContext& opt_cxt)
   1454       : AgnosticNodeProcessor(opt_cxt) {}
   1455 
   1456  protected:
   1457   std::vector<int> GetInputPos() const override { return {}; }
   1458 
   1459   Status CustomizedProcessing() override {
   1460     DataType dtype = node_->attr().at("index_type").type();
   1461     return UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype);
   1462   }
   1463 };
   1464 
   1465 class HistogramSummaryProcessor : public AgnosticNodeProcessor {
   1466  public:
   1467   explicit HistogramSummaryProcessor(const OptimizeContext& opt_cxt)
   1468       : AgnosticNodeProcessor(opt_cxt) {}
   1469 
   1470  protected:
   1471   bool ShouldProcess() const override {
   1472     auto input1 = node_map_->GetNode(node_->input(1));
   1473     int port;
   1474     ParseNodeName(node_->input(1), &port);
   1475     return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
   1476            IsPortDimsFour(*input1, port) && IsOnGPU();
   1477   }
   1478 
   1479   std::vector<int> GetInputPos() const override { return {1}; }
   1480 
   1481   Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
   1482 };
   1483 
   1484 class IdentityNProcessor : public AgnosticNodeProcessor {
   1485  public:
   1486   explicit IdentityNProcessor(const OptimizeContext& opt_cxt)
   1487       : AgnosticNodeProcessor(opt_cxt) {
   1488     std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
   1489     for (int i = 0; i < node_->input_size(); i++) {
   1490       auto input = node_map_->GetNode(node_->input(i));
   1491       int port;
   1492       ParseNodeName(node_->input(i), &port);
   1493       // Skip control input.
   1494       if (port != -1) {
   1495         bool is_agnostic =
   1496             ops_format_agnostic.find(input->op()) != ops_format_agnostic.end();
   1497         if (IsPortDimsFour(*input, port) &&
   1498             ((IsNodeAfterNCHWToNHWC(*input) && is_agnostic) ||
   1499              IsTransposeNCHWToNHWC(input->name()))) {
   1500           input_pos_.push_back(i);
   1501         }
   1502       }
   1503     }
   1504   }
   1505 
   1506  protected:
   1507   bool ShouldProcess() const override {
   1508     return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
   1509            IsOnGPU();
   1510   }
   1511 
   1512   std::vector<int> GetInputPos() const override { return input_pos_; }
   1513 
   1514   std::set<int> GetOutputPos() const override {
   1515     std::set<int> output_pos{};
   1516     for (const auto& input_pos : input_pos_) {
   1517       output_pos.insert(input_pos);
   1518     }
   1519     return output_pos;
   1520   }
   1521 
   1522  private:
   1523   std::vector<int> input_pos_;
   1524 };
   1525 
   1526 class ShapeProcessor : public IdentityNProcessor {
   1527  public:
   1528   explicit ShapeProcessor(const OptimizeContext& opt_cxt)
   1529       : IdentityNProcessor(opt_cxt) {}
   1530 
   1531  protected:
   1532   Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
   1533 
   1534   Status CustomizedProcessing() override {
   1535     return AddTransformToOutputs("DataFormatVecPermute");
   1536   }
   1537 };
   1538 
   1539 class MergeProcessor : public AgnosticNodeProcessor {
   1540  public:
   1541   explicit MergeProcessor(const OptimizeContext& opt_cxt)
   1542       : AgnosticNodeProcessor(opt_cxt) {}
   1543 
   1544  protected:
   1545   bool ShouldProcess() const override {
   1546     return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
   1547            IsEveryInputAfterNCHWToNHWC() && IsOnGPU();
   1548   }
   1549 
   1550   std::vector<int> GetInputPos() const override {
   1551     std::vector<int> input_pos;
   1552     int n = node_->attr().at("N").i();
   1553     input_pos.reserve(n);
   1554     for (int i = 0; i < n; i++) {
   1555       input_pos.push_back(i);
   1556     }
   1557     return input_pos;
   1558   }
   1559 
   1560  private:
   1561   bool IsEveryInputAfterNCHWToNHWC() const {
   1562     std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
   1563     for (const auto& input : node_->input()) {
   1564       auto input_node = node_map_->GetNode(input);
   1565       int port;
   1566       ParseNodeName(input, &port);
   1567       bool is_agnostic = ops_format_agnostic.find(input_node->op()) !=
   1568                          ops_format_agnostic.end();
   1569       if (IsPortDimsFour(*input_node, port) &&
   1570           ((IsNodeAfterNCHWToNHWC(*input_node) && is_agnostic) ||
   1571            IsTransposeNCHWToNHWC(input_node->name()))) {
   1572         continue;
   1573       }
   1574       return false;
   1575     }
   1576     return true;
   1577   }
   1578 };
   1579 
   1580 class PadProcessor : public AgnosticNodeProcessor {
   1581  public:
   1582   explicit PadProcessor(const OptimizeContext& opt_cxt)
   1583       : AgnosticNodeProcessor(opt_cxt) {}
   1584 
   1585  protected:
   1586   Status CustomizedProcessing() override {
   1587     DataType dtype = node_->attr().at("Tpaddings").type();
   1588     return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype);
   1589   }
   1590 };
   1591 
   1592 class ReverseProcessor : public AgnosticNodeProcessor {
   1593  public:
   1594   explicit ReverseProcessor(const OptimizeContext& opt_cxt)
   1595       : AgnosticNodeProcessor(opt_cxt) {}
   1596 
   1597  protected:
   1598   Status CustomizedProcessing() override {
   1599     DataType dtype = node_->attr().at("Tidx").type();
   1600     return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype);
   1601   }
   1602 };
   1603 
   1604 class SplitProcessor : public AgnosticNodeProcessor {
   1605  public:
   1606   explicit SplitProcessor(const OptimizeContext& opt_cxt)
   1607       : AgnosticNodeProcessor(opt_cxt) {
   1608     axis_node_pos_ = 0;
   1609   }
   1610 
   1611  protected:
   1612   std::vector<int> GetInputPos() const override { return {1}; }
   1613 
   1614   std::set<int> GetOutputPos() const override {
   1615     std::set<int> output_pos{0};
   1616     if (HasAttribute(*node_, "num_split").ok()) {
   1617       for (int i = 1; i < node_->attr().at("num_split").i(); i++) {
   1618         output_pos.insert(i);
   1619       }
   1620     }
   1621     return output_pos;
   1622   }
   1623 
   1624   Status CustomizedProcessing() override {
   1625     return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap",
   1626                                        DT_INT32);
   1627   }
   1628 
   1629   int axis_node_pos_;
   1630 };
   1631 
   1632 class SplitVProcessor : public SplitProcessor {
   1633  public:
   1634   explicit SplitVProcessor(const OptimizeContext& opt_cxt)
   1635       : SplitProcessor(opt_cxt) {
   1636     axis_node_pos_ = 2;
   1637   }
   1638 
   1639  protected:
   1640   std::vector<int> GetInputPos() const override { return {0}; }
   1641 };
   1642 
   1643 class TernaryOpProcessor : public AgnosticNodeProcessor {
   1644  public:
   1645   explicit TernaryOpProcessor(const OptimizeContext& opt_cxt)
   1646       : AgnosticNodeProcessor(opt_cxt) {}
   1647 
   1648  protected:
   1649   std::vector<int> GetInputPos() const override { return {0, 1, 2}; }
   1650 };
   1651 
   1652 class SelectProcessor : public AgnosticNodeProcessor {
   1653  public:
   1654   explicit SelectProcessor(const OptimizeContext& opt_cxt)
   1655       : AgnosticNodeProcessor(opt_cxt) {}
   1656 
   1657  protected:
   1658   bool ShouldProcess() const override {
   1659     auto input0 = node_map_->GetNode(node_->input(0));
   1660     int input0_port;
   1661     ParseNodeName(node_->input(0), &input0_port);
   1662     bool is_input0_scalar_vector_4d = IsPortDimsN(*input0, input0_port, 0) ||
   1663                                       IsPortDimsN(*input0, input0_port, 1) ||
   1664                                       IsPortDimsN(*input0, input0_port, 4);
   1665     return AgnosticNodeProcessor::ShouldProcess() && is_input0_scalar_vector_4d;
   1666   }
   1667 
   1668   std::vector<int> GetInputPos() const override {
   1669     auto input0 = node_map_->GetNode(node_->input(0));
   1670     int input0_port;
   1671     ParseNodeName(node_->input(0), &input0_port);
   1672     // Input 0 could be a scalar, a vector with size matching the first
   1673     // dimension of input 1 and 2, or must have the same shape as input 1 and 2.
   1674     if (IsPortDimsFour(*input0, input0_port)) {
   1675       return {0, 1, 2};
   1676     } else {
   1677       return {1, 2};
   1678     }
   1679   }
   1680 };
   1681 
   1682 class UnaryGradProcessor : public AgnosticNodeProcessor {
   1683  public:
   1684   explicit UnaryGradProcessor(const OptimizeContext& opt_cxt)
   1685       : AgnosticNodeProcessor(opt_cxt) {}
   1686 
   1687  protected:
   1688   std::vector<int> GetInputPos() const override { return {0, 1}; }
   1689 };
   1690 
   1691 class SliceProcessor : public AgnosticNodeProcessor {
   1692  public:
   1693   explicit SliceProcessor(const OptimizeContext& opt_cxt)
   1694       : AgnosticNodeProcessor(opt_cxt) {
   1695     // Skip the first input, which is the data to be sliced.
   1696     start_ = 1;
   1697     // Note that we can't use node_->input_size() here because there
   1698     // could be control inputs.
   1699     end_ = 2;
   1700   }
   1701 
   1702  protected:
   1703   Status ProcessInputs() {
   1704     for (int i = start_; i <= end_; i++) {
   1705       DataType dtype = node_->attr().at("Index").type();
   1706       TF_RETURN_IF_ERROR(
   1707           UpdateOrTransformParamInput(i, "DataFormatVecPermute", dtype));
   1708     }
   1709     return Status::OK();
   1710   }
   1711 
   1712   Status CustomizedProcessing() override { return ProcessInputs(); }
   1713 
   1714   int start_;
   1715   int end_;
   1716 };
   1717 
   1718 class StridedSliceProcessor : public SliceProcessor {
   1719  public:
   1720   explicit StridedSliceProcessor(const OptimizeContext& opt_cxt)
   1721       : SliceProcessor(opt_cxt) {
   1722     start_ = 1;
   1723     end_ = 3;
   1724   }
   1725 
   1726  protected:
   1727   bool ShouldProcess() const override {
   1728     return AgnosticNodeProcessor::ShouldProcess() && IsOnlyBeginEndMask();
   1729   }
   1730 
   1731   Status CustomizedProcessing() override {
   1732     TF_RETURN_IF_ERROR(UpdateMask("begin_mask"));
   1733     TF_RETURN_IF_ERROR(UpdateMask("end_mask"));
   1734     TF_RETURN_IF_ERROR(ProcessInputs());
   1735     return Status::OK();
   1736   }
   1737 
   1738  private:
   1739   bool IsMaskZero(const string& mask) const {
   1740     return node_->attr().at(mask).i() == 0;
   1741   }
   1742 
   1743   bool IsOnlyBeginEndMask() const {
   1744     return IsMaskZero("ellipsis_mask") && IsMaskZero("new_axis_mask") &&
   1745            IsMaskZero("shrink_axis_mask");
   1746   }
   1747 
   1748   Status UpdateMask(const string& mask) {
   1749     int i = node_->attr().at(mask).i();
   1750     if (i < 0 || i > 15) {
   1751       return errors::InvalidArgument("invalid mask value: ", i);
   1752     }
   1753     if (i == 0 || i == 1 || i == 14 || i == 15) return Status::OK();
   1754     switch (i) {
   1755       case 2:
   1756       case 3:
   1757         i += 2;
   1758         break;
   1759       case 4:
   1760       case 5:
   1761         i += 4;
   1762         break;
   1763       case 6:
   1764       case 7:
   1765         i += 6;
   1766         break;
   1767       case 8:
   1768       case 9:
   1769         i -= 6;
   1770         break;
   1771       case 10:
   1772       case 11:
   1773         i -= 4;
   1774         break;
   1775       case 12:
   1776       case 13:
   1777         i -= 2;
   1778         break;
   1779     }
   1780     node_->mutable_attr()->at(mask).set_i(i);
   1781     return Status::OK();
   1782   }
   1783 };
   1784 
   1785 class StridedSliceGradProcessor : public StridedSliceProcessor {
   1786  public:
   1787   explicit StridedSliceGradProcessor(const OptimizeContext& opt_cxt)
   1788       : StridedSliceProcessor(opt_cxt) {
   1789     start_ = 0;
   1790     end_ = 3;
   1791   }
   1792 
   1793  protected:
   1794   std::vector<int> GetInputPos() const override { return {4}; }
   1795 };
   1796 
   1797 class SqueezeProcessor : public AgnosticNodeProcessor {
   1798  public:
   1799   explicit SqueezeProcessor(const OptimizeContext& opt_cxt)
   1800       : AgnosticNodeProcessor(opt_cxt) {}
   1801 
   1802  protected:
   1803   bool ShouldProcess() const override {
   1804     bool is_dims_supported = (IsPortZeroDimsN(*node_, 2) && IsAlongHW()) ||
   1805                              (IsPortZeroDimsN(*node_, 1) && IsAlongNHW());
   1806     return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
   1807            IsInputConvertible() && is_dims_supported && IsOnGPU();
   1808   }
   1809 
   1810   Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
   1811 
   1812   Status CustomizedProcessing() override {
   1813     TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims"));
   1814     auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list();
   1815     if (list->i_size() == 2) {
   1816       list->set_i(0, 2);
   1817       list->set_i(1, 3);
   1818     } else if (list->i_size() == 3) {
   1819       list->set_i(1, 2);
   1820       list->set_i(2, 3);
   1821     }
   1822     return Status::OK();
   1823   }
   1824 
   1825  private:
   1826   bool IsInputConvertible() const {
   1827     int input_port;
   1828     auto input = node_map_->GetNode(node_->input(0));
   1829     ParseNodeName(node_->input(0), &input_port);
   1830     if (input->attr().find("_output_shapes") != input->attr().end()) {
   1831       auto shape = input->attr().at("_output_shapes").list().shape(input_port);
   1832       if (shape.dim_size() != 4) {
   1833         return false;
   1834       }
   1835       if (shape.dim(1).size() == 1 && shape.dim(2).size() == 1) {
   1836         return true;
   1837       }
   1838       if (shape.dim(0).size() == 1 && shape.dim(1).size() == 1 &&
   1839           shape.dim(2).size() == 1) {
   1840         return true;
   1841       }
   1842     }
   1843     return false;
   1844   }
   1845 
   1846   bool IsAlongAxis(const std::vector<int>& axis) const {
   1847     if (node_->attr().find("squeeze_dims") != node_->attr().end()) {
   1848       auto list = node_->attr().at("squeeze_dims").list();
   1849       // If list is empty, Squeeze op will squeeze all dimensions of size 1.
   1850       if (list.i_size() == 0) return true;
   1851       if (list.i_size() == axis.size()) {
   1852         bool along_axis = true;
   1853         for (int i = 0; i < axis.size(); i++) {
   1854           along_axis = along_axis && (list.i(i) == axis[i]);
   1855         }
   1856         if (along_axis) return true;
   1857       }
   1858     }
   1859     return false;
   1860   }
   1861   bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
   1862   bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
   1863 };
   1864 
   1865 class ReduceProcessor : public AgnosticNodeProcessor {
   1866  public:
   1867   explicit ReduceProcessor(const OptimizeContext& opt_cxt)
   1868       : AgnosticNodeProcessor(opt_cxt) {}
   1869 
   1870  protected:
   1871   bool ShouldProcess() const override {
   1872     auto input0 = node_map_->GetNode(node_->input(0));
   1873     int port;
   1874     ParseNodeName(node_->input(0), &port);
   1875     return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
   1876            IsPortDimsFour(*input0, port) && IsReduceAxisSupported() &&
   1877            IsOnGPU();
   1878   }
   1879 
   1880   Status CustomizedProcessing() override {
   1881     if (IsReduceAxisSupported()) {
   1882       DataType dtype = node_->attr().at("Tidx").type();
   1883       TF_RETURN_IF_ERROR(
   1884           UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype));
   1885     }
   1886     return Status::OK();
   1887   }
   1888 
   1889   Status AddLayoutTransposeToOutputs() override {
   1890     if (KeepDims()) {
   1891       return AddTransformToOutputs("Transpose");
   1892     }
   1893     return Status::OK();
   1894   }
   1895 
   1896  private:
   1897   bool IsReduceAxisSupported() const {
   1898     return KeepDims() || ((IsAlongAllFourDims() || IsAlongHWC() ||
   1899                            IsAlongNHW() || IsAlongHW() || IsAlongC()) &&
   1900                           !KeepDims());
   1901   }
   1902 
   1903   bool IsAlongAxis(const std::vector<int>& axis) const {
   1904     auto axis_node = node_map_->GetNode(node_->input(1));
   1905     if (!IsConstant(*axis_node)) {
   1906       return false;
   1907     }
   1908     if (HasAttribute(*axis_node, "value").ok()) {
   1909       Tensor tensor;
   1910       auto success = tensor.FromProto(axis_node->attr().at({"value"}).tensor());
   1911       if (!success) {
   1912         LOG(ERROR) << "Failed to parse TensorProto.";
   1913       }
   1914       if (tensor.dims() == 1 && tensor.dim_size(0) == axis.size()) {
   1915         bool along_axis = true;
   1916         for (int i = 0; i < axis.size(); i++) {
   1917           along_axis = along_axis && (tensor.flat<int>()(i) == axis[i]);
   1918         }
   1919         if (along_axis) return true;
   1920       }
   1921     }
   1922     return false;
   1923   }
   1924 
   1925   bool IsAlongAllFourDims() const { return IsAlongAxis({0, 1, 2, 3}); }
   1926 
   1927   bool IsAlongHWC() const { return IsAlongAxis({1, 2, 3}); }
   1928 
   1929   bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
   1930 
   1931   bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
   1932 
   1933   bool IsAlongC() const { return IsAlongAxis({3}); }
   1934 
   1935   bool KeepDims() const { return node_->attr().at("keep_dims").b(); }
   1936 };
   1937 
   1938 class SwitchProcessor : public AgnosticNodeProcessor {
   1939  public:
   1940   explicit SwitchProcessor(const OptimizeContext& opt_cxt)
   1941       : AgnosticNodeProcessor(opt_cxt) {}
   1942 
   1943  protected:
   1944   std::set<int> GetOutputPos() const override { return {0, 1}; }
   1945 };
   1946 
   1947 class TileProcessor : public AgnosticNodeProcessor {
   1948  public:
   1949   explicit TileProcessor(const OptimizeContext& opt_cxt)
   1950       : AgnosticNodeProcessor(opt_cxt) {}
   1951 
   1952  protected:
   1953   Status CustomizedProcessing() override {
   1954     DataType dtype = node_->attr().at("Tmultiples").type();
   1955     return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype);
   1956   }
   1957 };
   1958 
   1959 class DataLayoutOptimizer : GraphProcessor {
   1960  public:
   1961   explicit DataLayoutOptimizer(
   1962       const GraphProperties& graph_properties,
   1963       const VirtualPlacer& virtual_placer,
   1964       const LayoutOptimizer::TuningConfig& config,
   1965       const std::unordered_set<string>& nodes_to_preserve, GraphDef* graph,
   1966       NodeMap* node_map)
   1967       : GraphProcessor(graph_properties, virtual_placer, nodes_to_preserve,
   1968                        graph, node_map),
   1969         config_(config) {}
   1970 
   1971   Status Optimize() {
   1972     VLOG(1) << "Number of nodes for original graph: " << graph_->node_size();
   1973     TF_RETURN_IF_ERROR(Expand());
   1974     VLOG(1) << "Number of nodes after Expand: " << graph_->node_size();
   1975     TF_RETURN_IF_ERROR(Collapse());
   1976     VLOG(1) << "Number of nodes after Collapse: " << graph_->node_size();
   1977     return Status::OK();
   1978   }
   1979 
   1980  private:
   1981   NodeDef* AddNodePermNHWCToNCHW() {
   1982     return AddNodePermConst(LayoutOptimizerNode(kPermNHWCToNCHW), "",
   1983                             {0, 3, 1, 2});
   1984   }
   1985 
   1986   NodeDef* AddNodePermNCHWToNHWC() {
   1987     return AddNodePermConst(LayoutOptimizerNode(kPermNCHWToNHWC), "",
   1988                             {0, 2, 3, 1});
   1989   }
   1990 
   1991   // Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
   1992   Status Expand() {
   1993     int node_size_original = graph_->node_size();
   1994 
   1995     FrameView frame_view;
   1996     TF_RETURN_IF_ERROR(frame_view.InferFromGraph(*graph_));
   1997 
   1998     // This is the first pass where we expand the nodes which support NCHW.
   1999     std::set<string> ops_format_supported = GetOpsFormatSupported();
   2000     for (int i = 0; i < node_size_original; i++) {
   2001       if (IsNodeByLayoutOptimizer(graph_->node(i).name())) {
   2002         return Status(error::INVALID_ARGUMENT,
   2003                       "The graph is already optimized by layout optimizer.");
   2004       }
   2005       if (ops_format_supported.find(graph_->node(i).op()) !=
   2006           ops_format_supported.end()) {
   2007         auto node = graph_->mutable_node(i);
   2008         bool is_in_frame = frame_view.IsInFrame(*node);
   2009         OptimizeContext opt_cxt(graph_, node, node_map_, graph_properties_,
   2010                                 virtual_placer_, nodes_to_preserve_,
   2011                                 is_in_frame);
   2012         std::unique_ptr<NodeProcessor> node_processor;
   2013         if (IsAvgPoolGrad(*node)) {
   2014           node_processor.reset(new AvgPoolGradProcessor(opt_cxt));
   2015         } else if (IsBiasAddGrad(*node)) {
   2016           node_processor.reset(new BiasAddGradProcessor(opt_cxt));
   2017         } else if (IsConv2D(*node)) {
   2018           node_processor.reset(new Conv2DProcessor(opt_cxt, config_.no_gemm));
   2019         } else if (IsConv2DBackpropFilter(*node)) {
   2020           node_processor.reset(
   2021               new Conv2DBackpropFilterProcessor(opt_cxt, config_.no_gemm));
   2022         } else if (IsConv2DBackpropInput(*node)) {
   2023           node_processor.reset(
   2024               new Conv2DBackpropInputProcessor(opt_cxt, config_.no_gemm));
   2025         } else if (IsDepthwiseConv2dNative(*node)) {
   2026           node_processor.reset(new Conv2DProcessor(opt_cxt, true));
   2027         } else if (IsDepthwiseConv2dNativeBackpropFilter(*node)) {
   2028           node_processor.reset(
   2029               new Conv2DBackpropFilterProcessor(opt_cxt, true));
   2030         } else if (IsDepthwiseConv2dNativeBackpropInput(*node)) {
   2031           node_processor.reset(new Conv2DBackpropInputProcessor(opt_cxt, true));
   2032         } else if (IsFusedBatchNormGrad(*node)) {
   2033           node_processor.reset(new FusedBatchNormGradProcessor(opt_cxt));
   2034         } else if (IsMaxPoolV2(*node)) {
   2035           node_processor.reset(new MaxPoolV2Processor(opt_cxt));
   2036         } else if (IsMaxPoolGradV1(*node) || IsMaxPoolGradGradV1(*node)) {
   2037           node_processor.reset(new MaxPoolGradProcessor(opt_cxt));
   2038         } else if (IsMaxPoolGradV2(*node) || IsMaxPoolGradGradV2(*node)) {
   2039           node_processor.reset(new MaxPoolGradV2Processor(opt_cxt));
   2040         } else {
   2041           node_processor.reset(new NodeProcessor(opt_cxt));
   2042         }
   2043         TF_RETURN_IF_ERROR(node_processor->ConvertNode());
   2044       }
   2045     }
   2046 
   2047     // This is the second pass where we expand layout-agnostic nodes. This pass
   2048     // only needs to be performed if at least one node in the previous pass is
   2049     // expanded.
   2050     if (graph_->node_size() > node_size_original) {
   2051       // Create Const nodes holding the permutation used by added Transposes of
   2052       // nodes not in a frame.
   2053       AddNodePermNHWCToNCHW();
   2054       AddNodePermNCHWToNHWC();
   2055       std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
   2056       for (int i = 0; i < graph_->node_size(); i++) {
   2057         if (ops_format_agnostic.find(graph_->node(i).op()) !=
   2058             ops_format_agnostic.end()) {
   2059           auto node = graph_->mutable_node(i);
   2060           bool is_in_frame = frame_view.IsInFrame(*node);
   2061           OptimizeContext opt_cxt(graph_, node, node_map_, graph_properties_,
   2062                                   virtual_placer_, nodes_to_preserve_,
   2063                                   is_in_frame);
   2064           std::unique_ptr<NodeProcessor> node_processor;
   2065           if (IsAddN(*node)) {
   2066             node_processor.reset(new AddNProcessor(opt_cxt));
   2067           } else if (IsBetainc(*node)) {
   2068             node_processor.reset(new TernaryOpProcessor(opt_cxt));
   2069           } else if (IsBinaryOp(*node)) {
   2070             node_processor.reset(new BinaryOpProcessor(opt_cxt));
   2071           } else if (IsConcat(*node)) {
   2072             node_processor.reset(new ConcatProcessor(opt_cxt));
   2073           } else if (IsFill(*node)) {
   2074             node_processor.reset(new FillProcessor(opt_cxt));
   2075           } else if (IsHistogramSummary(*node)) {
   2076             node_processor.reset(new HistogramSummaryProcessor(opt_cxt));
   2077           } else if (IsIdentityN(*node)) {
   2078             node_processor.reset(new IdentityNProcessor(opt_cxt));
   2079           } else if (IsMerge(*node)) {
   2080             node_processor.reset(new MergeProcessor(opt_cxt));
   2081           } else if (IsPad(*node) || IsMirrorPad(*node) ||
   2082                      IsMirrorPadGrad(*node)) {
   2083             node_processor.reset(new PadProcessor(opt_cxt));
   2084           } else if (IsReduceOp(*node)) {
   2085             node_processor.reset(new ReduceProcessor(opt_cxt));
   2086           } else if (IsReverseV2(*node)) {
   2087             node_processor.reset(new ReverseProcessor(opt_cxt));
   2088           } else if (IsSelect(*node)) {
   2089             node_processor.reset(new SelectProcessor(opt_cxt));
   2090           } else if (IsSlice(*node)) {
   2091             node_processor.reset(new SliceProcessor(opt_cxt));
   2092           } else if (IsStridedSlice(*node)) {
   2093             node_processor.reset(new StridedSliceProcessor(opt_cxt));
   2094           } else if (IsShape(*node) || IsShapeN(*node)) {
   2095             node_processor.reset(new ShapeProcessor(opt_cxt));
   2096           } else if (IsSplit(*node)) {
   2097             node_processor.reset(new SplitProcessor(opt_cxt));
   2098           } else if (IsSplitV(*node)) {
   2099             node_processor.reset(new SplitVProcessor(opt_cxt));
   2100           } else if (IsSqueeze(*node)) {
   2101             node_processor.reset(new SqueezeProcessor(opt_cxt));
   2102           } else if (IsStridedSliceGrad(*node)) {
   2103             node_processor.reset(new StridedSliceGradProcessor(opt_cxt));
   2104           } else if (IsSwitch(*node)) {
   2105             node_processor.reset(new SwitchProcessor(opt_cxt));
   2106           } else if (IsTile(*node)) {
   2107             node_processor.reset(new TileProcessor(opt_cxt));
   2108           } else if (IsUnaryGrad(*node)) {
   2109             node_processor.reset(new UnaryGradProcessor(opt_cxt));
   2110           } else {
   2111             node_processor.reset(new AgnosticNodeProcessor(opt_cxt));
   2112           }
   2113           TF_RETURN_IF_ERROR(node_processor->ConvertNode());
   2114         }
   2115       }
   2116     }
   2117     return Status::OK();
   2118   }
   2119 
   2120   // Remove all node pairs, where a NCHW-to-NHWC node is followed by
   2121   // a NHWC-to-NCHW node.
   2122   Status Collapse() {
   2123     std::unordered_set<string> nodes_removable;
   2124     for (int i = 0; i < graph_->node_size(); i++) {
   2125       auto node = graph_->mutable_node(i);
   2126       node->mutable_attr()->erase("_output_shapes");
   2127       if (IsTransposeNHWCToNCHW(node->name()) ||
   2128           IsDimMapNHWCToNCHW(node->name()) ||
   2129           IsVecPermuteNHWCToNCHW(node->name())) {
   2130         bool transpose_pair = IsTransposeNHWCToNCHW(node->name()) &&
   2131                               IsTransposeNCHWToNHWC(node->input(0));
   2132         bool dim_map_pair = IsDimMapNHWCToNCHW(node->name()) &&
   2133                             IsDimMapNCHWToNHWC(node->input(0));
   2134         bool vec_permute_pair = IsVecPermuteNHWCToNCHW(node->name()) &&
   2135                                 IsVecPermuteNCHWToNHWC(node->input(0));
   2136         if (transpose_pair || dim_map_pair || vec_permute_pair) {
   2137           const string& trans_first = node->input(0);
   2138           const string& trans_second = node->name();
   2139           auto outputs = node_map_->GetOutputs(trans_second);
   2140           CHECK(outputs.size() == 1)
   2141               << "There is always only a single output for a Transpose node, "
   2142               << "due to the way it is added by NodeProcessor.";
   2143           NodeDef* output = *outputs.begin();
   2144           string input = node_map_->GetNode(trans_first)->input(0);
   2145           for (int i = 0; i < output->input_size(); i++) {
   2146             if (output->input(i).compare(trans_second) == 0) {
   2147               *output->mutable_input(i) = input;
   2148               break;
   2149             }
   2150           }
   2151           nodes_removable.insert(trans_first);
   2152           nodes_removable.insert(trans_second);
   2153         }
   2154       }
   2155     }
   2156     graph_->mutable_node()->erase(
   2157         std::remove_if(
   2158             graph_->mutable_node()->begin(), graph_->mutable_node()->end(),
   2159             [nodes_removable](const NodeDef& node) {
   2160               return nodes_removable.find(node.name()) != nodes_removable.end();
   2161             }),
   2162         graph_->mutable_node()->end());
   2163     return Status::OK();
   2164   }
   2165 
   2166   const LayoutOptimizer::TuningConfig& config_;
   2167 };
   2168 
   2169 int GetNumGPUs(const Cluster& cluster) {
   2170   auto devices = cluster.GetDevices();
   2171   int num_gpus = 0;
   2172   for (const auto& device : devices) {
   2173     if (device.second.type() == "GPU") {
   2174       num_gpus++;
   2175     }
   2176   }
   2177   return num_gpus;
   2178 }
   2179 }  // namespace
   2180 
   2181 Status LayoutOptimizer::Tune(const GrapplerItem& item,
   2182                              const GraphProperties& graph_properties,
   2183                              const TuningConfig& config, GraphDef* output) {
   2184   auto status = graph_properties.AnnotateOutputShapes(output);
   2185   if (!status.ok()) {
   2186     VLOG(1) << "Annotate shape return status: " << status.ToString();
   2187     *output = item.graph;
   2188     return status;
   2189   }
   2190   NodeMap node_map(output);
   2191   DataLayoutOptimizer layout_optimizer(graph_properties, *virtual_placer_,
   2192                                        config, nodes_to_preserve_, output,
   2193                                        &node_map);
   2194   status = layout_optimizer.Optimize();
   2195   return status;
   2196 }
   2197 
   2198 Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
   2199                                  GraphDef* output) {
   2200   if (cluster == nullptr) {
   2201     return errors::InvalidArgument("cluster == nullptr");
   2202   }
   2203 
   2204   if (GetNumGPUs(*cluster) < 1) {
   2205     // LayoutOptimizer is currently only tuned for GPU.
   2206     *output = item.graph;
   2207     return Status::OK();
   2208   }
   2209 
   2210   virtual_placer_.reset(new VirtualPlacer(cluster));
   2211   nodes_to_preserve_ = item.NodesToPreserve();
   2212   GraphProperties graph_properties(item);
   2213   auto status = graph_properties.InferStatically(false);
   2214   if (!status.ok()) {
   2215     VLOG(1) << "Infer shape return status: " << status.ToString();
   2216     *output = item.graph;
   2217     return status;
   2218   }
   2219   GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
   2220 
   2221   TuningConfig config;
   2222   config.no_gemm = true;
   2223   // TODO(yaozhang): Enable tuning with various TuningConfig choices with
   2224   // the measurement-based estimator.
   2225   status = Tune(item, graph_properties, config, output);
   2226   if (!status.ok()) {
   2227     *output = item.graph;
   2228   }
   2229   return status;
   2230 }
   2231 
   2232 void LayoutOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
   2233                                const GraphDef& optimize_output, double result) {
   2234   // Nothing to do for LayoutOptimizer.
   2235 }
   2236 
   2237 }  // end namespace grappler
   2238 }  // end namespace tensorflow
   2239