Home | History | Annotate | Download | only in grappler
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <unordered_set>
     17 
     18 #include "tensorflow/core/framework/attr_value.pb.h"
     19 #include "tensorflow/core/framework/op.h"
     20 #include "tensorflow/core/framework/types.h"
     21 #include "tensorflow/core/grappler/op_types.h"
     22 #include "tensorflow/core/grappler/utils.h"
     23 #include "tensorflow/core/lib/core/status.h"
     24 
     25 namespace tensorflow {
     26 namespace grappler {
     27 
     28 bool IsAdd(const NodeDef& node) {
     29   if (node.op() == "AddV2" || node.op() == "Add") {
     30     DataType type = node.attr().at("T").type();
     31     return type != DT_STRING;
     32   }
     33   return false;
     34 }
     35 
     36 bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
     37 
     38 bool IsAll(const NodeDef& node) { return node.op() == "All"; }
     39 
     40 bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; }
     41 
     42 bool IsAny(const NodeDef& node) { return node.op() == "Any"; }
     43 
     44 bool IsAnyDiv(const NodeDef& node) {
     45   return node.op() == "RealDiv" || node.op() == "Div" ||
     46          node.op() == "FloorDiv" || node.op() == "TruncateDiv";
     47 }
     48 
     49 bool IsApproximateEqual(const NodeDef& node) {
     50   return node.op() == "ApproximateEqual";
     51 }
     52 
     53 bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
     54 
     55 bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
     56 
     57 bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }
     58 
     59 bool IsBetainc(const NodeDef& node) { return node.op() == "Betainc"; }
     60 
     61 bool IsBiasAdd(const NodeDef& node) {
     62   return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
     63 }
     64 
     65 bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
     66 
     67 bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
     68 
     69 bool IsCast(const NodeDef& node) { return node.op() == "Cast"; }
     70 
     71 bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
     72 
     73 bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
     74 
     75 bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
     76 
     77 bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
     78 
     79 bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
     80 
     81 bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
     82 
     83 bool IsConv2DBackpropFilter(const NodeDef& node) {
     84   return node.op() == "Conv2DBackpropFilter";
     85 }
     86 
     87 bool IsConv2DBackpropInput(const NodeDef& node) {
     88   return node.op() == "Conv2DBackpropInput";
     89 }
     90 
     91 bool IsDepthwiseConv2dNative(const NodeDef& node) {
     92   return node.op() == "DepthwiseConv2dNative";
     93 }
     94 
     95 bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) {
     96   return node.op() == "DepthwiseConv2dNativeBackpropFilter";
     97 }
     98 
     99 bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) {
    100   return node.op() == "DepthwiseConv2dNativeBackpropInput";
    101 }
    102 
    103 bool IsDequeueOp(const NodeDef& node) {
    104   const auto& op = node.op();
    105   return op == "QueueDequeueManyV2" || op == "QueueDequeueMany" ||
    106          op == "QueueDequeueV2" || op == "QueueDequeue" ||
    107          op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo";
    108 }
    109 
    110 bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
    111 
    112 bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
    113 
    114 bool IsEnter(const NodeDef& node) {
    115   const auto& op = node.op();
    116   return op == "Enter" || op == "RefEnter";
    117 }
    118 
    119 bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; }
    120 
    121 bool IsExit(const NodeDef& node) {
    122   const auto& op = node.op();
    123   return op == "Exit" || op == "RefExit";
    124 }
    125 
    126 bool IsFill(const NodeDef& node) { return node.op() == "Fill"; }
    127 
    128 bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
    129 
    130 bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
    131 
    132 bool IsFusedBatchNormGrad(const NodeDef& node) {
    133   const auto& op = node.op();
    134   return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2";
    135 }
    136 
    137 bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
    138 
    139 bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
    140 
    141 bool IsHistogramSummary(const NodeDef& node) {
    142   return node.op() == "HistogramSummary";
    143 }
    144 
    145 bool IsIdentity(const NodeDef& node) {
    146   const auto& op = node.op();
    147   return op == "Identity" || op == "RefIdentity";
    148 }
    149 
    150 bool IsIdentityN(const NodeDef& node) {
    151   const auto& op = node.op();
    152   return op == "IdentityN";
    153 }
    154 
    155 bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; }
    156 
    157 bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; }
    158 
    159 bool IsImag(const NodeDef& node) { return node.op() == "Imag"; }
    160 
    161 bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }
    162 
    163 bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
    164 
    165 bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
    166 
    167 bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
    168 
    169 bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
    170 
    171 bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
    172 
    173 bool IsMatMul(const NodeDef& node) {
    174   const auto& op = node.op();
    175   return op == "MatMul" || op == "BatchMatMul" || op == "QuantizedMatMul" ||
    176          op == "SparseMatMul";
    177 }
    178 
    179 bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
    180 
    181 bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; }
    182 
    183 bool IsMean(const NodeDef& node) { return node.op() == "Mean"; }
    184 
    185 bool IsMerge(const NodeDef& node) {
    186   const auto& op = node.op();
    187   return op == "Merge" || op == "RefMerge";
    188 }
    189 
    190 bool IsMin(const NodeDef& node) { return node.op() == "Min"; }
    191 
    192 bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; }
    193 
    194 bool IsMirrorPad(const NodeDef& node) { return node.op() == "MirrorPad"; }
    195 
    196 bool IsMirrorPadGrad(const NodeDef& node) {
    197   return node.op() == "MirrorPadGrad";
    198 }
    199 
    200 bool IsMod(const NodeDef& node) { return node.op() == "Mod"; }
    201 
    202 bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
    203 
    204 bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
    205 
    206 bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; }
    207 
    208 bool IsNextIteration(const NodeDef& node) {
    209   const auto& op = node.op();
    210   return op == "NextIteration" || op == "RefNextIteration";
    211 }
    212 
    213 bool IsPad(const NodeDef& node) {
    214   const auto& op = node.op();
    215   return op == "Pad" || op == "PadV2";
    216 }
    217 
    218 bool IsPlaceholder(const NodeDef& node) {
    219   const auto& op = node.op();
    220   return op == "Placeholder" || op == "PlaceholderV2" ||
    221          op == "PlaceholderWithDefault";
    222 }
    223 
    224 bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; }
    225 
    226 bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
    227 
    228 bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
    229 
    230 bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
    231 
    232 bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
    233 
    234 bool IsReciprocalGrad(const NodeDef& node) {
    235   return node.op() == "ReciprocalGrad";
    236 }
    237 
    238 bool IsRecv(const NodeDef& node) {
    239   return node.op() == "_Recv" || node.op() == "_HostRecv";
    240 }
    241 
    242 bool IsReduction(const NodeDef& node) {
    243   const auto& op = node.op();
    244   return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" ||
    245          op == "Mean" || op == "Any" || op == "All";
    246 }
    247 
    248 bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
    249 
    250 bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; }
    251 
    252 bool IsReshape(const NodeDef& node) { return (node.op() == "Reshape"); }
    253 
    254 bool IsRestore(const NodeDef& node) {
    255   return (node.op() == "Restore" || node.op() == "RestoreV2" ||
    256           node.op() == "RestoreSlice");
    257 }
    258 
    259 bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; }
    260 
    261 bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; }
    262 
    263 bool IsSelect(const NodeDef& node) { return node.op() == "Select"; }
    264 
    265 bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; }
    266 
    267 bool IsSend(const NodeDef& node) {
    268   return node.op() == "_Send" || node.op() == "_HostSend";
    269 }
    270 
    271 bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
    272 
    273 bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; }
    274 
    275 bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; }
    276 
    277 bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }
    278 
    279 bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; }
    280 
    281 bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad"; }
    282 
    283 bool IsSplit(const NodeDef& node) { return node.op() == "Split"; }
    284 
    285 bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; }
    286 
    287 bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; }
    288 
    289 bool IsSquaredDifference(const NodeDef& node) {
    290   return node.op() == "SquaredDifference";
    291 }
    292 
    293 bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze"; }
    294 
    295 bool IsStopGradient(const NodeDef& node) {
    296   const auto& op = node.op();
    297   return op == "StopGradient" || op == "PreventGradient";
    298 }
    299 
    300 bool IsStridedSlice(const NodeDef& node) { return node.op() == "StridedSlice"; }
    301 
    302 bool IsStridedSliceGrad(const NodeDef& node) {
    303   return node.op() == "StridedSliceGrad";
    304 }
    305 
    306 bool IsSub(const NodeDef& node) { return node.op() == "Sub"; }
    307 
    308 bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
    309 
    310 bool IsSwitch(const NodeDef& node) {
    311   const auto& op = node.op();
    312   return op == "Switch" || op == "RefSwitch";
    313 }
    314 
    315 bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
    316 
    317 bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
    318 
    319 bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
    320 
    321 bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
    322 
    323 bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
    324 
    325 bool IsVariable(const NodeDef& node) {
    326   const auto& op = node.op();
    327   return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
    328          op == "VarHandleOp" || op == "ReadVariableOp";
    329 }
    330 
    331 bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; }
    332 
    333 namespace {
    334 bool GetBoolAttr(const NodeDef& node, const string& name) {
    335   return node.attr().count(name) > 0 && node.attr().at(name).b();
    336 }
    337 }  // namespace
    338 
    339 bool IsPersistent(const NodeDef& node) {
    340   return IsConstant(node) || IsVariable(node);
    341 }
    342 
    343 bool IsFreeOfSideEffect(const NodeDef& node) {
    344   // Placeholders must be preserved to keep the graph feedable.
    345   if (IsPlaceholder(node)) {
    346     return false;
    347   }
    348   const OpDef* op_def = nullptr;
    349   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
    350   if (!status.ok()) {
    351     return false;
    352   }
    353   if (op_def->is_stateful()) {
    354     return false;
    355   }
    356   // Nodes such as Assign or AssignAdd modify one of their inputs.
    357   for (const auto& input : op_def->input_arg()) {
    358     if (input.is_ref()) {
    359       return false;
    360     }
    361   }
    362   // Some nodes do in-place updates on regular tensor inputs.
    363   if (GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace")) {
    364     return false;
    365   }
    366   return true;
    367 }
    368 
    369 bool ModifiesFrameInfo(const NodeDef& node) {
    370   return IsEnter(node) || IsExit(node) || IsNextIteration(node);
    371 }
    372 
    373 #define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY)                      \
    374   bool Is##PROPERTY_CAP(const NodeDef& node) {                             \
    375     if (node.op() == "Add") {                                              \
    376       /* Workaround for "Add" not being marked is_commutative and */       \
    377       /* is_aggregate. (See cl/173915048). */                              \
    378       const auto type = GetDataTypeFromAttr(node, "T");                    \
    379       return type != DT_INVALID && type != DT_STRING;                      \
    380     }                                                                      \
    381     const OpDef* op_def = nullptr;                                         \
    382     Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); \
    383     return status.ok() && op_def->is_##PROPERTY();                         \
    384   }
    385 
    386 OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
    387 OPDEF_PROPERTY_HELPER(Commutative, commutative)
    388 
    389 bool IsInvolution(const NodeDef& node) {
    390   const std::unordered_set<string> involution_ops{
    391       "Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"};
    392   return involution_ops.count(node.op()) > 0;
    393 }
    394 
    395 bool IsValuePreserving(const NodeDef& node) {
    396   if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
    397     return true;
    398   }
    399   const std::unordered_set<string> value_preserving_ops{
    400       "Transpose",  "Reshape",      "Identity",        "InvertPermutation",
    401       "Reverse",    "StopGradient", "PreventGradient", "CheckNumerics",
    402       "ExpandDims", "Squeeze"};
    403   return value_preserving_ops.count(node.op()) > 0;
    404 }
    405 
    406 bool HasOpDef(const NodeDef& node) {
    407   const OpDef* op_def = nullptr;
    408   return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok();
    409 }
    410 
    411 }  // namespace grappler
    412 }  // end namespace tensorflow
    413