Home | History | Annotate | Download | only in grappler
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/op_types.h"
     17 #include "tensorflow/core/framework/attr_value.pb.h"
     18 #include "tensorflow/core/framework/op.h"
     19 #include "tensorflow/core/framework/types.h"
     20 #include "tensorflow/core/grappler/utils.h"
     21 #include "tensorflow/core/lib/core/status.h"
     22 #include "tensorflow/core/lib/gtl/flatset.h"
     23 #include "tensorflow/core/lib/strings/str_util.h"
     24 #include "tensorflow/core/platform/logging.h"
     25 
     26 namespace tensorflow {
     27 namespace grappler {
     28 
     29 bool IsAdd(const NodeDef& node) {
     30   if (node.op() == "AddV2" || node.op() == "Add") {
     31     DataType type = node.attr().at("T").type();
     32     return type != DT_STRING;
     33   }
     34   return false;
     35 }
     36 
     37 bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
     38 
     39 bool IsAll(const NodeDef& node) { return node.op() == "All"; }
     40 
     41 bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; }
     42 
     43 bool IsAny(const NodeDef& node) { return node.op() == "Any"; }
     44 
     45 bool IsAnyDiv(const NodeDef& node) {
     46   return node.op() == "RealDiv" || node.op() == "Div" ||
     47          node.op() == "FloorDiv" || node.op() == "TruncateDiv";
     48 }
     49 
     50 bool IsAnyMax(const NodeDef& node) {
     51   const auto& op = node.op();
     52   return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax";
     53 }
     54 
     55 bool IsAnyMaxPool(const NodeDef& node) {
     56   const auto& op = node.op();
     57   return op == "MaxPool" || op == "MaxPoolV2" || op == "MaxPool3D" ||
     58          op == "MaxPoolWithArgmax" || op == "FractionalMaxPool";
     59 }
     60 
     61 bool IsAnyMin(const NodeDef& node) {
     62   const auto& op = node.op();
     63   return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin";
     64 }
     65 
     66 bool IsApproximateEqual(const NodeDef& node) {
     67   return node.op() == "ApproximateEqual";
     68 }
     69 
     70 bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax"; }
     71 
     72 bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin"; }
     73 
     74 bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
     75 
     76 bool IsAssign(const NodeDef& node) {
     77   return node.op() == "Assign" || node.op() == "AssignVariableOp";
     78 }
     79 
     80 bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
     81 
     82 bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }
     83 
     84 bool IsBetainc(const NodeDef& node) { return node.op() == "Betainc"; }
     85 
     86 bool IsBiasAdd(const NodeDef& node) {
     87   return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
     88 }
     89 
     90 bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
     91 
     92 bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
     93 
     94 bool IsCast(const NodeDef& node) { return node.op() == "Cast"; }
     95 
     96 bool IsCastLike(const NodeDef& node) {
     97   static const gtl::FlatSet<string>* const kCastLikeOps =
     98       CHECK_NOTNULL((new gtl::FlatSet<string>{
     99           "Angle", "Bucketize", "Cast", "CompareAndBitpack", "Dequantize",
    100           "HistogramFixedWidth", "Imag", "IsFinite", "IsInf", "IsNan",
    101           "Quantize", "QuantizeDownAndShrinkRange", "QuantizeV2",
    102           "QuantizedInstanceNorm", "QuantizedRelu", "QuantizedRelu6",
    103           "QuantizedReluX", "Real", "Requantize"}));
    104   return kCastLikeOps->count(node.op()) > 0;
    105 }
    106 
    107 bool IsCheckNumerics(const NodeDef& node) {
    108   return node.op() == "CheckNumerics";
    109 }
    110 
    111 bool IsCollective(const NodeDef& node) {
    112   return node.op() == "CollectiveReduce" ||
    113          node.op() == "CollectiveBcastSend" ||
    114          node.op() == "CollectiveBcastRecv";
    115 }
    116 
    117 bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
    118 
    119 bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
    120 
    121 bool IsConcat(const NodeDef& node) {
    122   return node.op() == "Concat" || node.op() == "ConcatV2";
    123 }
    124 
    125 bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
    126 
    127 bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
    128 
    129 bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
    130 
    131 bool IsConjugateTranspose(const NodeDef& node) {
    132   return node.op() == "ConjugateTranspose";
    133 }
    134 
    135 bool IsControlFlow(const NodeDef& node) {
    136   // clang-format off
    137   return node.op() == "ControlTrigger" ||
    138          node.op() == "Enter" ||
    139          node.op() == "Exit" ||
    140          node.op() == "LoopCond" ||
    141          node.op() == "Merge" ||
    142          node.op() == "NextIteration" ||
    143          node.op() == "Switch";
    144   // clang-format on
    145 }
    146 
    147 bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
    148 
    149 bool IsConv2DBackpropFilter(const NodeDef& node) {
    150   return node.op() == "Conv2DBackpropFilter";
    151 }
    152 
    153 bool IsConv2DBackpropInput(const NodeDef& node) {
    154   return node.op() == "Conv2DBackpropInput";
    155 }
    156 
    157 bool IsConv3D(const NodeDef& node) { return node.op() == "Conv3D"; }
    158 
    159 bool IsDepthwiseConv2dNative(const NodeDef& node) {
    160   return node.op() == "DepthwiseConv2dNative";
    161 }
    162 
    163 bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) {
    164   return node.op() == "DepthwiseConv2dNativeBackpropFilter";
    165 }
    166 
    167 bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) {
    168   return node.op() == "DepthwiseConv2dNativeBackpropInput";
    169 }
    170 
    171 bool IsDequeueOp(const NodeDef& node) {
    172   const auto& op = node.op();
    173   return op == "QueueDequeueManyV2" || op == "QueueDequeueMany" ||
    174          op == "QueueDequeueV2" || op == "QueueDequeue" ||
    175          op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo";
    176 }
    177 
    178 bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
    179 
    180 // Returns true if node represents a unary elementwise function that is
    181 // monotonic. If *is_non_decreasing is true, the function is non-decreasing,
    182 // e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
    183 // e.g. inv.
    184 bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
    185   static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps =
    186       CHECK_NOTNULL((new gtl::FlatSet<string>{
    187           "Acosh", "Asin", "Asinh",    "Atan",     "Atanh", "Ceil",
    188           "Elu",   "Erf",  "Exp",      "Expm1",    "Floor", "Log",
    189           "Log1p", "Relu", "Relu6",    "Rint",     "Selu",  "Sigmoid",
    190           "Sign",  "Sinh", "Softsign", "Softplus", "Sqrt",  "Tanh",
    191       }));
    192   static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps =
    193       CHECK_NOTNULL((new gtl::FlatSet<string>{"Acos", "Erfc", "Neg", "Rsqrt"}));
    194   if (kMonotonicNonDecreasingOps->count(node.op()) > 0) {
    195     if (is_non_decreasing) {
    196       *is_non_decreasing = true;
    197     }
    198     return true;
    199   } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) {
    200     if (is_non_decreasing) {
    201       *is_non_decreasing = false;
    202     }
    203     return true;
    204   }
    205   return false;
    206 }
    207 
    208 bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
    209 
    210 bool IsEnter(const NodeDef& node) {
    211   const auto& op = node.op();
    212   return op == "Enter" || op == "RefEnter";
    213 }
    214 
    215 bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; }
    216 
    217 bool IsExit(const NodeDef& node) {
    218   const auto& op = node.op();
    219   return op == "Exit" || op == "RefExit";
    220 }
    221 
    222 bool IsExp(const NodeDef& node) { return node.op() == "Exp"; }
    223 
    224 bool IsFakeParam(const NodeDef& node) { return node.op() == "FakeParam"; }
    225 
    226 bool IsFill(const NodeDef& node) { return node.op() == "Fill"; }
    227 
    228 bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
    229 
    230 bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
    231 
    232 bool IsFusedBatchNorm(const NodeDef& node) {
    233   const auto& op = node.op();
    234   return op == "FusedBatchNorm" || op == "FusedBatchNormV2";
    235 }
    236 
    237 bool IsFusedBatchNormGrad(const NodeDef& node) {
    238   const auto& op = node.op();
    239   return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2";
    240 }
    241 
    242 bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
    243 
    244 bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
    245 
    246 bool IsHostConstant(const NodeDef& node) { return node.op() == "HostConst"; }
    247 
    248 bool IsHistogramSummary(const NodeDef& node) {
    249   return node.op() == "HistogramSummary";
    250 }
    251 
    252 bool IsIdentity(const NodeDef& node) {
    253   const auto& op = node.op();
    254   return op == "Identity" || op == "RefIdentity";
    255 }
    256 
    257 bool IsIdentityN(const NodeDef& node) {
    258   const auto& op = node.op();
    259   return op == "IdentityN";
    260 }
    261 
    262 bool IsIdentityNSingleInput(const NodeDef& node) {
    263   return IsIdentityN(node) && node.attr().count("T") != 0 &&
    264          node.attr().at("T").list().type_size() == 1;
    265 }
    266 
    267 bool IsIf(const NodeDef& node) {
    268   const auto& op = node.op();
    269   return op == "If" || op == "StatelessIf";
    270 }
    271 
    272 bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; }
    273 
    274 bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; }
    275 
    276 bool IsImag(const NodeDef& node) { return node.op() == "Imag"; }
    277 
    278 bool IsImmutableConst(const NodeDef& node) {
    279   return node.op() == "ImmutableConst";
    280 }
    281 
    282 bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }
    283 
    284 bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
    285 
    286 bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
    287 
    288 bool IsLog(const NodeDef& node) { return node.op() == "Log"; }
    289 
    290 bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
    291 
    292 bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
    293 
    294 bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
    295 
    296 bool IsMatMul(const NodeDef& node) {
    297   const auto& op = node.op();
    298   return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" ||
    299          IsQuantizedMatMul(node);
    300 }
    301 
    302 bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
    303 
    304 bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; }
    305 
    306 bool IsMaxPoolGrad(const NodeDef& node) { return node.op() == "MaxPoolGrad"; }
    307 
    308 bool IsMean(const NodeDef& node) { return node.op() == "Mean"; }
    309 
    310 bool IsMerge(const NodeDef& node) {
    311   const auto& op = node.op();
    312   return op == "Merge" || op == "RefMerge";
    313 }
    314 
    315 bool IsMin(const NodeDef& node) { return node.op() == "Min"; }
    316 
    317 bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; }
    318 
    319 bool IsMirrorPad(const NodeDef& node) { return node.op() == "MirrorPad"; }
    320 
    321 bool IsMirrorPadGrad(const NodeDef& node) {
    322   return node.op() == "MirrorPadGrad";
    323 }
    324 
    325 bool IsMod(const NodeDef& node) { return node.op() == "Mod"; }
    326 
    327 bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
    328 
    329 bool IsNeg(const NodeDef& node) { return node.op() == "Neg"; }
    330 
    331 bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
    332 
    333 bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; }
    334 
    335 bool IsNextIteration(const NodeDef& node) {
    336   const auto& op = node.op();
    337   return op == "NextIteration" || op == "RefNextIteration";
    338 }
    339 
    340 bool IsOnesLike(const NodeDef& node) { return node.op() == "OnesLike"; }
    341 
    342 bool IsPack(const NodeDef& node) { return node.op() == "Pack"; }
    343 
    344 bool IsPad(const NodeDef& node) {
    345   const auto& op = node.op();
    346   return op == "Pad" || op == "PadV2";
    347 }
    348 
    349 bool IsPartitionedCall(const NodeDef& node) {
    350   return node.op() == "PartitionedCall";
    351 }
    352 
    353 bool IsPlaceholder(const NodeDef& node) {
    354   const auto& op = node.op();
    355   return op == "Placeholder" || op == "PlaceholderV2" ||
    356          op == "PlaceholderWithDefault";
    357 }
    358 
    359 bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; }
    360 
    361 bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
    362 
    363 bool IsPrint(const NodeDef& node) {
    364   return node.op() == "Print" || node.op() == "PrintV2";
    365 }
    366 
    367 bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
    368 
    369 bool IsQuantizedMatMul(const NodeDef& node) {
    370   return node.op() == "QuantizedMatMul" || node.op() == "QuantizedMatMulV2";
    371 }
    372 
    373 bool IsQueue(const NodeDef& node) {
    374   return str_util::EndsWith(node.op(), "QueueV2");
    375 }
    376 
    377 bool IsRandomShuffle(const NodeDef& node) {
    378   return node.op() == "RandomShuffle";
    379 }
    380 
    381 bool IsRank(const NodeDef& node) { return node.op() == "Rank"; }
    382 
    383 bool IsReadVariableOp(const NodeDef& node) {
    384   return node.op() == "ReadVariableOp";
    385 }
    386 
    387 bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
    388 
    389 bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
    390 
    391 bool IsReciprocalGrad(const NodeDef& node) {
    392   return node.op() == "ReciprocalGrad";
    393 }
    394 
    395 bool IsRecv(const NodeDef& node) {
    396   return node.op() == "_Recv" || node.op() == "_HostRecv";
    397 }
    398 
    399 bool IsReduction(const NodeDef& node) {
    400   const auto& op = node.op();
    401   return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" ||
    402          op == "Mean" || op == "Any" || op == "All";
    403 }
    404 
    405 bool IsRelu(const NodeDef& node) { return node.op() == "Relu"; }
    406 
    407 bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
    408 
    409 bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; }
    410 
    411 bool IsReshape(const NodeDef& node) { return (node.op() == "Reshape"); }
    412 
    413 bool IsRestore(const NodeDef& node) {
    414   return (node.op() == "Restore" || node.op() == "RestoreV2" ||
    415           node.op() == "RestoreSlice");
    416 }
    417 
    418 bool IsReverse(const NodeDef& node) {
    419   return node.op() == "Reverse" || node.op() == "ReverseV2";
    420 }
    421 
    422 bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; }
    423 
    424 bool IsRsqrt(const NodeDef& node) { return node.op() == "Rsqrt"; }
    425 
    426 bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; }
    427 
    428 bool IsSelect(const NodeDef& node) { return node.op() == "Select"; }
    429 
    430 bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; }
    431 
    432 bool IsSend(const NodeDef& node) {
    433   return node.op() == "_Send" || node.op() == "_HostSend";
    434 }
    435 
    436 bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
    437 
    438 bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; }
    439 
    440 bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle"; }
    441 
    442 bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; }
    443 
    444 bool IsSize(const NodeDef& node) { return node.op() == "Size"; }
    445 
    446 bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }
    447 
    448 bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; }
    449 
    450 bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax"; }
    451 
    452 bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; }
    453 
    454 bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad"; }
    455 
    456 bool IsSplit(const NodeDef& node) { return node.op() == "Split"; }
    457 
    458 bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; }
    459 
    460 bool IsSqrt(const NodeDef& node) { return node.op() == "Sqrt"; }
    461 
    462 bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; }
    463 
    464 bool IsSquare(const NodeDef& node) { return node.op() == "Square"; }
    465 
    466 bool IsSquaredDifference(const NodeDef& node) {
    467   return node.op() == "SquaredDifference";
    468 }
    469 
    470 bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze"; }
    471 
    472 bool IsStackOp(const NodeDef& node) {
    473   return node.op() == "Stack" || node.op() == "StackV2";
    474 }
    475 bool IsStackCloseOp(const NodeDef& node) {
    476   return node.op() == "StackClose" || node.op() == "StackCloseV2";
    477 }
    478 bool IsStackPushOp(const NodeDef& node) {
    479   return node.op() == "StackPush" || node.op() == "StackPushV2";
    480 }
    481 bool IsStackPopOp(const NodeDef& node) {
    482   return node.op() == "StackPop" || node.op() == "StackPopV2";
    483 }
    484 
    485 bool IsStatefulPartitionedCall(const NodeDef& node) {
    486   return node.op() == "StatefulPartitionedCall";
    487 }
    488 
    489 bool IsStopGradient(const NodeDef& node) {
    490   const auto& op = node.op();
    491   return op == "StopGradient" || op == "PreventGradient";
    492 }
    493 
    494 bool IsStridedSlice(const NodeDef& node) { return node.op() == "StridedSlice"; }
    495 
    496 bool IsStridedSliceGrad(const NodeDef& node) {
    497   return node.op() == "StridedSliceGrad";
    498 }
    499 
    500 bool IsSub(const NodeDef& node) { return node.op() == "Sub"; }
    501 
    502 bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
    503 
    504 bool IsSwitch(const NodeDef& node) {
    505   const auto& op = node.op();
    506   return op == "Switch" || op == "RefSwitch";
    507 }
    508 
    509 bool IsSymbolicGradient(const NodeDef& node) {
    510   return node.op() == "SymbolicGradient";
    511 }
    512 
    513 bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
    514 
    515 bool IsTensorArray(const NodeDef& node) {
    516   static const gtl::FlatSet<string>* const kTensorArrayOps =
    517       CHECK_NOTNULL((new gtl::FlatSet<string>{
    518           "TensorArray",
    519           "TensorArrayV2",
    520           "TensorArrayV3",
    521           "TensorArrayGrad",
    522           "TensorArrayGradV2",
    523           "TensorArrayGradV3",
    524           "TensorArrayGradWithShape",
    525           "TensorArrayWrite",
    526           "TensorArrayWriteV2",
    527           "TensorArrayWriteV3",
    528           "TensorArrayRead",
    529           "TensorArrayReadV2",
    530           "TensorArrayReadV3",
    531           "TensorArrayConcat",
    532           "TensorArrayConcatV2",
    533           "TensorArrayConcatV3",
    534           "TensorArraySplit",
    535           "TensorArraySplitV2",
    536           "TensorArraySplitV3",
    537           "TensorArraySize",
    538           "TensorArraySizeV2",
    539           "TensorArraySizeV3",
    540           "TensorArrayClose",
    541           "TensorArrayCloseV2",
    542           "TensorArrayCloseV3",
    543       }));
    544   return kTensorArrayOps->count(node.op()) > 0;
    545 }
    546 
    547 bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
    548 
    549 bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
    550 
    551 bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
    552 
    553 bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
    554 
    555 bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; }
    556 
    557 bool IsVariable(const NodeDef& node) {
    558   const auto& op = node.op();
    559   return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
    560          op == "VarHandleOp" || op == "ReadVariableOp";
    561 }
    562 
    563 bool IsWhile(const NodeDef& node) {
    564   const auto& op = node.op();
    565   return op == "While" || op == "StatelessWhile";
    566 }
    567 
    568 bool IsZerosLike(const NodeDef& node) { return node.op() == "ZerosLike"; }
    569 
    570 bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; }
    571 
    572 namespace {
    573 bool GetBoolAttr(const NodeDef& node, const string& name) {
    574   return node.attr().count(name) > 0 && node.attr().at(name).b();
    575 }
    576 }  // namespace
    577 
    578 bool IsPersistent(const NodeDef& node) {
    579   return IsConstant(node) || IsVariable(node) || IsHostConstant(node);
    580 }
    581 
    582 bool MaybeHasRefInput(const NodeDef& node) {
    583   const OpDef* op_def;
    584   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
    585   if (!status.ok()) {
    586     return true;
    587   }
    588   // Nodes such as Assign or AssignAdd modify one of their inputs.
    589   for (const auto& input : op_def->input_arg()) {
    590     if (input.is_ref()) {
    591       return true;
    592     }
    593   }
    594   return false;
    595 }
    596 
    597 bool IsDataset(const NodeDef& node) {
    598   const string& op = node.op();
    599   // See `GetNodeClassForOp` in core/graph/graph.cc.
    600   return op == "IteratorGetNext" || op == "IteratorGetNextSync" ||
    601          op == "DatasetToSingleElement" || op == "ReduceDataset";
    602 }
    603 
    604 bool IsStateful(const NodeDef node, const OpRegistryInterface* op_registry) {
    605   const OpDef* op_def = nullptr;
    606   const string& op_name = node.op();
    607   Status status = op_registry->LookUpOpDef(op_name, &op_def);
    608   if (!status.ok()) {
    609     LOG(WARNING) << "Failed to lookup OpDef for " << op_name
    610                  << ". Error: " << status.error_message();
    611     return false;
    612   }
    613   return op_def->is_stateful();
    614 }
    615 
    616 bool IsStateful(const NodeDef node) {
    617   return IsStateful(node, OpRegistry::Global());
    618 }
    619 
    620 bool IsFreeOfSideEffect(const NodeDef& node,
    621                         const OpRegistryInterface* op_registry) {
    622   // Placeholders must be preserved to keep the graph feedable.
    623   if (IsPlaceholder(node)) {
    624     return false;
    625   }
    626   const OpDef* op_def = nullptr;
    627   const string& op_name = node.op();
    628   Status status = op_registry->LookUpOpDef(op_name, &op_def);
    629   if (!status.ok()) {
    630     return false;
    631   }
    632   if (op_def->is_stateful()) {
    633     return false;
    634   }
    635   // Nodes such as Assign or AssignAdd modify one of their inputs.
    636   for (const auto& input : op_def->input_arg()) {
    637     if (input.is_ref()) {
    638       return false;
    639     }
    640   }
    641   // Queue ops modify the queue which is a side effect.
    642   if (node.op().find("Queue") != string::npos) {
    643     return false;
    644   }
    645   // Sending a tensor via a network is a side effect.
    646   if (IsSend(node)) {
    647     return false;
    648   }
    649   return !ModifiesInputsInPlace(node);
    650 }
    651 
    652 bool IsFreeOfSideEffect(const NodeDef& node) {
    653   return IsFreeOfSideEffect(node, OpRegistry::Global());
    654 }
    655 
    656 bool ModifiesInputsInPlace(const NodeDef& node) {
    657   // Some nodes do in-place updates on regular tensor inputs.
    658   string op_name = node.op();
    659 
    660   // Ops that modify resource variables effectively modify one of their inputs.
    661   if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" ||
    662       op_name == "AssignSubVariableOp" || op_name == "ResourceScatterUpdate" ||
    663       op_name == "ResourceScatterAdd" || op_name == "ResourceScatterSub" ||
    664       op_name == "ResourceScatterMul" || op_name == "ResourceScatterDiv" ||
    665       op_name == "ResourceScatterMin" || op_name == "ResourceScatterMax") {
    666     return false;
    667   }
    668 
    669   std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower);
    670   if (str_util::StrContains(op_name, "inplace")) {
    671     return true;
    672   }
    673   return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace");
    674 }
    675 
    676 bool ModifiesFrameInfo(const NodeDef& node) {
    677   return IsEnter(node) || IsExit(node) || IsNextIteration(node);
    678 }
    679 
    680 #define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY)                      \
    681   bool Is##PROPERTY_CAP(const NodeDef& node) {                             \
    682     if (node.op() == "Add") {                                              \
    683       /* Workaround for "Add" not being marked is_commutative and */       \
    684       /* is_aggregate. (See cl/173915048). */                              \
    685       const auto type = GetDataTypeFromAttr(node, "T");                    \
    686       return type != DT_INVALID && type != DT_STRING;                      \
    687     }                                                                      \
    688     const OpDef* op_def = nullptr;                                         \
    689     Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); \
    690     return status.ok() && op_def->is_##PROPERTY();                         \
    691   }
    692 
    693 OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
    694 OPDEF_PROPERTY_HELPER(Commutative, commutative)
    695 
    696 bool IsInvolution(const NodeDef& node) {
    697   static const gtl::FlatSet<string>* const kInvolutionOps =
    698       CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert",
    699                                               "Neg", "LogicalNot"}));
    700   return kInvolutionOps->count(node.op()) > 0;
    701 }
    702 
    703 bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
    704   if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
    705     return true;
    706   }
    707   static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps =
    708       CHECK_NOTNULL((new const gtl::FlatSet<string>{
    709           "CheckNumerics",
    710           "DebugGradientIdentity",
    711           "DeepCopy"
    712           "Enter",
    713           "Exit",
    714           "PreventGradient",
    715           "Print",
    716           "Snapshot",
    717           "StopGradient",
    718       }));
    719   return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 ||
    720          IsIdentity(node);
    721 }
    722 
    723 bool IsValueAndOrderPreserving(const NodeDef& node) {
    724   if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
    725     return true;
    726   }
    727   static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps =
    728       CHECK_NOTNULL((new const gtl::FlatSet<string>{
    729           "ExpandDims",
    730           "Reshape",
    731           "Squeeze",
    732       }));
    733   return kValueAndOrderPreservingOps->count(node.op()) > 0 ||
    734          IsValueAndOrderAndShapePreserving(node);
    735 }
    736 
    737 bool IsValuePreserving(const NodeDef& node) {
    738   static const gtl::FlatSet<string>* const kValuePreservingOps =
    739       CHECK_NOTNULL((new gtl::FlatSet<string>{
    740           "InvertPermutation",
    741           "Reverse",
    742           "ReverseV2",
    743           "Roll",
    744           "Transpose",
    745           "DepthToSpace",
    746           "SpaceToDepth",
    747           "BatchToSpace",
    748           "BatchToSpaceND",
    749           "SpaceToBatch",
    750           "SpaceToBatchND",
    751       }));
    752   return IsValueAndOrderPreserving(node) ||
    753          kValuePreservingOps->count(node.op()) > 0;
    754 }
    755 
    756 bool IsUnaryElementWise(const NodeDef& node) {
    757   static const gtl::FlatSet<string>* const kElementWiseOps =
    758       CHECK_NOTNULL((new gtl::FlatSet<string>{
    759           "Abs",
    760           "Acos",
    761           "Acosh",
    762           "Asin",
    763           "Asinh",
    764           "Atan",
    765           "Atanh",
    766           "Ceil",
    767           "ComplexAbs",
    768           "Conj",
    769           "Cos",
    770           "Cosh",
    771           "Digamma",
    772           "Elu"
    773           "Erf",
    774           "Erfc",
    775           "Exp",
    776           "Expm1",
    777           "Floor",
    778           "Inv",
    779           "Invert",
    780           "Isinf",
    781           "Isnan",
    782           "Isfinite",
    783           "Lgamma",
    784           "Log",
    785           "Log1p",
    786           "LogicalNot",
    787           "Neg",
    788           "Reciprocal",
    789           "Relu",
    790           "Relu6",
    791           "Rint",
    792           "Round",
    793           "Selu",
    794           "Rsqrt",
    795           "Sigmoid",
    796           "Sign",
    797           "Sin",
    798           "SinH",
    799           "Softplus",
    800           "Softsign",
    801           "Sqrt",
    802           "Square",
    803           "Tan"
    804           "Tanh",
    805       }));
    806   return kElementWiseOps->count(node.op()) > 0 ||
    807          IsValueAndOrderAndShapePreserving(node);
    808 }
    809 
    810 bool HasOpDef(const NodeDef& node) {
    811   const OpDef* op_def = nullptr;
    812   return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok();
    813 }
    814 
    815 bool IsIdempotent(const NodeDef& node) {
    816   return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) &&
    817          !ModifiesFrameInfo(node);
    818 }
    819 
    820 bool NeverForwardsInputs(const NodeDef& node) {
    821   static const gtl::FlatSet<string>* const kNonForwardingOps = CHECK_NOTNULL(
    822       (new gtl::FlatSet<string>{"ArgMax",
    823                                 "ArgMin",
    824                                 "AudioSpectrogram",
    825                                 "BatchMatMul",
    826                                 "BatchToSpace",
    827                                 "BatchToSpaceND",
    828                                 "Bincount",
    829                                 "BroadcastArgs",
    830                                 "BroadcastGradientArgs",
    831                                 "CTCBeamSearchDecoder",
    832                                 "CTCGreedyDecoder",
    833                                 "CTCLoss",
    834                                 "ComplexAbs",
    835                                 "Concat",
    836                                 "ConcatOffset",
    837                                 "ConcatV2",
    838                                 "Copy",
    839                                 "CopyHost",
    840                                 "Cross",
    841                                 "CudnnRNN",
    842                                 "CudnnRNNBackprop",
    843                                 "CudnnRNNBackpropV2",
    844                                 "CudnnRNNBackpropV3",
    845                                 "CudnnRNNCanonicalToParams",
    846                                 "CudnnRNNParamsSize",
    847                                 "CudnnRNNParamsToCanonical",
    848                                 "CudnnRNNV2",
    849                                 "CudnnRNNV3",
    850                                 "CumSum",
    851                                 "CumProd",
    852                                 "DebugNanCount",
    853                                 "DebugNumericSummary",
    854                                 "DecodeProtoV2",
    855                                 "DecodeWav",
    856                                 "DeepCopy",
    857                                 "DepthToSpace",
    858                                 "Dequantize",
    859                                 "Diag",
    860                                 "DiagPart",
    861                                 "EditDistance",
    862                                 "Empty",
    863                                 "EncodeProtoV2",
    864                                 "EncodeWav",
    865                                 "ExtractImagePatches",
    866                                 "ExtractVolumePatches",
    867                                 "Fill",
    868                                 "Gather",
    869                                 "GatherNd",
    870                                 "GatherV2",
    871                                 "HistogramFixedWidth",
    872                                 "InvertPermutation",
    873                                 "IsInf",
    874                                 "IsNan",
    875                                 "Isfinite",
    876                                 "LinSpace",
    877                                 "LowerBound",
    878                                 "MatMul",
    879                                 "MatrixDiag",
    880                                 "MatrixDiagPart",
    881                                 "Mfcc",
    882                                 "OneHot",
    883                                 "Pack",
    884                                 "PopulationCount",
    885                                 "Range",
    886                                 "Rank",
    887                                 "ReverseSequence",
    888                                 "Shape",
    889                                 "ShapeN",
    890                                 "Size",
    891                                 "SpaceToBatch",
    892                                 "SpaceToBatchND",
    893                                 "SpaceToDepth",
    894                                 "SparseMatMul",
    895                                 "Split",
    896                                 "SplitV",
    897                                 "Unique",
    898                                 "UniqueV2",
    899                                 "UniqueWithCounts",
    900                                 "UniqueWithCountsV2",
    901                                 "Unpack",
    902                                 "UnravelIndex",
    903                                 "UpperBound",
    904                                 "Where",
    905                                 "CompareAndBitpack",
    906                                 "Requantize",
    907                                 "RequantizationRange",
    908                                 "Bucketize",
    909                                 "AvgPool",
    910                                 "BatchNormWithGlobalNormalization",
    911                                 "FusedBatchNorm",
    912                                 "FusedBatchNormV2",
    913                                 "Conv2D",
    914                                 "RandomUniform",
    915                                 "RandomUniformInt",
    916                                 "RandomStandardNormal",
    917                                 "ParameterizedTruncatedNormal",
    918                                 "TruncatedNormal",
    919                                 "Multinomial",
    920                                 "RandomGamma",
    921                                 "RandomPoisson",
    922                                 "RandomPoissonV2"}));
    923   const string& op_name = node.op();
    924   return kNonForwardingOps->count(op_name) > 0 ||
    925          str_util::StrContains(op_name, "Segment") ||
    926          str_util::StartsWith(op_name, "Quantize");
    927 }
    928 
    929 }  // namespace grappler
    930 }  // end namespace tensorflow
    931