1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/client/client_library.h" 17 18 #include "absl/memory/memory.h" 19 #include "tensorflow/compiler/xla/service/backend.h" 20 #include "tensorflow/compiler/xla/service/platform_util.h" 21 #include "tensorflow/compiler/xla/status_macros.h" 22 #include "tensorflow/compiler/xla/util.h" 23 #include "tensorflow/core/platform/logging.h" 24 25 namespace xla { 26 27 LocalClientOptions::LocalClientOptions( 28 se::Platform* platform, int number_of_replicas, 29 int intra_op_parallelism_threads, 30 const absl::optional<std::set<int>>& allowed_devices) 31 : platform_(platform), 32 number_of_replicas_(number_of_replicas), 33 intra_op_parallelism_threads_(intra_op_parallelism_threads), 34 allowed_devices_(allowed_devices) {} 35 36 LocalClientOptions& LocalClientOptions::set_platform(se::Platform* platform) { 37 platform_ = platform; 38 return *this; 39 } 40 41 se::Platform* LocalClientOptions::platform() const { return platform_; } 42 43 LocalClientOptions& LocalClientOptions::set_number_of_replicas( 44 int number_of_replicas) { 45 number_of_replicas_ = number_of_replicas; 46 return *this; 47 } 48 49 int LocalClientOptions::number_of_replicas() const { 50 return number_of_replicas_; 51 } 52 53 LocalClientOptions& LocalClientOptions::set_intra_op_parallelism_threads( 54 int num_threads) { 55 intra_op_parallelism_threads_ = num_threads; 56 return *this; 57 } 58 59 int LocalClientOptions::intra_op_parallelism_threads() const { 60 return intra_op_parallelism_threads_; 61 } 62 63 LocalClientOptions& LocalClientOptions::set_allowed_devices( 64 const absl::optional<std::set<int>>& allowed_devices) { 65 allowed_devices_ = allowed_devices; 66 return *this; 67 } 68 69 const absl::optional<std::set<int>>& LocalClientOptions::allowed_devices() 70 const { 71 return allowed_devices_; 72 } 73 74 /* static */ ClientLibrary& ClientLibrary::Singleton() { 75 static ClientLibrary* c = new ClientLibrary; 76 return *c; 77 } 78 79 ClientLibrary::ClientLibrary() = default; 80 ClientLibrary::~ClientLibrary() = default; 81 82 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient( 83 se::Platform* platform, const absl::optional<std::set<int>>& device_set) { 84 LocalClientOptions default_options; 85 default_options.set_platform(platform); 86 default_options.set_allowed_devices(device_set); 87 return GetOrCreateLocalClient(default_options); 88 } 89 90 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient( 91 const LocalClientOptions& options) { 92 se::Platform* platform = options.platform(); 93 int replica_count = options.number_of_replicas(); 94 ClientLibrary& client_library = Singleton(); 95 tensorflow::mutex_lock lock(client_library.service_mutex_); 96 97 if (platform == nullptr) { 98 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); 99 } 100 101 auto it = client_library.local_instances_.find(platform->id()); 102 if (it != client_library.local_instances_.end()) { 103 return it->second->client.get(); 104 } 105 106 ServiceOptions service_options; 107 service_options.set_platform(platform); 108 service_options.set_number_of_replicas(replica_count); 109 service_options.set_intra_op_parallelism_threads( 110 options.intra_op_parallelism_threads()); 111 service_options.set_allowed_devices(options.allowed_devices()); 112 auto instance = absl::make_unique<LocalInstance>(); 113 TF_ASSIGN_OR_RETURN(instance->service, 114 LocalService::NewService(service_options)); 115 instance->client = absl::make_unique<LocalClient>(instance->service.get()); 116 LocalClient* cl = instance->client.get(); 117 118 client_library.local_instances_.insert( 119 std::make_pair(platform->id(), std::move(instance))); 120 return cl; 121 } 122 123 /* static */ LocalClient* ClientLibrary::LocalClientOrDie() { 124 auto client_status = GetOrCreateLocalClient(); 125 TF_CHECK_OK(client_status.status()); 126 return client_status.ValueOrDie(); 127 } 128 129 /* static */ LocalService* ClientLibrary::GetXlaService( 130 se::Platform* platform) { 131 ClientLibrary& client_library = Singleton(); 132 tensorflow::mutex_lock lock(client_library.service_mutex_); 133 auto it = client_library.local_instances_.find(platform->id()); 134 CHECK(it != client_library.local_instances_.end()); 135 return it->second->service.get(); 136 } 137 138 /* static */ StatusOr<CompileOnlyClient*> 139 ClientLibrary::GetOrCreateCompileOnlyClient(se::Platform* platform) { 140 ClientLibrary& client_library = Singleton(); 141 tensorflow::mutex_lock lock(client_library.service_mutex_); 142 143 if (platform == nullptr) { 144 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); 145 } 146 147 auto it = client_library.compile_only_instances_.find(platform->id()); 148 if (it != client_library.compile_only_instances_.end()) { 149 return it->second->client.get(); 150 } 151 152 auto instance = absl::make_unique<CompileOnlyInstance>(); 153 TF_ASSIGN_OR_RETURN(instance->service, 154 CompileOnlyService::NewService(platform)); 155 instance->client = 156 absl::make_unique<CompileOnlyClient>(instance->service.get()); 157 CompileOnlyClient* cl = instance->client.get(); 158 159 client_library.compile_only_instances_.insert( 160 std::make_pair(platform->id(), std::move(instance))); 161 return cl; 162 } 163 164 /* static */ void ClientLibrary::DestroyLocalInstances() { 165 ClientLibrary& client_library = Singleton(); 166 tensorflow::mutex_lock lock(client_library.service_mutex_); 167 168 client_library.local_instances_.clear(); 169 client_library.compile_only_instances_.clear(); 170 } 171 172 } // namespace xla 173