1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/op_types.h" 17 #include "tensorflow/core/framework/attr_value.pb.h" 18 #include "tensorflow/core/framework/op.h" 19 #include "tensorflow/core/framework/types.h" 20 #include "tensorflow/core/grappler/utils.h" 21 #include "tensorflow/core/lib/core/status.h" 22 #include "tensorflow/core/lib/gtl/flatset.h" 23 #include "tensorflow/core/lib/strings/str_util.h" 24 #include "tensorflow/core/platform/logging.h" 25 26 namespace tensorflow { 27 namespace grappler { 28 29 bool IsAdd(const NodeDef& node) { 30 if (node.op() == "AddV2" || node.op() == "Add") { 31 DataType type = node.attr().at("T").type(); 32 return type != DT_STRING; 33 } 34 return false; 35 } 36 37 bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; } 38 39 bool IsAll(const NodeDef& node) { return node.op() == "All"; } 40 41 bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; } 42 43 bool IsAny(const NodeDef& node) { return node.op() == "Any"; } 44 45 bool IsAnyDiv(const NodeDef& node) { 46 return node.op() == "RealDiv" || node.op() == "Div" || 47 node.op() == "FloorDiv" || node.op() == "TruncateDiv"; 48 } 49 50 bool IsAnyMax(const NodeDef& node) { 51 const auto& op = node.op(); 52 return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax"; 53 } 54 55 bool IsAnyMaxPool(const NodeDef& node) { 56 const auto& op = node.op(); 57 return op == "MaxPool" || op == "MaxPoolV2" || op == "MaxPool3D" || 58 op == "MaxPoolWithArgmax" || op == "FractionalMaxPool"; 59 } 60 61 bool IsAnyMin(const NodeDef& node) { 62 const auto& op = node.op(); 63 return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin"; 64 } 65 66 bool IsApproximateEqual(const NodeDef& node) { 67 return node.op() == "ApproximateEqual"; 68 } 69 70 bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax"; } 71 72 bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin"; } 73 74 bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; } 75 76 bool IsAssign(const NodeDef& node) { 77 return node.op() == "Assign" || node.op() == "AssignVariableOp"; 78 } 79 80 bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; } 81 82 bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; } 83 84 bool IsBetainc(const NodeDef& node) { return node.op() == "Betainc"; } 85 86 bool IsBiasAdd(const NodeDef& node) { 87 return node.op() == "BiasAdd" || node.op() == "BiasAddV1"; 88 } 89 90 bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; } 91 92 bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; } 93 94 bool IsCast(const NodeDef& node) { return node.op() == "Cast"; } 95 96 bool IsCastLike(const NodeDef& node) { 97 static const gtl::FlatSet<string>* const kCastLikeOps = 98 CHECK_NOTNULL((new gtl::FlatSet<string>{ 99 "Angle", "Bucketize", "Cast", "CompareAndBitpack", "Dequantize", 100 "HistogramFixedWidth", "Imag", "IsFinite", "IsInf", "IsNan", 101 "Quantize", "QuantizeDownAndShrinkRange", "QuantizeV2", 102 "QuantizedInstanceNorm", "QuantizedRelu", "QuantizedRelu6", 103 "QuantizedReluX", "Real", "Requantize"})); 104 return kCastLikeOps->count(node.op()) > 0; 105 } 106 107 bool IsCheckNumerics(const NodeDef& node) { 108 return node.op() == "CheckNumerics"; 109 } 110 111 bool IsCollective(const NodeDef& node) { 112 return node.op() == "CollectiveReduce" || 113 node.op() == "CollectiveBcastSend" || 114 node.op() == "CollectiveBcastRecv"; 115 } 116 117 bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; } 118 119 bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; } 120 121 bool IsConcat(const NodeDef& node) { 122 return node.op() == "Concat" || node.op() == "ConcatV2"; 123 } 124 125 bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; } 126 127 bool IsConstant(const NodeDef& node) { return node.op() == "Const"; } 128 129 bool IsConj(const NodeDef& node) { return node.op() == "Conj"; } 130 131 bool IsConjugateTranspose(const NodeDef& node) { 132 return node.op() == "ConjugateTranspose"; 133 } 134 135 bool IsControlFlow(const NodeDef& node) { 136 // clang-format off 137 return node.op() == "ControlTrigger" || 138 node.op() == "Enter" || 139 node.op() == "Exit" || 140 node.op() == "LoopCond" || 141 node.op() == "Merge" || 142 node.op() == "NextIteration" || 143 node.op() == "Switch"; 144 // clang-format on 145 } 146 147 bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; } 148 149 bool IsConv2DBackpropFilter(const NodeDef& node) { 150 return node.op() == "Conv2DBackpropFilter"; 151 } 152 153 bool IsConv2DBackpropInput(const NodeDef& node) { 154 return node.op() == "Conv2DBackpropInput"; 155 } 156 157 bool IsConv3D(const NodeDef& node) { return node.op() == "Conv3D"; } 158 159 bool IsDepthwiseConv2dNative(const NodeDef& node) { 160 return node.op() == "DepthwiseConv2dNative"; 161 } 162 163 bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) { 164 return node.op() == "DepthwiseConv2dNativeBackpropFilter"; 165 } 166 167 bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) { 168 return node.op() == "DepthwiseConv2dNativeBackpropInput"; 169 } 170 171 bool IsDequeueOp(const NodeDef& node) { 172 const auto& op = node.op(); 173 return op == "QueueDequeueManyV2" || op == "QueueDequeueMany" || 174 op == "QueueDequeueV2" || op == "QueueDequeue" || 175 op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo"; 176 } 177 178 bool IsDiv(const NodeDef& node) { return node.op() == "Div"; } 179 180 // Returns true if node represents a unary elementwise function that is 181 // monotonic. If *is_non_decreasing is true, the function is non-decreasing, 182 // e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing, 183 // e.g. inv. 184 bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) { 185 static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps = 186 CHECK_NOTNULL((new gtl::FlatSet<string>{ 187 "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil", 188 "Elu", "Erf", "Exp", "Expm1", "Floor", "Log", 189 "Log1p", "Relu", "Relu6", "Rint", "Selu", "Sigmoid", 190 "Sign", "Sinh", "Softsign", "Softplus", "Sqrt", "Tanh", 191 })); 192 static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps = 193 CHECK_NOTNULL((new gtl::FlatSet<string>{"Acos", "Erfc", "Neg", "Rsqrt"})); 194 if (kMonotonicNonDecreasingOps->count(node.op()) > 0) { 195 if (is_non_decreasing) { 196 *is_non_decreasing = true; 197 } 198 return true; 199 } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) { 200 if (is_non_decreasing) { 201 *is_non_decreasing = false; 202 } 203 return true; 204 } 205 return false; 206 } 207 208 bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; } 209 210 bool IsEnter(const NodeDef& node) { 211 const auto& op = node.op(); 212 return op == "Enter" || op == "RefEnter"; 213 } 214 215 bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; } 216 217 bool IsExit(const NodeDef& node) { 218 const auto& op = node.op(); 219 return op == "Exit" || op == "RefExit"; 220 } 221 222 bool IsExp(const NodeDef& node) { return node.op() == "Exp"; } 223 224 bool IsFakeParam(const NodeDef& node) { return node.op() == "FakeParam"; } 225 226 bool IsFill(const NodeDef& node) { return node.op() == "Fill"; } 227 228 bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; } 229 230 bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; } 231 232 bool IsFusedBatchNorm(const NodeDef& node) { 233 const auto& op = node.op(); 234 return op == "FusedBatchNorm" || op == "FusedBatchNormV2"; 235 } 236 237 bool IsFusedBatchNormGrad(const NodeDef& node) { 238 const auto& op = node.op(); 239 return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2"; 240 } 241 242 bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; } 243 244 bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; } 245 246 bool IsHostConstant(const NodeDef& node) { return node.op() == "HostConst"; } 247 248 bool IsHistogramSummary(const NodeDef& node) { 249 return node.op() == "HistogramSummary"; 250 } 251 252 bool IsIdentity(const NodeDef& node) { 253 const auto& op = node.op(); 254 return op == "Identity" || op == "RefIdentity"; 255 } 256 257 bool IsIdentityN(const NodeDef& node) { 258 const auto& op = node.op(); 259 return op == "IdentityN"; 260 } 261 262 bool IsIdentityNSingleInput(const NodeDef& node) { 263 return IsIdentityN(node) && node.attr().count("T") != 0 && 264 node.attr().at("T").list().type_size() == 1; 265 } 266 267 bool IsIf(const NodeDef& node) { 268 const auto& op = node.op(); 269 return op == "If" || op == "StatelessIf"; 270 } 271 272 bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; } 273 274 bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; } 275 276 bool IsImag(const NodeDef& node) { return node.op() == "Imag"; } 277 278 bool IsImmutableConst(const NodeDef& node) { 279 return node.op() == "ImmutableConst"; 280 } 281 282 bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; } 283 284 bool IsLess(const NodeDef& node) { return node.op() == "Less"; } 285 286 bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; } 287 288 bool IsLog(const NodeDef& node) { return node.op() == "Log"; } 289 290 bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; } 291 292 bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; } 293 294 bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; } 295 296 bool IsMatMul(const NodeDef& node) { 297 const auto& op = node.op(); 298 return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" || 299 IsQuantizedMatMul(node); 300 } 301 302 bool IsMax(const NodeDef& node) { return node.op() == "Max"; } 303 304 bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; } 305 306 bool IsMaxPoolGrad(const NodeDef& node) { return node.op() == "MaxPoolGrad"; } 307 308 bool IsMean(const NodeDef& node) { return node.op() == "Mean"; } 309 310 bool IsMerge(const NodeDef& node) { 311 const auto& op = node.op(); 312 return op == "Merge" || op == "RefMerge"; 313 } 314 315 bool IsMin(const NodeDef& node) { return node.op() == "Min"; } 316 317 bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; } 318 319 bool IsMirrorPad(const NodeDef& node) { return node.op() == "MirrorPad"; } 320 321 bool IsMirrorPadGrad(const NodeDef& node) { 322 return node.op() == "MirrorPadGrad"; 323 } 324 325 bool IsMod(const NodeDef& node) { return node.op() == "Mod"; } 326 327 bool IsMul(const NodeDef& node) { return node.op() == "Mul"; } 328 329 bool IsNeg(const NodeDef& node) { return node.op() == "Neg"; } 330 331 bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; } 332 333 bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; } 334 335 bool IsNextIteration(const NodeDef& node) { 336 const auto& op = node.op(); 337 return op == "NextIteration" || op == "RefNextIteration"; 338 } 339 340 bool IsOnesLike(const NodeDef& node) { return node.op() == "OnesLike"; } 341 342 bool IsPack(const NodeDef& node) { return node.op() == "Pack"; } 343 344 bool IsPad(const NodeDef& node) { 345 const auto& op = node.op(); 346 return op == "Pad" || op == "PadV2"; 347 } 348 349 bool IsPartitionedCall(const NodeDef& node) { 350 return node.op() == "PartitionedCall"; 351 } 352 353 bool IsPlaceholder(const NodeDef& node) { 354 const auto& op = node.op(); 355 return op == "Placeholder" || op == "PlaceholderV2" || 356 op == "PlaceholderWithDefault"; 357 } 358 359 bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; } 360 361 bool IsPow(const NodeDef& node) { return node.op() == "Pow"; } 362 363 bool IsPrint(const NodeDef& node) { 364 return node.op() == "Print" || node.op() == "PrintV2"; 365 } 366 367 bool IsProd(const NodeDef& node) { return node.op() == "Prod"; } 368 369 bool IsQuantizedMatMul(const NodeDef& node) { 370 return node.op() == "QuantizedMatMul" || node.op() == "QuantizedMatMulV2"; 371 } 372 373 bool IsQueue(const NodeDef& node) { 374 return str_util::EndsWith(node.op(), "QueueV2"); 375 } 376 377 bool IsRandomShuffle(const NodeDef& node) { 378 return node.op() == "RandomShuffle"; 379 } 380 381 bool IsRank(const NodeDef& node) { return node.op() == "Rank"; } 382 383 bool IsReadVariableOp(const NodeDef& node) { 384 return node.op() == "ReadVariableOp"; 385 } 386 387 bool IsReal(const NodeDef& node) { return node.op() == "Real"; } 388 389 bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; } 390 391 bool IsReciprocalGrad(const NodeDef& node) { 392 return node.op() == "ReciprocalGrad"; 393 } 394 395 bool IsRecv(const NodeDef& node) { 396 return node.op() == "_Recv" || node.op() == "_HostRecv"; 397 } 398 399 bool IsReduction(const NodeDef& node) { 400 const auto& op = node.op(); 401 return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" || 402 op == "Mean" || op == "Any" || op == "All"; 403 } 404 405 bool IsRelu(const NodeDef& node) { return node.op() == "Relu"; } 406 407 bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; } 408 409 bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; } 410 411 bool IsReshape(const NodeDef& node) { return (node.op() == "Reshape"); } 412 413 bool IsRestore(const NodeDef& node) { 414 return (node.op() == "Restore" || node.op() == "RestoreV2" || 415 node.op() == "RestoreSlice"); 416 } 417 418 bool IsReverse(const NodeDef& node) { 419 return node.op() == "Reverse" || node.op() == "ReverseV2"; 420 } 421 422 bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; } 423 424 bool IsRsqrt(const NodeDef& node) { return node.op() == "Rsqrt"; } 425 426 bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; } 427 428 bool IsSelect(const NodeDef& node) { return node.op() == "Select"; } 429 430 bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; } 431 432 bool IsSend(const NodeDef& node) { 433 return node.op() == "_Send" || node.op() == "_HostSend"; 434 } 435 436 bool IsShape(const NodeDef& node) { return node.op() == "Shape"; } 437 438 bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; } 439 440 bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle"; } 441 442 bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; } 443 444 bool IsSize(const NodeDef& node) { return node.op() == "Size"; } 445 446 bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; } 447 448 bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; } 449 450 bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax"; } 451 452 bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; } 453 454 bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad"; } 455 456 bool IsSplit(const NodeDef& node) { return node.op() == "Split"; } 457 458 bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; } 459 460 bool IsSqrt(const NodeDef& node) { return node.op() == "Sqrt"; } 461 462 bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; } 463 464 bool IsSquare(const NodeDef& node) { return node.op() == "Square"; } 465 466 bool IsSquaredDifference(const NodeDef& node) { 467 return node.op() == "SquaredDifference"; 468 } 469 470 bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze"; } 471 472 bool IsStackOp(const NodeDef& node) { 473 return node.op() == "Stack" || node.op() == "StackV2"; 474 } 475 bool IsStackCloseOp(const NodeDef& node) { 476 return node.op() == "StackClose" || node.op() == "StackCloseV2"; 477 } 478 bool IsStackPushOp(const NodeDef& node) { 479 return node.op() == "StackPush" || node.op() == "StackPushV2"; 480 } 481 bool IsStackPopOp(const NodeDef& node) { 482 return node.op() == "StackPop" || node.op() == "StackPopV2"; 483 } 484 485 bool IsStatefulPartitionedCall(const NodeDef& node) { 486 return node.op() == "StatefulPartitionedCall"; 487 } 488 489 bool IsStopGradient(const NodeDef& node) { 490 const auto& op = node.op(); 491 return op == "StopGradient" || op == "PreventGradient"; 492 } 493 494 bool IsStridedSlice(const NodeDef& node) { return node.op() == "StridedSlice"; } 495 496 bool IsStridedSliceGrad(const NodeDef& node) { 497 return node.op() == "StridedSliceGrad"; 498 } 499 500 bool IsSub(const NodeDef& node) { return node.op() == "Sub"; } 501 502 bool IsSum(const NodeDef& node) { return node.op() == "Sum"; } 503 504 bool IsSwitch(const NodeDef& node) { 505 const auto& op = node.op(); 506 return op == "Switch" || op == "RefSwitch"; 507 } 508 509 bool IsSymbolicGradient(const NodeDef& node) { 510 return node.op() == "SymbolicGradient"; 511 } 512 513 bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; } 514 515 bool IsTensorArray(const NodeDef& node) { 516 static const gtl::FlatSet<string>* const kTensorArrayOps = 517 CHECK_NOTNULL((new gtl::FlatSet<string>{ 518 "TensorArray", 519 "TensorArrayV2", 520 "TensorArrayV3", 521 "TensorArrayGrad", 522 "TensorArrayGradV2", 523 "TensorArrayGradV3", 524 "TensorArrayGradWithShape", 525 "TensorArrayWrite", 526 "TensorArrayWriteV2", 527 "TensorArrayWriteV3", 528 "TensorArrayRead", 529 "TensorArrayReadV2", 530 "TensorArrayReadV3", 531 "TensorArrayConcat", 532 "TensorArrayConcatV2", 533 "TensorArrayConcatV3", 534 "TensorArraySplit", 535 "TensorArraySplitV2", 536 "TensorArraySplitV3", 537 "TensorArraySize", 538 "TensorArraySizeV2", 539 "TensorArraySizeV3", 540 "TensorArrayClose", 541 "TensorArrayCloseV2", 542 "TensorArrayCloseV3", 543 })); 544 return kTensorArrayOps->count(node.op()) > 0; 545 } 546 547 bool IsTile(const NodeDef& node) { return node.op() == "Tile"; } 548 549 bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; } 550 551 bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; } 552 553 bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; } 554 555 bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; } 556 557 bool IsVariable(const NodeDef& node) { 558 const auto& op = node.op(); 559 return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" || 560 op == "VarHandleOp" || op == "ReadVariableOp"; 561 } 562 563 bool IsWhile(const NodeDef& node) { 564 const auto& op = node.op(); 565 return op == "While" || op == "StatelessWhile"; 566 } 567 568 bool IsZerosLike(const NodeDef& node) { return node.op() == "ZerosLike"; } 569 570 bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; } 571 572 namespace { 573 bool GetBoolAttr(const NodeDef& node, const string& name) { 574 return node.attr().count(name) > 0 && node.attr().at(name).b(); 575 } 576 } // namespace 577 578 bool IsPersistent(const NodeDef& node) { 579 return IsConstant(node) || IsVariable(node) || IsHostConstant(node); 580 } 581 582 bool MaybeHasRefInput(const NodeDef& node) { 583 const OpDef* op_def; 584 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); 585 if (!status.ok()) { 586 return true; 587 } 588 // Nodes such as Assign or AssignAdd modify one of their inputs. 589 for (const auto& input : op_def->input_arg()) { 590 if (input.is_ref()) { 591 return true; 592 } 593 } 594 return false; 595 } 596 597 bool IsDataset(const NodeDef& node) { 598 const string& op = node.op(); 599 // See `GetNodeClassForOp` in core/graph/graph.cc. 600 return op == "IteratorGetNext" || op == "IteratorGetNextSync" || 601 op == "DatasetToSingleElement" || op == "ReduceDataset"; 602 } 603 604 bool IsStateful(const NodeDef node, const OpRegistryInterface* op_registry) { 605 const OpDef* op_def = nullptr; 606 const string& op_name = node.op(); 607 Status status = op_registry->LookUpOpDef(op_name, &op_def); 608 if (!status.ok()) { 609 LOG(WARNING) << "Failed to lookup OpDef for " << op_name 610 << ". Error: " << status.error_message(); 611 return false; 612 } 613 return op_def->is_stateful(); 614 } 615 616 bool IsStateful(const NodeDef node) { 617 return IsStateful(node, OpRegistry::Global()); 618 } 619 620 bool IsFreeOfSideEffect(const NodeDef& node, 621 const OpRegistryInterface* op_registry) { 622 // Placeholders must be preserved to keep the graph feedable. 623 if (IsPlaceholder(node)) { 624 return false; 625 } 626 const OpDef* op_def = nullptr; 627 const string& op_name = node.op(); 628 Status status = op_registry->LookUpOpDef(op_name, &op_def); 629 if (!status.ok()) { 630 return false; 631 } 632 if (op_def->is_stateful()) { 633 return false; 634 } 635 // Nodes such as Assign or AssignAdd modify one of their inputs. 636 for (const auto& input : op_def->input_arg()) { 637 if (input.is_ref()) { 638 return false; 639 } 640 } 641 // Queue ops modify the queue which is a side effect. 642 if (node.op().find("Queue") != string::npos) { 643 return false; 644 } 645 // Sending a tensor via a network is a side effect. 646 if (IsSend(node)) { 647 return false; 648 } 649 return !ModifiesInputsInPlace(node); 650 } 651 652 bool IsFreeOfSideEffect(const NodeDef& node) { 653 return IsFreeOfSideEffect(node, OpRegistry::Global()); 654 } 655 656 bool ModifiesInputsInPlace(const NodeDef& node) { 657 // Some nodes do in-place updates on regular tensor inputs. 658 string op_name = node.op(); 659 660 // Ops that modify resource variables effectively modify one of their inputs. 661 if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" || 662 op_name == "AssignSubVariableOp" || op_name == "ResourceScatterUpdate" || 663 op_name == "ResourceScatterAdd" || op_name == "ResourceScatterSub" || 664 op_name == "ResourceScatterMul" || op_name == "ResourceScatterDiv" || 665 op_name == "ResourceScatterMin" || op_name == "ResourceScatterMax") { 666 return false; 667 } 668 669 std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower); 670 if (str_util::StrContains(op_name, "inplace")) { 671 return true; 672 } 673 return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace"); 674 } 675 676 bool ModifiesFrameInfo(const NodeDef& node) { 677 return IsEnter(node) || IsExit(node) || IsNextIteration(node); 678 } 679 680 #define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY) \ 681 bool Is##PROPERTY_CAP(const NodeDef& node) { \ 682 if (node.op() == "Add") { \ 683 /* Workaround for "Add" not being marked is_commutative and */ \ 684 /* is_aggregate. (See cl/173915048). */ \ 685 const auto type = GetDataTypeFromAttr(node, "T"); \ 686 return type != DT_INVALID && type != DT_STRING; \ 687 } \ 688 const OpDef* op_def = nullptr; \ 689 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); \ 690 return status.ok() && op_def->is_##PROPERTY(); \ 691 } 692 693 OPDEF_PROPERTY_HELPER(Aggregate, aggregate) 694 OPDEF_PROPERTY_HELPER(Commutative, commutative) 695 696 bool IsInvolution(const NodeDef& node) { 697 static const gtl::FlatSet<string>* const kInvolutionOps = 698 CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert", 699 "Neg", "LogicalNot"})); 700 return kInvolutionOps->count(node.op()) > 0; 701 } 702 703 bool IsValueAndOrderAndShapePreserving(const NodeDef& node) { 704 if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { 705 return true; 706 } 707 static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps = 708 CHECK_NOTNULL((new const gtl::FlatSet<string>{ 709 "CheckNumerics", 710 "DebugGradientIdentity", 711 "DeepCopy" 712 "Enter", 713 "Exit", 714 "PreventGradient", 715 "Print", 716 "Snapshot", 717 "StopGradient", 718 })); 719 return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 || 720 IsIdentity(node); 721 } 722 723 bool IsValueAndOrderPreserving(const NodeDef& node) { 724 if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { 725 return true; 726 } 727 static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps = 728 CHECK_NOTNULL((new const gtl::FlatSet<string>{ 729 "ExpandDims", 730 "Reshape", 731 "Squeeze", 732 })); 733 return kValueAndOrderPreservingOps->count(node.op()) > 0 || 734 IsValueAndOrderAndShapePreserving(node); 735 } 736 737 bool IsValuePreserving(const NodeDef& node) { 738 static const gtl::FlatSet<string>* const kValuePreservingOps = 739 CHECK_NOTNULL((new gtl::FlatSet<string>{ 740 "InvertPermutation", 741 "Reverse", 742 "ReverseV2", 743 "Roll", 744 "Transpose", 745 "DepthToSpace", 746 "SpaceToDepth", 747 "BatchToSpace", 748 "BatchToSpaceND", 749 "SpaceToBatch", 750 "SpaceToBatchND", 751 })); 752 return IsValueAndOrderPreserving(node) || 753 kValuePreservingOps->count(node.op()) > 0; 754 } 755 756 bool IsUnaryElementWise(const NodeDef& node) { 757 static const gtl::FlatSet<string>* const kElementWiseOps = 758 CHECK_NOTNULL((new gtl::FlatSet<string>{ 759 "Abs", 760 "Acos", 761 "Acosh", 762 "Asin", 763 "Asinh", 764 "Atan", 765 "Atanh", 766 "Ceil", 767 "ComplexAbs", 768 "Conj", 769 "Cos", 770 "Cosh", 771 "Digamma", 772 "Elu" 773 "Erf", 774 "Erfc", 775 "Exp", 776 "Expm1", 777 "Floor", 778 "Inv", 779 "Invert", 780 "Isinf", 781 "Isnan", 782 "Isfinite", 783 "Lgamma", 784 "Log", 785 "Log1p", 786 "LogicalNot", 787 "Neg", 788 "Reciprocal", 789 "Relu", 790 "Relu6", 791 "Rint", 792 "Round", 793 "Selu", 794 "Rsqrt", 795 "Sigmoid", 796 "Sign", 797 "Sin", 798 "SinH", 799 "Softplus", 800 "Softsign", 801 "Sqrt", 802 "Square", 803 "Tan" 804 "Tanh", 805 })); 806 return kElementWiseOps->count(node.op()) > 0 || 807 IsValueAndOrderAndShapePreserving(node); 808 } 809 810 bool HasOpDef(const NodeDef& node) { 811 const OpDef* op_def = nullptr; 812 return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok(); 813 } 814 815 bool IsIdempotent(const NodeDef& node) { 816 return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) && 817 !ModifiesFrameInfo(node); 818 } 819 820 bool NeverForwardsInputs(const NodeDef& node) { 821 static const gtl::FlatSet<string>* const kNonForwardingOps = CHECK_NOTNULL( 822 (new gtl::FlatSet<string>{"ArgMax", 823 "ArgMin", 824 "AudioSpectrogram", 825 "BatchMatMul", 826 "BatchToSpace", 827 "BatchToSpaceND", 828 "Bincount", 829 "BroadcastArgs", 830 "BroadcastGradientArgs", 831 "CTCBeamSearchDecoder", 832 "CTCGreedyDecoder", 833 "CTCLoss", 834 "ComplexAbs", 835 "Concat", 836 "ConcatOffset", 837 "ConcatV2", 838 "Copy", 839 "CopyHost", 840 "Cross", 841 "CudnnRNN", 842 "CudnnRNNBackprop", 843 "CudnnRNNBackpropV2", 844 "CudnnRNNBackpropV3", 845 "CudnnRNNCanonicalToParams", 846 "CudnnRNNParamsSize", 847 "CudnnRNNParamsToCanonical", 848 "CudnnRNNV2", 849 "CudnnRNNV3", 850 "CumSum", 851 "CumProd", 852 "DebugNanCount", 853 "DebugNumericSummary", 854 "DecodeProtoV2", 855 "DecodeWav", 856 "DeepCopy", 857 "DepthToSpace", 858 "Dequantize", 859 "Diag", 860 "DiagPart", 861 "EditDistance", 862 "Empty", 863 "EncodeProtoV2", 864 "EncodeWav", 865 "ExtractImagePatches", 866 "ExtractVolumePatches", 867 "Fill", 868 "Gather", 869 "GatherNd", 870 "GatherV2", 871 "HistogramFixedWidth", 872 "InvertPermutation", 873 "IsInf", 874 "IsNan", 875 "Isfinite", 876 "LinSpace", 877 "LowerBound", 878 "MatMul", 879 "MatrixDiag", 880 "MatrixDiagPart", 881 "Mfcc", 882 "OneHot", 883 "Pack", 884 "PopulationCount", 885 "Range", 886 "Rank", 887 "ReverseSequence", 888 "Shape", 889 "ShapeN", 890 "Size", 891 "SpaceToBatch", 892 "SpaceToBatchND", 893 "SpaceToDepth", 894 "SparseMatMul", 895 "Split", 896 "SplitV", 897 "Unique", 898 "UniqueV2", 899 "UniqueWithCounts", 900 "UniqueWithCountsV2", 901 "Unpack", 902 "UnravelIndex", 903 "UpperBound", 904 "Where", 905 "CompareAndBitpack", 906 "Requantize", 907 "RequantizationRange", 908 "Bucketize", 909 "AvgPool", 910 "BatchNormWithGlobalNormalization", 911 "FusedBatchNorm", 912 "FusedBatchNormV2", 913 "Conv2D", 914 "RandomUniform", 915 "RandomUniformInt", 916 "RandomStandardNormal", 917 "ParameterizedTruncatedNormal", 918 "TruncatedNormal", 919 "Multinomial", 920 "RandomGamma", 921 "RandomPoisson", 922 "RandomPoissonV2"})); 923 const string& op_name = node.op(); 924 return kNonForwardingOps->count(op_name) > 0 || 925 str_util::StrContains(op_name, "Segment") || 926 str_util::StartsWith(op_name, "Quantize"); 927 } 928 929 } // namespace grappler 930 } // end namespace tensorflow 931