Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/computation_placer.h"
     17 
     18 #include <string>
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "tensorflow/compiler/xla/literal_util.h"
     23 #include "tensorflow/compiler/xla/ptr_util.h"
     24 #include "tensorflow/compiler/xla/shape_util.h"
     25 #include "tensorflow/compiler/xla/status.h"
     26 #include "tensorflow/compiler/xla/status_macros.h"
     27 #include "tensorflow/compiler/xla/statusor.h"
     28 #include "tensorflow/compiler/xla/types.h"
     29 #include "tensorflow/compiler/xla/util.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/lib/core/status.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     34 
     35 namespace se = ::perftools::gputools;
     36 
     37 namespace xla {
     38 
     39 Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
     40   proto->set_replica_count(replica_count());
     41   proto->set_computation_count(computation_count());
     42   for (int computation = 0; computation < computation_count(); ++computation) {
     43     DeviceAssignmentProto::ComputationDevice* computation_device =
     44         proto->add_computation_devices();
     45     for (int replica = 0; replica < replica_count(); ++replica) {
     46       computation_device->add_replica_device_ids((*this)(replica, computation));
     47     }
     48   }
     49   return Status::OK();
     50 }
     51 
     52 /* static */ StatusOr<std::unique_ptr<DeviceAssignment>>
     53 DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
     54   TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
     55   if (proto.replica_count() <= 0 || proto.computation_count() <= 0) {
     56     return InvalidArgument(
     57         "Invalid device assignment topology: replica_count=%d, "
     58         "computation_count=%d",
     59         proto.replica_count(), proto.computation_count());
     60   }
     61   auto assignment = MakeUnique<DeviceAssignment>(proto.replica_count(),
     62                                                  proto.computation_count());
     63   for (int computation = 0; computation < proto.computation_count();
     64        ++computation) {
     65     const auto& computation_device = proto.computation_devices(computation);
     66     TF_RET_CHECK(computation_device.replica_device_ids_size() ==
     67                  proto.replica_count());
     68     for (int replica = 0; replica < proto.replica_count(); ++replica) {
     69       (*assignment)(replica, computation) =
     70           computation_device.replica_device_ids(replica);
     71     }
     72   }
     73   return std::move(assignment);
     74 }
     75 
     76 StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
     77                                           int replica_count,
     78                                           int computation_count) {
     79   TF_RET_CHECK(replica < replica_count);
     80   TF_RET_CHECK(computation < computation_count);
     81 
     82   return computation * replica_count + replica;
     83 }
     84 
     85 StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
     86     int replica_count, int computation_count) {
     87   DeviceAssignment assignment(replica_count, computation_count);
     88   for (int replica = 0; replica < replica_count; ++replica) {
     89     for (int computation = 0; computation < computation_count; ++computation) {
     90       TF_ASSIGN_OR_RETURN(
     91           int device_id,
     92           DeviceId(replica, computation, replica_count, computation_count));
     93       assignment(replica, computation) = device_id;
     94     }
     95   }
     96   return std::move(assignment);
     97 }
     98 
     99 /* static */ void ComputationPlacer::RegisterComputationPlacer(
    100     se::Platform::Id platform_id,
    101     ComputationPlacerCreationFunction creation_function) {
    102   tensorflow::mutex_lock lock(
    103       ComputationPlacer::platform_computation_placer_mutex_);
    104   auto* computation_placers = GetPlatformComputationPlacers();
    105   CHECK(computation_placers->find(platform_id) == computation_placers->end());
    106   (*computation_placers)[platform_id].creation_function = creation_function;
    107 }
    108 
    109 /* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
    110     const se::Platform* platform) {
    111   tensorflow::mutex_lock lock(
    112       ComputationPlacer::platform_computation_placer_mutex_);
    113   auto* computation_placers = GetPlatformComputationPlacers();
    114 
    115   auto it = computation_placers->find(platform->id());
    116   if (it == computation_placers->end()) {
    117     return NotFound(
    118         "could not find registered computation placer for platform %s -- check "
    119         "target linkage",
    120         platform->Name().c_str());
    121   }
    122 
    123   if (it->second.placer == nullptr) {
    124     // Lazily create the computation placer the first time it is needed.
    125     it->second.placer = (*it->second.creation_function)();
    126   }
    127 
    128   return it->second.placer.get();
    129 }
    130 
    131 /* static */ tensorflow::mutex
    132     ComputationPlacer::platform_computation_placer_mutex_(
    133         tensorflow::LINKER_INITIALIZED);
    134 
    135 /* static */ std::map<perftools::gputools::Platform::Id,
    136                       ComputationPlacer::State>*
    137 ComputationPlacer::GetPlatformComputationPlacers() {
    138   static auto* r =
    139       new std::map<perftools::gputools::Platform::Id, ComputationPlacer::State>;
    140   return r;
    141 }
    142 
    143 }  // namespace xla
    144 
    145 static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
    146   return xla::MakeUnique<xla::ComputationPlacer>();
    147 }
    148 
    149 static bool InitModule() {
    150   xla::ComputationPlacer::RegisterComputationPlacer(se::host::kHostPlatformId,
    151                                                     &CreateComputationPlacer);
    152   xla::ComputationPlacer::RegisterComputationPlacer(se::cuda::kCudaPlatformId,
    153                                                     &CreateComputationPlacer);
    154   return true;
    155 }
    156 static bool module_initialized = InitModule();
    157