1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <unordered_set> 17 18 #include "tensorflow/core/framework/attr_value.pb.h" 19 #include "tensorflow/core/framework/op.h" 20 #include "tensorflow/core/framework/types.h" 21 #include "tensorflow/core/grappler/op_types.h" 22 #include "tensorflow/core/grappler/utils.h" 23 #include "tensorflow/core/lib/core/status.h" 24 25 namespace tensorflow { 26 namespace grappler { 27 28 bool IsAdd(const NodeDef& node) { 29 if (node.op() == "AddV2" || node.op() == "Add") { 30 DataType type = node.attr().at("T").type(); 31 return type != DT_STRING; 32 } 33 return false; 34 } 35 36 bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; } 37 38 bool IsAll(const NodeDef& node) { return node.op() == "All"; } 39 40 bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; } 41 42 bool IsAny(const NodeDef& node) { return node.op() == "Any"; } 43 44 bool IsAnyDiv(const NodeDef& node) { 45 return node.op() == "RealDiv" || node.op() == "Div" || 46 node.op() == "FloorDiv" || node.op() == "TruncateDiv"; 47 } 48 49 bool IsApproximateEqual(const NodeDef& node) { 50 return node.op() == "ApproximateEqual"; 51 } 52 53 bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; } 54 55 bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; } 56 57 bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; } 58 59 bool IsBetainc(const NodeDef& node) { return node.op() == "Betainc"; } 60 61 bool IsBiasAdd(const NodeDef& node) { 62 return node.op() == "BiasAdd" || node.op() == "BiasAddV1"; 63 } 64 65 bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; } 66 67 bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; } 68 69 bool IsCast(const NodeDef& node) { return node.op() == "Cast"; } 70 71 bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; } 72 73 bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; } 74 75 bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; } 76 77 bool IsConstant(const NodeDef& node) { return node.op() == "Const"; } 78 79 bool IsConj(const NodeDef& node) { return node.op() == "Conj"; } 80 81 bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; } 82 83 bool IsConv2DBackpropFilter(const NodeDef& node) { 84 return node.op() == "Conv2DBackpropFilter"; 85 } 86 87 bool IsConv2DBackpropInput(const NodeDef& node) { 88 return node.op() == "Conv2DBackpropInput"; 89 } 90 91 bool IsDepthwiseConv2dNative(const NodeDef& node) { 92 return node.op() == "DepthwiseConv2dNative"; 93 } 94 95 bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) { 96 return node.op() == "DepthwiseConv2dNativeBackpropFilter"; 97 } 98 99 bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) { 100 return node.op() == "DepthwiseConv2dNativeBackpropInput"; 101 } 102 103 bool IsDequeueOp(const NodeDef& node) { 104 const auto& op = node.op(); 105 return op == "QueueDequeueManyV2" || op == "QueueDequeueMany" || 106 op == "QueueDequeueV2" || op == "QueueDequeue" || 107 op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo"; 108 } 109 110 bool IsDiv(const NodeDef& node) { return node.op() == "Div"; } 111 112 bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; } 113 114 bool IsEnter(const NodeDef& node) { 115 const auto& op = node.op(); 116 return op == "Enter" || op == "RefEnter"; 117 } 118 119 bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; } 120 121 bool IsExit(const NodeDef& node) { 122 const auto& op = node.op(); 123 return op == "Exit" || op == "RefExit"; 124 } 125 126 bool IsFill(const NodeDef& node) { return node.op() == "Fill"; } 127 128 bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; } 129 130 bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; } 131 132 bool IsFusedBatchNormGrad(const NodeDef& node) { 133 const auto& op = node.op(); 134 return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2"; 135 } 136 137 bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; } 138 139 bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; } 140 141 bool IsHistogramSummary(const NodeDef& node) { 142 return node.op() == "HistogramSummary"; 143 } 144 145 bool IsIdentity(const NodeDef& node) { 146 const auto& op = node.op(); 147 return op == "Identity" || op == "RefIdentity"; 148 } 149 150 bool IsIdentityN(const NodeDef& node) { 151 const auto& op = node.op(); 152 return op == "IdentityN"; 153 } 154 155 bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; } 156 157 bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; } 158 159 bool IsImag(const NodeDef& node) { return node.op() == "Imag"; } 160 161 bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; } 162 163 bool IsLess(const NodeDef& node) { return node.op() == "Less"; } 164 165 bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; } 166 167 bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; } 168 169 bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; } 170 171 bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; } 172 173 bool IsMatMul(const NodeDef& node) { 174 const auto& op = node.op(); 175 return op == "MatMul" || op == "BatchMatMul" || op == "QuantizedMatMul" || 176 op == "SparseMatMul"; 177 } 178 179 bool IsMax(const NodeDef& node) { return node.op() == "Max"; } 180 181 bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; } 182 183 bool IsMean(const NodeDef& node) { return node.op() == "Mean"; } 184 185 bool IsMerge(const NodeDef& node) { 186 const auto& op = node.op(); 187 return op == "Merge" || op == "RefMerge"; 188 } 189 190 bool IsMin(const NodeDef& node) { return node.op() == "Min"; } 191 192 bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; } 193 194 bool IsMirrorPad(const NodeDef& node) { return node.op() == "MirrorPad"; } 195 196 bool IsMirrorPadGrad(const NodeDef& node) { 197 return node.op() == "MirrorPadGrad"; 198 } 199 200 bool IsMod(const NodeDef& node) { return node.op() == "Mod"; } 201 202 bool IsMul(const NodeDef& node) { return node.op() == "Mul"; } 203 204 bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; } 205 206 bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; } 207 208 bool IsNextIteration(const NodeDef& node) { 209 const auto& op = node.op(); 210 return op == "NextIteration" || op == "RefNextIteration"; 211 } 212 213 bool IsPad(const NodeDef& node) { 214 const auto& op = node.op(); 215 return op == "Pad" || op == "PadV2"; 216 } 217 218 bool IsPlaceholder(const NodeDef& node) { 219 const auto& op = node.op(); 220 return op == "Placeholder" || op == "PlaceholderV2" || 221 op == "PlaceholderWithDefault"; 222 } 223 224 bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; } 225 226 bool IsPow(const NodeDef& node) { return node.op() == "Pow"; } 227 228 bool IsProd(const NodeDef& node) { return node.op() == "Prod"; } 229 230 bool IsReal(const NodeDef& node) { return node.op() == "Real"; } 231 232 bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; } 233 234 bool IsReciprocalGrad(const NodeDef& node) { 235 return node.op() == "ReciprocalGrad"; 236 } 237 238 bool IsRecv(const NodeDef& node) { 239 return node.op() == "_Recv" || node.op() == "_HostRecv"; 240 } 241 242 bool IsReduction(const NodeDef& node) { 243 const auto& op = node.op(); 244 return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" || 245 op == "Mean" || op == "Any" || op == "All"; 246 } 247 248 bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; } 249 250 bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; } 251 252 bool IsReshape(const NodeDef& node) { return (node.op() == "Reshape"); } 253 254 bool IsRestore(const NodeDef& node) { 255 return (node.op() == "Restore" || node.op() == "RestoreV2" || 256 node.op() == "RestoreSlice"); 257 } 258 259 bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; } 260 261 bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; } 262 263 bool IsSelect(const NodeDef& node) { return node.op() == "Select"; } 264 265 bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; } 266 267 bool IsSend(const NodeDef& node) { 268 return node.op() == "_Send" || node.op() == "_HostSend"; 269 } 270 271 bool IsShape(const NodeDef& node) { return node.op() == "Shape"; } 272 273 bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; } 274 275 bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; } 276 277 bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; } 278 279 bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; } 280 281 bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad"; } 282 283 bool IsSplit(const NodeDef& node) { return node.op() == "Split"; } 284 285 bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; } 286 287 bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; } 288 289 bool IsSquaredDifference(const NodeDef& node) { 290 return node.op() == "SquaredDifference"; 291 } 292 293 bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze"; } 294 295 bool IsStopGradient(const NodeDef& node) { 296 const auto& op = node.op(); 297 return op == "StopGradient" || op == "PreventGradient"; 298 } 299 300 bool IsStridedSlice(const NodeDef& node) { return node.op() == "StridedSlice"; } 301 302 bool IsStridedSliceGrad(const NodeDef& node) { 303 return node.op() == "StridedSliceGrad"; 304 } 305 306 bool IsSub(const NodeDef& node) { return node.op() == "Sub"; } 307 308 bool IsSum(const NodeDef& node) { return node.op() == "Sum"; } 309 310 bool IsSwitch(const NodeDef& node) { 311 const auto& op = node.op(); 312 return op == "Switch" || op == "RefSwitch"; 313 } 314 315 bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; } 316 317 bool IsTile(const NodeDef& node) { return node.op() == "Tile"; } 318 319 bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; } 320 321 bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; } 322 323 bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; } 324 325 bool IsVariable(const NodeDef& node) { 326 const auto& op = node.op(); 327 return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" || 328 op == "VarHandleOp" || op == "ReadVariableOp"; 329 } 330 331 bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; } 332 333 namespace { 334 bool GetBoolAttr(const NodeDef& node, const string& name) { 335 return node.attr().count(name) > 0 && node.attr().at(name).b(); 336 } 337 } // namespace 338 339 bool IsPersistent(const NodeDef& node) { 340 return IsConstant(node) || IsVariable(node); 341 } 342 343 bool IsFreeOfSideEffect(const NodeDef& node) { 344 // Placeholders must be preserved to keep the graph feedable. 345 if (IsPlaceholder(node)) { 346 return false; 347 } 348 const OpDef* op_def = nullptr; 349 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); 350 if (!status.ok()) { 351 return false; 352 } 353 if (op_def->is_stateful()) { 354 return false; 355 } 356 // Nodes such as Assign or AssignAdd modify one of their inputs. 357 for (const auto& input : op_def->input_arg()) { 358 if (input.is_ref()) { 359 return false; 360 } 361 } 362 // Some nodes do in-place updates on regular tensor inputs. 363 if (GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace")) { 364 return false; 365 } 366 return true; 367 } 368 369 bool ModifiesFrameInfo(const NodeDef& node) { 370 return IsEnter(node) || IsExit(node) || IsNextIteration(node); 371 } 372 373 #define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY) \ 374 bool Is##PROPERTY_CAP(const NodeDef& node) { \ 375 if (node.op() == "Add") { \ 376 /* Workaround for "Add" not being marked is_commutative and */ \ 377 /* is_aggregate. (See cl/173915048). */ \ 378 const auto type = GetDataTypeFromAttr(node, "T"); \ 379 return type != DT_INVALID && type != DT_STRING; \ 380 } \ 381 const OpDef* op_def = nullptr; \ 382 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); \ 383 return status.ok() && op_def->is_##PROPERTY(); \ 384 } 385 386 OPDEF_PROPERTY_HELPER(Aggregate, aggregate) 387 OPDEF_PROPERTY_HELPER(Commutative, commutative) 388 389 bool IsInvolution(const NodeDef& node) { 390 const std::unordered_set<string> involution_ops{ 391 "Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"}; 392 return involution_ops.count(node.op()) > 0; 393 } 394 395 bool IsValuePreserving(const NodeDef& node) { 396 if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { 397 return true; 398 } 399 const std::unordered_set<string> value_preserving_ops{ 400 "Transpose", "Reshape", "Identity", "InvertPermutation", 401 "Reverse", "StopGradient", "PreventGradient", "CheckNumerics", 402 "ExpandDims", "Squeeze"}; 403 return value_preserving_ops.count(node.op()) > 0; 404 } 405 406 bool HasOpDef(const NodeDef& node) { 407 const OpDef* op_def = nullptr; 408 return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok(); 409 } 410 411 } // namespace grappler 412 } // end namespace tensorflow 413