Home | History | Annotate | Download | only in client
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/client/client_library.h"
     17 
     18 #include "tensorflow/compiler/xla/service/backend.h"
     19 #include "tensorflow/compiler/xla/service/platform_util.h"
     20 #include "tensorflow/compiler/xla/status_macros.h"
     21 #include "tensorflow/compiler/xla/util.h"
     22 #include "tensorflow/core/platform/logging.h"
     23 
     24 namespace xla {
     25 
     26 LocalClientOptions::LocalClientOptions(perftools::gputools::Platform* platform,
     27                                        int number_of_replicas,
     28                                        int intra_op_parallelism_threads)
     29     : platform_(platform),
     30       number_of_replicas_(number_of_replicas),
     31       intra_op_parallelism_threads_(intra_op_parallelism_threads) {}
     32 
     33 LocalClientOptions& LocalClientOptions::set_platform(
     34     perftools::gputools::Platform* platform) {
     35   platform_ = platform;
     36   return *this;
     37 }
     38 
     39 perftools::gputools::Platform* LocalClientOptions::platform() const {
     40   return platform_;
     41 }
     42 
     43 LocalClientOptions& LocalClientOptions::set_number_of_replicas(
     44     int number_of_replicas) {
     45   number_of_replicas_ = number_of_replicas;
     46   return *this;
     47 }
     48 
     49 int LocalClientOptions::number_of_replicas() const {
     50   return number_of_replicas_;
     51 }
     52 
     53 LocalClientOptions& LocalClientOptions::set_intra_op_parallelism_threads(
     54     int num_threads) {
     55   intra_op_parallelism_threads_ = num_threads;
     56   return *this;
     57 }
     58 
     59 int LocalClientOptions::intra_op_parallelism_threads() const {
     60   return intra_op_parallelism_threads_;
     61 }
     62 
     63 /* static */ ClientLibrary& ClientLibrary::Singleton() {
     64   static ClientLibrary* c = new ClientLibrary;
     65   return *c;
     66 }
     67 
     68 ClientLibrary::ClientLibrary() = default;
     69 ClientLibrary::~ClientLibrary() = default;
     70 
     71 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
     72     perftools::gputools::Platform* platform) {
     73   LocalClientOptions default_options;
     74   default_options.set_platform(platform);
     75   return GetOrCreateLocalClient(default_options);
     76 }
     77 
     78 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
     79     const LocalClientOptions& options) {
     80   perftools::gputools::Platform* platform = options.platform();
     81   int replica_count = options.number_of_replicas();
     82   ClientLibrary& client_library = Singleton();
     83   tensorflow::mutex_lock lock(client_library.service_mutex_);
     84 
     85   if (platform == nullptr) {
     86     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
     87   }
     88 
     89   auto it = client_library.local_instances_.find(platform->id());
     90   if (it != client_library.local_instances_.end()) {
     91     return it->second->client.get();
     92   }
     93 
     94   ServiceOptions service_options;
     95   service_options.set_platform(platform);
     96   service_options.set_number_of_replicas(replica_count);
     97   service_options.set_intra_op_parallelism_threads(
     98       options.intra_op_parallelism_threads());
     99 
    100   auto instance = MakeUnique<LocalInstance>();
    101   TF_ASSIGN_OR_RETURN(instance->service,
    102                       LocalService::NewService(service_options));
    103   instance->client = MakeUnique<LocalClient>(instance->service.get());
    104   LocalClient* cl = instance->client.get();
    105 
    106   client_library.local_instances_.insert(
    107       std::make_pair(platform->id(), std::move(instance)));
    108   return cl;
    109 }
    110 
    111 /* static */ LocalClient* ClientLibrary::LocalClientOrDie() {
    112   auto client_status = GetOrCreateLocalClient();
    113   TF_CHECK_OK(client_status.status());
    114   return client_status.ValueOrDie();
    115 }
    116 
    117 /* static */ LocalService* ClientLibrary::GetXlaService(
    118     perftools::gputools::Platform* platform) {
    119   ClientLibrary& client_library = Singleton();
    120   tensorflow::mutex_lock lock(client_library.service_mutex_);
    121   auto it = client_library.local_instances_.find(platform->id());
    122   CHECK(it != client_library.local_instances_.end());
    123   return it->second->service.get();
    124 }
    125 
    126 /* static */ StatusOr<CompileOnlyClient*>
    127 ClientLibrary::GetOrCreateCompileOnlyClient(
    128     perftools::gputools::Platform* platform) {
    129   ClientLibrary& client_library = Singleton();
    130   tensorflow::mutex_lock lock(client_library.service_mutex_);
    131 
    132   if (platform == nullptr) {
    133     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
    134   }
    135 
    136   auto it = client_library.compile_only_instances_.find(platform->id());
    137   if (it != client_library.compile_only_instances_.end()) {
    138     return it->second->client.get();
    139   }
    140 
    141   auto instance = MakeUnique<CompileOnlyInstance>();
    142   TF_ASSIGN_OR_RETURN(instance->service,
    143                       CompileOnlyService::NewService(platform));
    144   instance->client = MakeUnique<CompileOnlyClient>(instance->service.get());
    145   CompileOnlyClient* cl = instance->client.get();
    146 
    147   client_library.compile_only_instances_.insert(
    148       std::make_pair(platform->id(), std::move(instance)));
    149   return cl;
    150 }
    151 
    152 /* static */ void ClientLibrary::DestroyLocalInstances() {
    153   ClientLibrary& client_library = Singleton();
    154   tensorflow::mutex_lock lock(client_library.service_mutex_);
    155 
    156   client_library.local_instances_.clear();
    157   client_library.compile_only_instances_.clear();
    158 }
    159 
    160 }  // namespace xla
    161