1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/client/client_library.h" 17 18 #include "tensorflow/compiler/xla/service/backend.h" 19 #include "tensorflow/compiler/xla/service/platform_util.h" 20 #include "tensorflow/compiler/xla/status_macros.h" 21 #include "tensorflow/compiler/xla/util.h" 22 #include "tensorflow/core/platform/logging.h" 23 24 namespace xla { 25 26 LocalClientOptions::LocalClientOptions(perftools::gputools::Platform* platform, 27 int number_of_replicas, 28 int intra_op_parallelism_threads) 29 : platform_(platform), 30 number_of_replicas_(number_of_replicas), 31 intra_op_parallelism_threads_(intra_op_parallelism_threads) {} 32 33 LocalClientOptions& LocalClientOptions::set_platform( 34 perftools::gputools::Platform* platform) { 35 platform_ = platform; 36 return *this; 37 } 38 39 perftools::gputools::Platform* LocalClientOptions::platform() const { 40 return platform_; 41 } 42 43 LocalClientOptions& LocalClientOptions::set_number_of_replicas( 44 int number_of_replicas) { 45 number_of_replicas_ = number_of_replicas; 46 return *this; 47 } 48 49 int LocalClientOptions::number_of_replicas() const { 50 return number_of_replicas_; 51 } 52 53 LocalClientOptions& LocalClientOptions::set_intra_op_parallelism_threads( 54 int num_threads) { 55 intra_op_parallelism_threads_ = num_threads; 56 return *this; 57 } 58 59 int LocalClientOptions::intra_op_parallelism_threads() const { 60 return intra_op_parallelism_threads_; 61 } 62 63 /* static */ ClientLibrary& ClientLibrary::Singleton() { 64 static ClientLibrary* c = new ClientLibrary; 65 return *c; 66 } 67 68 ClientLibrary::ClientLibrary() = default; 69 ClientLibrary::~ClientLibrary() = default; 70 71 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient( 72 perftools::gputools::Platform* platform) { 73 LocalClientOptions default_options; 74 default_options.set_platform(platform); 75 return GetOrCreateLocalClient(default_options); 76 } 77 78 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient( 79 const LocalClientOptions& options) { 80 perftools::gputools::Platform* platform = options.platform(); 81 int replica_count = options.number_of_replicas(); 82 ClientLibrary& client_library = Singleton(); 83 tensorflow::mutex_lock lock(client_library.service_mutex_); 84 85 if (platform == nullptr) { 86 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); 87 } 88 89 auto it = client_library.local_instances_.find(platform->id()); 90 if (it != client_library.local_instances_.end()) { 91 return it->second->client.get(); 92 } 93 94 ServiceOptions service_options; 95 service_options.set_platform(platform); 96 service_options.set_number_of_replicas(replica_count); 97 service_options.set_intra_op_parallelism_threads( 98 options.intra_op_parallelism_threads()); 99 100 auto instance = MakeUnique<LocalInstance>(); 101 TF_ASSIGN_OR_RETURN(instance->service, 102 LocalService::NewService(service_options)); 103 instance->client = MakeUnique<LocalClient>(instance->service.get()); 104 LocalClient* cl = instance->client.get(); 105 106 client_library.local_instances_.insert( 107 std::make_pair(platform->id(), std::move(instance))); 108 return cl; 109 } 110 111 /* static */ LocalClient* ClientLibrary::LocalClientOrDie() { 112 auto client_status = GetOrCreateLocalClient(); 113 TF_CHECK_OK(client_status.status()); 114 return client_status.ValueOrDie(); 115 } 116 117 /* static */ LocalService* ClientLibrary::GetXlaService( 118 perftools::gputools::Platform* platform) { 119 ClientLibrary& client_library = Singleton(); 120 tensorflow::mutex_lock lock(client_library.service_mutex_); 121 auto it = client_library.local_instances_.find(platform->id()); 122 CHECK(it != client_library.local_instances_.end()); 123 return it->second->service.get(); 124 } 125 126 /* static */ StatusOr<CompileOnlyClient*> 127 ClientLibrary::GetOrCreateCompileOnlyClient( 128 perftools::gputools::Platform* platform) { 129 ClientLibrary& client_library = Singleton(); 130 tensorflow::mutex_lock lock(client_library.service_mutex_); 131 132 if (platform == nullptr) { 133 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); 134 } 135 136 auto it = client_library.compile_only_instances_.find(platform->id()); 137 if (it != client_library.compile_only_instances_.end()) { 138 return it->second->client.get(); 139 } 140 141 auto instance = MakeUnique<CompileOnlyInstance>(); 142 TF_ASSIGN_OR_RETURN(instance->service, 143 CompileOnlyService::NewService(platform)); 144 instance->client = MakeUnique<CompileOnlyClient>(instance->service.get()); 145 CompileOnlyClient* cl = instance->client.get(); 146 147 client_library.compile_only_instances_.insert( 148 std::make_pair(platform->id(), std::move(instance))); 149 return cl; 150 } 151 152 /* static */ void ClientLibrary::DestroyLocalInstances() { 153 ClientLibrary& client_library = Singleton(); 154 tensorflow::mutex_lock lock(client_library.service_mutex_); 155 156 client_library.local_instances_.clear(); 157 client_library.compile_only_instances_.clear(); 158 } 159 160 } // namespace xla 161