1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h" 17 18 #include "tensorflow/compiler/jit/xla_cluster_util.h" 19 #include "tensorflow/compiler/xla/status_macros.h" 20 #include "tensorflow/core/framework/node_def.pb.h" 21 #include "tensorflow/core/framework/tensor.pb.h" 22 23 namespace tensorflow { 24 25 using se::port::StatusOr; 26 27 string CloneConstantsForBetterClusteringPass::GenerateUniqueName( 28 const absl::flat_hash_set<string>& name_set, absl::string_view prefix) { 29 string candidate; 30 do { 31 candidate = absl::StrCat(prefix, "/clone_", unique_name_counter_++); 32 } while (name_set.contains(candidate)); 33 return candidate; 34 } 35 36 StatusOr<Node*> CloneConstantsForBetterClusteringPass::CloneNode( 37 Graph* g, const absl::flat_hash_set<string>& name_set, Node* n) { 38 NodeDef new_in_def = n->def(); 39 new_in_def.clear_input(); 40 new_in_def.set_name(GenerateUniqueName(name_set, new_in_def.name())); 41 Status s; 42 Node* new_in = g->AddNode(new_in_def, &s); 43 TF_RETURN_IF_ERROR(s); 44 45 for (const Edge* e : n->in_edges()) { 46 if (e->IsControlEdge()) { 47 g->AddControlEdge(e->src(), new_in); 48 } else { 49 g->AddEdge(e->src(), e->src_output(), new_in, e->dst_input()); 50 } 51 } 52 53 new_in->set_assigned_device_name(n->assigned_device_name()); 54 return new_in; 55 } 56 57 namespace { 58 // We only clone host constants for now since we want to avoid increasing memory 59 // pressure on GPUs. 60 StatusOr<bool> IsSmallHostConstant(Node* n) { 61 if (!n->IsConstant()) { 62 return false; 63 } 64 65 DeviceNameUtils::ParsedName parsed; 66 TF_RET_CHECK( 67 DeviceNameUtils::ParseFullName(n->assigned_device_name(), &parsed)); 68 if (parsed.type != DEVICE_CPU) { 69 return false; 70 } 71 72 const TensorProto* proto = nullptr; 73 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto)); 74 75 // TODO(sanjoy): It may make sense to combine this threshold with XLA's "large 76 // constant" threshold, if there is one. 77 const int kSmallTensorThreshold = 16; 78 int64 total_elements = 1; 79 for (const auto& dim : proto->tensor_shape().dim()) { 80 if (dim.size() < 0) { 81 return errors::Internal("Unknown dimension size in constant tensor ", 82 n->name()); 83 } 84 total_elements *= dim.size(); 85 } 86 return total_elements < kSmallTensorThreshold; 87 } 88 89 bool IsInPlaceOp(absl::string_view op_name) { 90 return op_name == "InplaceUpdate" || op_name == "InplaceAdd" || 91 op_name == "InplaceSub"; 92 } 93 } // namespace 94 95 Status CloneConstantsForBetterClusteringPass::CloneSmallHostConstantInputs( 96 Graph* g, const absl::flat_hash_set<string>& name_set, Node* n) { 97 std::vector<const Edge*> in_edges; 98 absl::c_copy(n->in_edges(), std::back_inserter(in_edges)); 99 for (const Edge* e : in_edges) { 100 Node* input = e->src(); 101 TF_ASSIGN_OR_RETURN(bool is_small_host_constant, 102 IsSmallHostConstant(input)); 103 if (is_small_host_constant && input->out_edges().size() != 1) { 104 VLOG(2) << "Cloning small host constant " << input->name(); 105 TF_ASSIGN_OR_RETURN(Node* const input_cloned, 106 CloneNode(g, name_set, input)); 107 if (e->IsControlEdge()) { 108 g->AddControlEdge(input_cloned, e->dst()); 109 } else { 110 int dst_input = e->dst_input(); 111 TF_RET_CHECK(e->src_output() == 0) 112 << "expected constant to have exactly one non-control output, but " 113 "found output index = " 114 << e->src_output(); 115 g->RemoveEdge(e); 116 g->AddEdge(input_cloned, 0, n, dst_input); 117 } 118 } 119 } 120 return Status::OK(); 121 } 122 123 Status CloneConstantsForBetterClusteringPass::Run( 124 const GraphOptimizationPassOptions& options) { 125 if (GetGlobalJitLevel(options) == OptimizerOptions::OFF) { 126 return Status::OK(); 127 } 128 129 Graph* g = options.graph->get(); 130 absl::flat_hash_set<string> name_set; 131 absl::c_transform(g->nodes(), std::inserter(name_set, name_set.begin()), 132 [](Node* n) { return n->name(); }); 133 std::vector<Node*> nodes; 134 for (Node* n : g->nodes()) { 135 // We rely on the immutability of Tensors to safely clone Const operations. 136 // However, "in place" ops do not respect the immutability of Tensors so we 137 // avoid this transformation when such ops are present in the graph. 138 // 139 // In-place operations are problematic because they break the semantic 140 // illusion that tensorflow::Tensor instances are immutable. For instance 141 // if we have the following graph: 142 // 143 // digraph { 144 // SRC -> Const 145 // SRC -> I 146 // SRC -> V 147 // Const -> Identity 148 // Const -> InplaceAdd [label="x"] 149 // I -> InplaceAdd [label="i"] 150 // V -> InplaceAdd [label="v"] 151 // InplaceAdd -> Identity [style=dotted] 152 // } 153 // 154 // then the value produced by `Identity` is Const+I*V since InplaceAdd 155 // modifies the tensor in place. However, if we clone `Const` and turn the 156 // graph into: 157 // 158 // digraph { 159 // SRC -> "Const/clone_1" 160 // SRC -> "Const/clone_2" 161 // SRC -> I 162 // SRC -> V 163 // "Const/clone_1" -> Identity 164 // "Const/clone_2" -> InplaceAdd [label="x"] 165 // I -> InplaceAdd [label="i"] 166 // V -> InplaceAdd [label="v"] 167 // InplaceAdd -> Identity [style=dotted] 168 // } 169 // 170 // then `Identity` no longer produces Const+I*V because the InplaceAdd 171 // operation only modifies Const/clone_2 in place. 172 173 if (IsInPlaceOp(n->type_string())) { 174 return Status::OK(); 175 } 176 nodes.push_back(n); 177 } 178 179 // Iterate over a copy of the nodes to avoid iterating over g->nodes() while 180 // creating more nodes. 181 for (Node* n : nodes) { 182 TF_RETURN_IF_ERROR(CloneSmallHostConstantInputs(g, name_set, n)); 183 } 184 return Status::OK(); 185 } 186 187 } // namespace tensorflow 188