1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/computation_placer.h" 17 18 #include <string> 19 #include <utility> 20 #include <vector> 21 22 #include "tensorflow/compiler/xla/literal_util.h" 23 #include "tensorflow/compiler/xla/ptr_util.h" 24 #include "tensorflow/compiler/xla/shape_util.h" 25 #include "tensorflow/compiler/xla/status.h" 26 #include "tensorflow/compiler/xla/status_macros.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/util.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/platform/logging.h" 33 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 34 35 namespace se = ::perftools::gputools; 36 37 namespace xla { 38 39 Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { 40 proto->set_replica_count(replica_count()); 41 proto->set_computation_count(computation_count()); 42 for (int computation = 0; computation < computation_count(); ++computation) { 43 DeviceAssignmentProto::ComputationDevice* computation_device = 44 proto->add_computation_devices(); 45 for (int replica = 0; replica < replica_count(); ++replica) { 46 computation_device->add_replica_device_ids((*this)(replica, computation)); 47 } 48 } 49 return Status::OK(); 50 } 51 52 /* static */ StatusOr<std::unique_ptr<DeviceAssignment>> 53 DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { 54 TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count()); 55 if (proto.replica_count() <= 0 || proto.computation_count() <= 0) { 56 return InvalidArgument( 57 "Invalid device assignment topology: replica_count=%d, " 58 "computation_count=%d", 59 proto.replica_count(), proto.computation_count()); 60 } 61 auto assignment = MakeUnique<DeviceAssignment>(proto.replica_count(), 62 proto.computation_count()); 63 for (int computation = 0; computation < proto.computation_count(); 64 ++computation) { 65 const auto& computation_device = proto.computation_devices(computation); 66 TF_RET_CHECK(computation_device.replica_device_ids_size() == 67 proto.replica_count()); 68 for (int replica = 0; replica < proto.replica_count(); ++replica) { 69 (*assignment)(replica, computation) = 70 computation_device.replica_device_ids(replica); 71 } 72 } 73 return std::move(assignment); 74 } 75 76 StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation, 77 int replica_count, 78 int computation_count) { 79 TF_RET_CHECK(replica < replica_count); 80 TF_RET_CHECK(computation < computation_count); 81 82 return computation * replica_count + replica; 83 } 84 85 StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices( 86 int replica_count, int computation_count) { 87 DeviceAssignment assignment(replica_count, computation_count); 88 for (int replica = 0; replica < replica_count; ++replica) { 89 for (int computation = 0; computation < computation_count; ++computation) { 90 TF_ASSIGN_OR_RETURN( 91 int device_id, 92 DeviceId(replica, computation, replica_count, computation_count)); 93 assignment(replica, computation) = device_id; 94 } 95 } 96 return std::move(assignment); 97 } 98 99 /* static */ void ComputationPlacer::RegisterComputationPlacer( 100 se::Platform::Id platform_id, 101 ComputationPlacerCreationFunction creation_function) { 102 tensorflow::mutex_lock lock( 103 ComputationPlacer::platform_computation_placer_mutex_); 104 auto* computation_placers = GetPlatformComputationPlacers(); 105 CHECK(computation_placers->find(platform_id) == computation_placers->end()); 106 (*computation_placers)[platform_id].creation_function = creation_function; 107 } 108 109 /* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform( 110 const se::Platform* platform) { 111 tensorflow::mutex_lock lock( 112 ComputationPlacer::platform_computation_placer_mutex_); 113 auto* computation_placers = GetPlatformComputationPlacers(); 114 115 auto it = computation_placers->find(platform->id()); 116 if (it == computation_placers->end()) { 117 return NotFound( 118 "could not find registered computation placer for platform %s -- check " 119 "target linkage", 120 platform->Name().c_str()); 121 } 122 123 if (it->second.placer == nullptr) { 124 // Lazily create the computation placer the first time it is needed. 125 it->second.placer = (*it->second.creation_function)(); 126 } 127 128 return it->second.placer.get(); 129 } 130 131 /* static */ tensorflow::mutex 132 ComputationPlacer::platform_computation_placer_mutex_( 133 tensorflow::LINKER_INITIALIZED); 134 135 /* static */ std::map<perftools::gputools::Platform::Id, 136 ComputationPlacer::State>* 137 ComputationPlacer::GetPlatformComputationPlacers() { 138 static auto* r = 139 new std::map<perftools::gputools::Platform::Id, ComputationPlacer::State>; 140 return r; 141 } 142 143 } // namespace xla 144 145 static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() { 146 return xla::MakeUnique<xla::ComputationPlacer>(); 147 } 148 149 static bool InitModule() { 150 xla::ComputationPlacer::RegisterComputationPlacer(se::host::kHostPlatformId, 151 &CreateComputationPlacer); 152 xla::ComputationPlacer::RegisterComputationPlacer(se::cuda::kCudaPlatformId, 153 &CreateComputationPlacer); 154 return true; 155 } 156 static bool module_initialized = InitModule(); 157