Home | History | Annotate | Download | only in jit
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // ALGORITHM OVERVIEW
     17 // ==================
     18 //
     19 // An XLA cluster hoists all resource reads to be beginning of the cluster
     20 // execution and all the resource writes to the end.  This means it cannot
     21 // enforce arbitrary ordering dependencies (via control or data edges) between
     22 // resource operations.  Since all resource reads happen before all resource
     23 // writes, edges constraining resource reads to happen before resource writes
     24 // are fine, but all other kinds of edges are problematic.  This analysis
     25 // computes the set of pairs of resource operations that cannot be put in the
     26 // same cluster because XLA cannot respect the dependencies between them in the
     27 // TensorFlow program.
     28 //
     29 // TODO(b/112856632): We can, in theory, support Read->Read and Write->Write
     30 // dependencies.
     31 //
     32 // Specifically the result computed by this analysis contains the edge {W, R}
     33 // iff all of these hold true:
     34 //
     35 //   - In the graph (g - {edges from NextIteration to Merge}) there is a path
     36 //     from W to R.
     37 //   - IsEdgeSafe(W, R) == False [defined below]
     38 //   - W != R (note: some resource operations both read from and write to
     39 //     resource variables).
     40 //
     41 // The result is incorrect around loops because we ignore edges from
     42 // NextIteration to Merge.  For instance, in:
     43 //
     44 // Init -----> Merge <-------+
     45 //               |           |
     46 //               v           |
     47 //             Read          |
     48 //               |           |
     49 //               v           |
     50 //             Write         |
     51 //               |           |
     52 //               v           |
     53 //           NextIteration --+
     54 //
     55 // we won't put (Read, Write) in the returned set.  This is fine if
     56 // auto-clustering can only cluster the Read->Write edge, but it is a problem if
     57 // it clusters the Write->NextIteration->Merge->Read edges instead.  So we rely
     58 // on auto-clustering to not cluster NextIteration->Merge edges.  The same
     59 // problem is present for the functional version of the loop above and we also
     60 // rely on auto-clustering not clustering functional while loops containing
     61 // resource operations.
     62 //
     63 // One way to think about this is that we only care about cases where two nodes,
     64 // A and B, would normally have been put in the same cluster but cannot legally
     65 // be in the same cluster because of resourcevar-dependencies.  If A and B would
     66 // normally have been put in the same cluster then all paths between A and B
     67 // would have to be clusterable (otherwise we'd have introduced a cycle).  Ergo
     68 // there could not have been a NextIteration->Merge edge between A and B since
     69 // we don't cluster these edges.
     70 //
     71 // IMPLEMENTATION
     72 // --------------
     73 //
     74 // We traverse the graph minus backedges in reverse post order, mapping each
     75 // node to the set of resource operation reaching that node.  Since we visit
     76 // producers before consumers, we can construct the set of reaching operations
     77 // by taking the union of the operations reaching the input nodes.  These
     78 // "reaching resource operations" can then be used to create the pairs of
     79 // incompatible nodes using `IsEdgeSafe`.
     80 
     81 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
     82 
     83 #include "absl/container/flat_hash_set.h"
     84 #include "absl/memory/memory.h"
     85 #include "absl/strings/str_join.h"
     86 #include "absl/types/optional.h"
     87 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
     88 #include "tensorflow/core/framework/node_def.pb.h"
     89 #include "tensorflow/core/graph/algorithm.h"
     90 #include "tensorflow/core/graph/tensor_id.h"
     91 #include "tensorflow/core/lib/hash/hash.h"
     92 #include "tensorflow/core/util/ptr_util.h"
     93 
     94 namespace tensorflow {
     95 namespace {
     96 // Returns true if `n` may call a function.
     97 Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def,
     98                        bool* out_result) {
     99   if (flib_def->Contains(n.type_string())) {
    100     *out_result = true;
    101   } else {
    102     *out_result =
    103         std::any_of(n.def().attr().begin(), n.def().attr().end(),
    104                     [](const std::pair<string, AttrValue>& name_attr_pair) {
    105                       return name_attr_pair.second.has_func();
    106                     });
    107   }
    108 
    109   return Status::OK();
    110 }
    111 
    112 // Maps `n` to the XlaResourceOpKind corresponding to its operation.  If `n` is
    113 // not a resource operation recognized by XLA then sets `out_resource_op_kind`
    114 // to nullopt.
    115 Status XlaResourceOpKindForNode(
    116     const Node& n, const FunctionLibraryDefinition* flib_def,
    117     const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
    118     absl::optional<XlaResourceOpKind>* out_resource_op_kind) {
    119   bool should_ignore = false;
    120   if (resource_ops_to_ignore) {
    121     TF_RETURN_IF_ERROR(resource_ops_to_ignore(n, &should_ignore));
    122   }
    123   if (should_ignore) {
    124     *out_resource_op_kind = absl::nullopt;
    125     return Status::OK();
    126   }
    127 
    128   const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string());
    129   if (op_info) {
    130     *out_resource_op_kind = op_info->kind();
    131     return Status::OK();
    132   }
    133 
    134   // We conservatively assume that functions will both read and write resource
    135   // variables.  In the future we may consider doing some form of
    136   // inter-procedural analysis.
    137   bool may_call_function;
    138   TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function));
    139   if (may_call_function) {
    140     *out_resource_op_kind = XlaResourceOpKind::kReadWrite;
    141   } else {
    142     *out_resource_op_kind = absl::nullopt;
    143   }
    144 
    145   return Status::OK();
    146 }
    147 
    148 // Returns true if a control or data dependence from a TensorFlow operation of
    149 // resource op kind `from` to a TensorFlow operation of resource op kind `to`
    150 // can be represented by an XLA cluster and needs no special handling around
    151 // auto-jit.
    152 bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) {
    153   // XLA clusters force all reads to happen before all writes.  Moreover the set
    154   // of reads are executed as one atomic operation, and the set of writes are as
    155   // another atomic operation.  This means we can faithfully represent the
    156   // following edges: Read->*, *->Write.
    157 
    158   return from == XlaResourceOpKind::kRead || to == XlaResourceOpKind::kWrite;
    159 }
    160 
    161 using ResourceOp = std::pair<int, XlaResourceOpKind>;
    162 
    163 string ResourceOpToString(const ResourceOp& resource_op) {
    164   return absl::StrCat(
    165       resource_op.first, ": ",
    166       XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second));
    167 }
    168 
    169 // A copy-on-write set used to store the set of ResourceOps reaching a node in a
    170 // TensorFlow graph.
    171 //
    172 // TODO(sanjoy): It may be useful to pull this out into its own header at some
    173 // point.
    174 class ResourceOpSet {
    175  private:
    176   using Impl = absl::flat_hash_set<ResourceOp>;
    177 
    178  public:
    179   ResourceOpSet() = default;
    180 
    181   // Adds all ResourceOp s in `other` to this set.
    182   void Add(const ResourceOpSet& other) {
    183     CHECK(!frozen_);
    184     if (other.impl_ == impl_) {
    185       other.frozen_ = true;
    186       return;
    187     }
    188 
    189     if (!impl_) {
    190       other.frozen_ = true;
    191       impl_ = other.impl_;
    192       return;
    193     }
    194 
    195     for (ResourceOp resource_op : other) {
    196       Add(resource_op);
    197     }
    198   }
    199 
    200   void Add(const ResourceOp& resource_op) {
    201     CHECK(!frozen_);
    202     if (!IsCopy() && Contains(resource_op)) {
    203       // We can avoid the copy if the item we want to insert already exists.
    204       return;
    205     }
    206 
    207     EnsureIsCopied();
    208     impl_->insert(resource_op);
    209   }
    210 
    211   Impl::const_iterator begin() const {
    212     return impl_ ? impl_->begin() : GetEmptyImpl()->begin();
    213   }
    214 
    215   Impl::const_iterator end() const {
    216     return impl_ ? impl_->end() : GetEmptyImpl()->end();
    217   }
    218 
    219   bool Contains(const ResourceOp& resource_op) const {
    220     return impl_ != nullptr && impl_->count(resource_op);
    221   }
    222 
    223  private:
    224   bool IsCopy() const { return storage_ != nullptr; }
    225 
    226   void EnsureIsCopied() {
    227     if (storage_ == nullptr) {
    228       storage_ = absl::make_unique<Impl>();
    229       for (ResourceOp op : *this) {
    230         storage_->insert(op);
    231       }
    232       impl_ = storage_.get();
    233     }
    234   }
    235 
    236   static Impl* GetEmptyImpl() {
    237     static Impl* empty_impl = new Impl;
    238     return empty_impl;
    239   }
    240 
    241   Impl* impl_ = nullptr;
    242   std::unique_ptr<Impl> storage_;
    243 
    244   // frozen_ is true if there is another set pointing to this set's impl_.  We
    245   // can no longer add elements to this set in that case since the sets pointing
    246   // to this set expect the contents of this set to be stable.
    247   mutable bool frozen_ = false;
    248 
    249   TF_DISALLOW_COPY_AND_ASSIGN(ResourceOpSet);
    250 };
    251 
    252 string ResourceOpSetToString(const ResourceOpSet& resource_op_set) {
    253   std::vector<string> elements_debug_string;
    254   std::transform(resource_op_set.begin(), resource_op_set.end(),
    255                  std::back_inserter(elements_debug_string), ResourceOpToString);
    256   return absl::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}");
    257 }
    258 
    259 string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) {
    260   return absl::StrCat(
    261       "[", n.name(), ": ", n.type_string(), "(",
    262       XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]");
    263 }
    264 }  // namespace
    265 
    266 Status ComputeIncompatibleResourceOperationPairs(
    267     const Graph& g, const FunctionLibraryDefinition* flib_def,
    268     const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
    269     std::vector<std::pair<int, int>>* result) {
    270   CHECK(result->empty());
    271 
    272   std::vector<Node*> rpo;
    273   GetReversePostOrder(g, &rpo, /*stable_comparator=*/NodeComparatorName(),
    274                       /*edge_filter=*/[](const Edge& edge) {
    275                         return !edge.src()->IsNextIteration();
    276                       });
    277 
    278   auto resource_op_set_for_node =
    279       absl::make_unique<ResourceOpSet[]>(g.num_node_ids());
    280 
    281   const bool vlog = VLOG_IS_ON(2);
    282 
    283   for (Node* n : rpo) {
    284     absl::optional<XlaResourceOpKind> op_kind;
    285     TF_RETURN_IF_ERROR(XlaResourceOpKindForNode(
    286         *n, flib_def, resource_ops_to_ignore, &op_kind));
    287 
    288     ResourceOpSet* resource_op_set = &resource_op_set_for_node[n->id()];
    289 
    290     // Merge the reaching resource operations for all the incoming edges to
    291     // create the set of all possible resource ops reaching `n`.
    292     for (const Edge* e : n->in_edges()) {
    293       if (n->IsMerge() && e->src()->IsNextIteration()) {
    294         // Ignore back-edges (see file comment).
    295         continue;
    296       }
    297 
    298       const ResourceOpSet& incoming_op_set =
    299           resource_op_set_for_node[e->src()->id()];
    300       resource_op_set->Add(incoming_op_set);
    301     }
    302 
    303     // Add to the "incompatible resource ops" set if necessary.
    304     if (op_kind) {
    305       for (ResourceOp incoming_op : *resource_op_set) {
    306         if (IsEdgeSafe(incoming_op.second, *op_kind)) {
    307           continue;
    308         }
    309 
    310         if (vlog) {
    311           VLOG(2) << "Unsafe edge: "
    312                   << NodeToString(*g.FindNodeId(incoming_op.first),
    313                                   incoming_op.second)
    314                   << " -> " << NodeToString(*n, *op_kind);
    315         }
    316         result->push_back({incoming_op.first, n->id()});
    317       }
    318 
    319       resource_op_set->Add({n->id(), *op_kind});
    320     }
    321 
    322     if (vlog) {
    323       VLOG(3) << n->name() << " -> " << ResourceOpSetToString(*resource_op_set);
    324     }
    325   }
    326 
    327   std::sort(result->begin(), result->end());
    328   CHECK(std::unique(result->begin(), result->end()) == result->end());
    329 
    330   return Status::OK();
    331 }
    332 }  // namespace tensorflow
    333