1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/computation_tracker.h" 17 18 #include <list> 19 #include <string> 20 #include <utility> 21 #include <vector> 22 23 #include "tensorflow/compiler/xla/ptr_util.h" 24 #include "tensorflow/compiler/xla/service/hlo_computation.h" 25 #include "tensorflow/compiler/xla/status_macros.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/compiler/xla/util.h" 28 #include "tensorflow/core/lib/strings/strcat.h" 29 #include "tensorflow/core/lib/strings/stringprintf.h" 30 #include "tensorflow/core/platform/logging.h" 31 32 using ::tensorflow::strings::Appendf; 33 34 namespace xla { 35 36 ComputationTracker::ComputationTracker() : next_computation_(1) {} 37 38 ComputationHandle ComputationTracker::NewComputation( 39 const string& computation_name) { 40 tensorflow::mutex_lock lock(computation_mutex_); 41 ComputationHandle computation_handle; 42 int64 handle_value = next_computation_++; 43 computation_handle.set_handle(handle_value); 44 opaque_to_computation_[handle_value] = 45 MakeUnique<UserComputation>(computation_name, computation_handle); 46 return computation_handle; 47 } 48 49 StatusOr<ComputationHandle> ComputationTracker::LoadSessionModule( 50 const SessionModule& session_module) { 51 tensorflow::mutex_lock lock(computation_mutex_); 52 53 // For each embedded computation, create a new computation based on its 54 // serialized data, and place the mapping from the old computation handle to 55 // the new computation handle. 56 57 // Build a mapping from old embedded computation handles to new computation 58 // handles. We build the ID mapping first since the embedded computations are 59 // in no particular order and may refer to each other. 60 std::map<int64, ComputationHandle> old_to_new; 61 for (const SessionComputation& computation : 62 session_module.embedded_computations()) { 63 const int64 old_handle = computation.computation_handle().handle(); 64 if (!old_to_new.emplace(old_handle, AllocateHandle()).second) { 65 return InvalidArgument("Duplicate embedded computation handle %lld", 66 old_handle); 67 } 68 } 69 70 // Create a new computation from each serialized embedded computation. 71 for (const SessionComputation& computation : 72 session_module.embedded_computations()) { 73 const int64 old_handle = computation.computation_handle().handle(); 74 const ComputationHandle& new_handle = old_to_new[old_handle]; 75 TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], 76 UserComputation::MakeWithRemapping( 77 computation, new_handle, old_to_new)); 78 } 79 80 // Finally, place the entry computation in the tracker with all of the 81 // remappings populated from the above. 82 const int64 old_handle = session_module.entry().computation_handle().handle(); 83 TF_ASSIGN_OR_RETURN( 84 old_to_new[old_handle], 85 LoadSessionComputation(session_module.entry(), &old_to_new)); 86 return old_to_new[old_handle]; 87 } 88 89 StatusOr<std::unique_ptr<SessionModule>> 90 ComputationTracker::SnapshotComputation(const ComputationHandle& computation) { 91 TF_ASSIGN_OR_RETURN(UserComputation * user_computation, Resolve(computation)); 92 const VersionedComputationHandle entry_versioned_handle = 93 user_computation->GetVersionedHandle(); 94 std::set<VersionedComputationHandle> visited; 95 std::list<VersionedComputationHandle> post_order; 96 { 97 tensorflow::mutex_lock lock(computation_mutex_); 98 ComputeComputationPostOrder(entry_versioned_handle, &visited, &post_order); 99 } 100 auto session_module = MakeUnique<SessionModule>(); 101 *session_module->mutable_entry() = 102 Resolve(entry_versioned_handle.handle) 103 .ValueOrDie() 104 ->CloneSessionComputation(entry_versioned_handle.version); 105 for (auto it = ++post_order.rbegin(); it != post_order.rend(); ++it) { 106 *session_module->add_embedded_computations() = 107 Resolve(it->handle).ValueOrDie()->CloneSessionComputation(it->version); 108 } 109 return std::move(session_module); 110 } 111 112 StatusOr<UserComputation*> ComputationTracker::Resolve( 113 const ComputationHandle& computation) const { 114 tensorflow::mutex_lock lock(computation_mutex_); 115 return ResolveInternal(computation); 116 } 117 118 ComputationHandle ComputationTracker::AllocateHandle() { 119 int64 handle_value = next_computation_++; 120 ComputationHandle result; 121 result.set_handle(handle_value); 122 return result; 123 } 124 125 StatusOr<ComputationHandle> ComputationTracker::LoadSessionComputation( 126 const SessionComputation& session_computation, 127 std::map<int64, ComputationHandle>* old_to_new) { 128 TF_RET_CHECK(old_to_new != nullptr); 129 const ComputationHandle new_handle = AllocateHandle(); 130 (*old_to_new)[session_computation.computation_handle().handle()] = new_handle; 131 TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], 132 UserComputation::MakeWithRemapping( 133 session_computation, new_handle, *old_to_new)); 134 return new_handle; 135 } 136 137 StatusOr<UserComputation*> ComputationTracker::ResolveInternal( 138 const ComputationHandle& computation) const { 139 auto it = opaque_to_computation_.find(computation.handle()); 140 if (it == opaque_to_computation_.end()) { 141 return NotFound("computation handle not found: %lld", computation.handle()); 142 } 143 UserComputation* user_computation = it->second.get(); 144 return user_computation; 145 } 146 147 void ComputationTracker::ComputeComputationPostOrder( 148 const VersionedComputationHandle& versioned_handle, 149 std::set<VersionedComputationHandle>* visited, 150 std::list<VersionedComputationHandle>* post_order) const { 151 if (visited->count(versioned_handle) > 0) { 152 CHECK_EQ(1, visited->count(versioned_handle)); 153 return; 154 } 155 156 UserComputation* computation = 157 ResolveInternal(versioned_handle.handle).ValueOrDie(); 158 std::vector<VersionedComputationHandle> embedded_handles = 159 computation->GetEmbeddedComputations(versioned_handle.version); 160 161 for (const auto& embedded_handle : embedded_handles) { 162 ComputeComputationPostOrder(embedded_handle, visited, post_order); 163 } 164 165 visited->insert(versioned_handle); 166 post_order->push_back(versioned_handle); 167 } 168 169 StatusOr<std::unique_ptr<HloModule>> ComputationTracker::BuildHloModule( 170 const VersionedComputationHandle& entry_handle, 171 const HloModuleConfig& config, 172 bool include_unreachable_instructions) const { 173 tensorflow::mutex_lock lock(computation_mutex_); 174 175 VLOG(1) << "BuildHloModule(" << entry_handle 176 << ", include_unreachable_instructions=" 177 << include_unreachable_instructions << ")"; 178 XLA_VLOG_LINES(1, ToStringInternal()); 179 180 TF_ASSIGN_OR_RETURN(UserComputation * entry_computation, 181 ResolveInternal(entry_handle.handle)); 182 183 // Build a topological sort of the entry and any embedded computations as a 184 // list. The root of the computation will be the last element in the list. 185 std::set<VersionedComputationHandle> visited; 186 std::list<VersionedComputationHandle> post_order; 187 ComputeComputationPostOrder(entry_handle, &visited, &post_order); 188 189 // Map from ComputationHandle value and computation version to HloComputation. 190 std::map<VersionedComputationHandle, HloComputation*> hlo_computations; 191 192 // The resolver lambda resolves VersionedHandles to embedded 193 // HloComputation*. This is required by UserComputation::BuildHloComputation 194 // when lowering calling operations (map, reduce etc). 195 auto resolver = [&hlo_computations]( 196 const VersionedComputationHandle& versioned_handle) -> HloComputation* { 197 CHECK_GT(hlo_computations.count(versioned_handle), 0); 198 return hlo_computations.at(versioned_handle); 199 }; 200 201 // Print the post-order list for this entry computation. 202 if (VLOG_IS_ON(2)) { 203 VLOG(2) << "Visiting UserComputations in post order:"; 204 for (const VersionedComputationHandle& versioned_handle : post_order) { 205 VLOG(2) << " " << versioned_handle; 206 } 207 } 208 209 string module_name = 210 tensorflow::strings::StrCat(entry_computation->name(), "_module"); 211 auto module = MakeUnique<HloModule>(module_name, entry_handle, config); 212 for (auto versioned_handle : post_order) { 213 UserComputation* computation = 214 ResolveInternal(versioned_handle.handle).ValueOrDie(); 215 216 TF_ASSIGN_OR_RETURN( 217 std::unique_ptr<HloComputation> hlo_computation, 218 computation->BuildHloComputation(versioned_handle.version, resolver, 219 config.debug_options(), 220 include_unreachable_instructions)); 221 222 // Add the newly created computation to VersionedHandle-to-HloComputation 223 // map. 224 DCHECK_EQ(0, hlo_computations.count(versioned_handle)); 225 hlo_computations[versioned_handle] = hlo_computation.get(); 226 227 if (computation == entry_computation) { 228 module->AddEntryComputation(std::move(hlo_computation)); 229 } else { 230 module->AddEmbeddedComputation(std::move(hlo_computation)); 231 } 232 } 233 234 return std::move(module); 235 } 236 237 string ComputationTracker::ToString() const { 238 tensorflow::mutex_lock lock(computation_mutex_); 239 return ToStringInternal(); 240 } 241 242 string ComputationTracker::ToStringInternal() const { 243 string out; 244 Appendf(&out, "ComputationTracker(%p):\n", this); 245 for (const auto& handle_computation : opaque_to_computation_) { 246 int64 handle = handle_computation.first; 247 const std::unique_ptr<UserComputation>& computation = 248 handle_computation.second; 249 Appendf(&out, " %4lld : %s \"%s\"\n", handle, 250 computation->GetVersionedHandle().ToString().c_str(), 251 computation->name().c_str()); 252 } 253 return out; 254 } 255 256 } // namespace xla 257