1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // ALGORITHM OVERVIEW 17 // ================== 18 // 19 // An XLA cluster hoists all resource reads to be beginning of the cluster 20 // execution and all the resource writes to the end. This means it cannot 21 // enforce arbitrary ordering dependencies (via control or data edges) between 22 // resource operations. Since all resource reads happen before all resource 23 // writes, edges constraining resource reads to happen before resource writes 24 // are fine, but all other kinds of edges are problematic. This analysis 25 // computes the set of pairs of resource operations that cannot be put in the 26 // same cluster because XLA cannot respect the dependencies between them in the 27 // TensorFlow program. 28 // 29 // TODO(b/112856632): We can, in theory, support Read->Read and Write->Write 30 // dependencies. 31 // 32 // Specifically the result computed by this analysis contains the edge {W, R} 33 // iff all of these hold true: 34 // 35 // - In the graph (g - {edges from NextIteration to Merge}) there is a path 36 // from W to R. 37 // - IsEdgeSafe(W, R) == False [defined below] 38 // - W != R (note: some resource operations both read from and write to 39 // resource variables). 40 // 41 // The result is incorrect around loops because we ignore edges from 42 // NextIteration to Merge. For instance, in: 43 // 44 // Init -----> Merge <-------+ 45 // | | 46 // v | 47 // Read | 48 // | | 49 // v | 50 // Write | 51 // | | 52 // v | 53 // NextIteration --+ 54 // 55 // we won't put (Read, Write) in the returned set. This is fine if 56 // auto-clustering can only cluster the Read->Write edge, but it is a problem if 57 // it clusters the Write->NextIteration->Merge->Read edges instead. So we rely 58 // on auto-clustering to not cluster NextIteration->Merge edges. The same 59 // problem is present for the functional version of the loop above and we also 60 // rely on auto-clustering not clustering functional while loops containing 61 // resource operations. 62 // 63 // One way to think about this is that we only care about cases where two nodes, 64 // A and B, would normally have been put in the same cluster but cannot legally 65 // be in the same cluster because of resourcevar-dependencies. If A and B would 66 // normally have been put in the same cluster then all paths between A and B 67 // would have to be clusterable (otherwise we'd have introduced a cycle). Ergo 68 // there could not have been a NextIteration->Merge edge between A and B since 69 // we don't cluster these edges. 70 // 71 // IMPLEMENTATION 72 // -------------- 73 // 74 // We traverse the graph minus backedges in reverse post order, mapping each 75 // node to the set of resource operation reaching that node. Since we visit 76 // producers before consumers, we can construct the set of reaching operations 77 // by taking the union of the operations reaching the input nodes. These 78 // "reaching resource operations" can then be used to create the pairs of 79 // incompatible nodes using `IsEdgeSafe`. 80 81 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" 82 83 #include "absl/container/flat_hash_set.h" 84 #include "absl/memory/memory.h" 85 #include "absl/strings/str_join.h" 86 #include "absl/types/optional.h" 87 #include "tensorflow/compiler/tf2xla/resource_operation_table.h" 88 #include "tensorflow/core/framework/node_def.pb.h" 89 #include "tensorflow/core/graph/algorithm.h" 90 #include "tensorflow/core/graph/tensor_id.h" 91 #include "tensorflow/core/lib/hash/hash.h" 92 #include "tensorflow/core/util/ptr_util.h" 93 94 namespace tensorflow { 95 namespace { 96 // Returns true if `n` may call a function. 97 Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def, 98 bool* out_result) { 99 if (flib_def->Contains(n.type_string())) { 100 *out_result = true; 101 } else { 102 *out_result = 103 std::any_of(n.def().attr().begin(), n.def().attr().end(), 104 [](const std::pair<string, AttrValue>& name_attr_pair) { 105 return name_attr_pair.second.has_func(); 106 }); 107 } 108 109 return Status::OK(); 110 } 111 112 // Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is 113 // not a resource operation recognized by XLA then sets `out_resource_op_kind` 114 // to nullopt. 115 Status XlaResourceOpKindForNode( 116 const Node& n, const FunctionLibraryDefinition* flib_def, 117 const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore, 118 absl::optional<XlaResourceOpKind>* out_resource_op_kind) { 119 bool should_ignore = false; 120 if (resource_ops_to_ignore) { 121 TF_RETURN_IF_ERROR(resource_ops_to_ignore(n, &should_ignore)); 122 } 123 if (should_ignore) { 124 *out_resource_op_kind = absl::nullopt; 125 return Status::OK(); 126 } 127 128 const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string()); 129 if (op_info) { 130 *out_resource_op_kind = op_info->kind(); 131 return Status::OK(); 132 } 133 134 // We conservatively assume that functions will both read and write resource 135 // variables. In the future we may consider doing some form of 136 // inter-procedural analysis. 137 bool may_call_function; 138 TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function)); 139 if (may_call_function) { 140 *out_resource_op_kind = XlaResourceOpKind::kReadWrite; 141 } else { 142 *out_resource_op_kind = absl::nullopt; 143 } 144 145 return Status::OK(); 146 } 147 148 // Returns true if a control or data dependence from a TensorFlow operation of 149 // resource op kind `from` to a TensorFlow operation of resource op kind `to` 150 // can be represented by an XLA cluster and needs no special handling around 151 // auto-jit. 152 bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) { 153 // XLA clusters force all reads to happen before all writes. Moreover the set 154 // of reads are executed as one atomic operation, and the set of writes are as 155 // another atomic operation. This means we can faithfully represent the 156 // following edges: Read->*, *->Write. 157 158 return from == XlaResourceOpKind::kRead || to == XlaResourceOpKind::kWrite; 159 } 160 161 using ResourceOp = std::pair<int, XlaResourceOpKind>; 162 163 string ResourceOpToString(const ResourceOp& resource_op) { 164 return absl::StrCat( 165 resource_op.first, ": ", 166 XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second)); 167 } 168 169 // A copy-on-write set used to store the set of ResourceOps reaching a node in a 170 // TensorFlow graph. 171 // 172 // TODO(sanjoy): It may be useful to pull this out into its own header at some 173 // point. 174 class ResourceOpSet { 175 private: 176 using Impl = absl::flat_hash_set<ResourceOp>; 177 178 public: 179 ResourceOpSet() = default; 180 181 // Adds all ResourceOp s in `other` to this set. 182 void Add(const ResourceOpSet& other) { 183 CHECK(!frozen_); 184 if (other.impl_ == impl_) { 185 other.frozen_ = true; 186 return; 187 } 188 189 if (!impl_) { 190 other.frozen_ = true; 191 impl_ = other.impl_; 192 return; 193 } 194 195 for (ResourceOp resource_op : other) { 196 Add(resource_op); 197 } 198 } 199 200 void Add(const ResourceOp& resource_op) { 201 CHECK(!frozen_); 202 if (!IsCopy() && Contains(resource_op)) { 203 // We can avoid the copy if the item we want to insert already exists. 204 return; 205 } 206 207 EnsureIsCopied(); 208 impl_->insert(resource_op); 209 } 210 211 Impl::const_iterator begin() const { 212 return impl_ ? impl_->begin() : GetEmptyImpl()->begin(); 213 } 214 215 Impl::const_iterator end() const { 216 return impl_ ? impl_->end() : GetEmptyImpl()->end(); 217 } 218 219 bool Contains(const ResourceOp& resource_op) const { 220 return impl_ != nullptr && impl_->count(resource_op); 221 } 222 223 private: 224 bool IsCopy() const { return storage_ != nullptr; } 225 226 void EnsureIsCopied() { 227 if (storage_ == nullptr) { 228 storage_ = absl::make_unique<Impl>(); 229 for (ResourceOp op : *this) { 230 storage_->insert(op); 231 } 232 impl_ = storage_.get(); 233 } 234 } 235 236 static Impl* GetEmptyImpl() { 237 static Impl* empty_impl = new Impl; 238 return empty_impl; 239 } 240 241 Impl* impl_ = nullptr; 242 std::unique_ptr<Impl> storage_; 243 244 // frozen_ is true if there is another set pointing to this set's impl_. We 245 // can no longer add elements to this set in that case since the sets pointing 246 // to this set expect the contents of this set to be stable. 247 mutable bool frozen_ = false; 248 249 TF_DISALLOW_COPY_AND_ASSIGN(ResourceOpSet); 250 }; 251 252 string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { 253 std::vector<string> elements_debug_string; 254 std::transform(resource_op_set.begin(), resource_op_set.end(), 255 std::back_inserter(elements_debug_string), ResourceOpToString); 256 return absl::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); 257 } 258 259 string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { 260 return absl::StrCat( 261 "[", n.name(), ": ", n.type_string(), "(", 262 XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]"); 263 } 264 } // namespace 265 266 Status ComputeIncompatibleResourceOperationPairs( 267 const Graph& g, const FunctionLibraryDefinition* flib_def, 268 const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore, 269 std::vector<std::pair<int, int>>* result) { 270 CHECK(result->empty()); 271 272 std::vector<Node*> rpo; 273 GetReversePostOrder(g, &rpo, /*stable_comparator=*/NodeComparatorName(), 274 /*edge_filter=*/[](const Edge& edge) { 275 return !edge.src()->IsNextIteration(); 276 }); 277 278 auto resource_op_set_for_node = 279 absl::make_unique<ResourceOpSet[]>(g.num_node_ids()); 280 281 const bool vlog = VLOG_IS_ON(2); 282 283 for (Node* n : rpo) { 284 absl::optional<XlaResourceOpKind> op_kind; 285 TF_RETURN_IF_ERROR(XlaResourceOpKindForNode( 286 *n, flib_def, resource_ops_to_ignore, &op_kind)); 287 288 ResourceOpSet* resource_op_set = &resource_op_set_for_node[n->id()]; 289 290 // Merge the reaching resource operations for all the incoming edges to 291 // create the set of all possible resource ops reaching `n`. 292 for (const Edge* e : n->in_edges()) { 293 if (n->IsMerge() && e->src()->IsNextIteration()) { 294 // Ignore back-edges (see file comment). 295 continue; 296 } 297 298 const ResourceOpSet& incoming_op_set = 299 resource_op_set_for_node[e->src()->id()]; 300 resource_op_set->Add(incoming_op_set); 301 } 302 303 // Add to the "incompatible resource ops" set if necessary. 304 if (op_kind) { 305 for (ResourceOp incoming_op : *resource_op_set) { 306 if (IsEdgeSafe(incoming_op.second, *op_kind)) { 307 continue; 308 } 309 310 if (vlog) { 311 VLOG(2) << "Unsafe edge: " 312 << NodeToString(*g.FindNodeId(incoming_op.first), 313 incoming_op.second) 314 << " -> " << NodeToString(*n, *op_kind); 315 } 316 result->push_back({incoming_op.first, n->id()}); 317 } 318 319 resource_op_set->Add({n->id(), *op_kind}); 320 } 321 322 if (vlog) { 323 VLOG(3) << n->name() << " -> " << ResourceOpSetToString(*resource_op_set); 324 } 325 } 326 327 std::sort(result->begin(), result->end()); 328 CHECK(std::unique(result->begin(), result->end()) == result->end()); 329 330 return Status::OK(); 331 } 332 } // namespace tensorflow 333