1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/stream_executor/plugin_registry.h" 17 18 #include "tensorflow/stream_executor/lib/error.h" 19 #include "tensorflow/stream_executor/lib/stringprintf.h" 20 #include "tensorflow/stream_executor/multi_platform_manager.h" 21 22 namespace perftools { 23 namespace gputools { 24 25 const PluginId kNullPlugin = nullptr; 26 27 // Returns the string representation of the specified PluginKind. 28 string PluginKindString(PluginKind plugin_kind) { 29 switch (plugin_kind) { 30 case PluginKind::kBlas: 31 return "BLAS"; 32 case PluginKind::kDnn: 33 return "DNN"; 34 case PluginKind::kFft: 35 return "FFT"; 36 case PluginKind::kRng: 37 return "RNG"; 38 case PluginKind::kInvalid: 39 default: 40 return "kInvalid"; 41 } 42 } 43 44 PluginRegistry::DefaultFactories::DefaultFactories() : 45 blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { } 46 47 static mutex& GetPluginRegistryMutex() { 48 static mutex* mu = new mutex; 49 return *mu; 50 } 51 52 /* static */ PluginRegistry* PluginRegistry::instance_ = nullptr; 53 54 PluginRegistry::PluginRegistry() {} 55 56 /* static */ PluginRegistry* PluginRegistry::Instance() { 57 mutex_lock lock{GetPluginRegistryMutex()}; 58 if (instance_ == nullptr) { 59 instance_ = new PluginRegistry(); 60 } 61 return instance_; 62 } 63 64 void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind, 65 Platform::Id platform_id) { 66 platform_id_by_kind_[platform_kind] = platform_id; 67 } 68 69 template <typename FACTORY_TYPE> 70 port::Status PluginRegistry::RegisterFactoryInternal( 71 PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory, 72 std::map<PluginId, FACTORY_TYPE>* factories) { 73 mutex_lock lock{GetPluginRegistryMutex()}; 74 75 if (factories->find(plugin_id) != factories->end()) { 76 return port::Status{ 77 port::error::ALREADY_EXISTS, 78 port::Printf("Attempting to register factory for plugin %s when " 79 "one has already been registered", 80 plugin_name.c_str())}; 81 } 82 83 (*factories)[plugin_id] = factory; 84 plugin_names_[plugin_id] = plugin_name; 85 return port::Status::OK(); 86 } 87 88 template <typename FACTORY_TYPE> 89 port::StatusOr<FACTORY_TYPE> PluginRegistry::GetFactoryInternal( 90 PluginId plugin_id, const std::map<PluginId, FACTORY_TYPE>& factories, 91 const std::map<PluginId, FACTORY_TYPE>& generic_factories) const { 92 auto iter = factories.find(plugin_id); 93 if (iter == factories.end()) { 94 iter = generic_factories.find(plugin_id); 95 if (iter == generic_factories.end()) { 96 return port::Status{ 97 port::error::NOT_FOUND, 98 port::Printf("Plugin ID %p not registered.", plugin_id)}; 99 } 100 } 101 102 return iter->second; 103 } 104 105 bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id, 106 PluginKind plugin_kind, 107 PluginId plugin_id) { 108 if (!HasFactory(platform_id, plugin_kind, plugin_id)) { 109 port::StatusOr<Platform*> status = 110 MultiPlatformManager::PlatformWithId(platform_id); 111 string platform_name = "<unregistered platform>"; 112 if (status.ok()) { 113 platform_name = status.ValueOrDie()->Name(); 114 } 115 116 LOG(ERROR) << "A factory must be registered for a platform before being " 117 << "set as default! " 118 << "Platform name: " << platform_name 119 << ", PluginKind: " << PluginKindString(plugin_kind) 120 << ", PluginId: " << plugin_id; 121 return false; 122 } 123 124 switch (plugin_kind) { 125 case PluginKind::kBlas: 126 default_factories_[platform_id].blas = plugin_id; 127 break; 128 case PluginKind::kDnn: 129 default_factories_[platform_id].dnn = plugin_id; 130 break; 131 case PluginKind::kFft: 132 default_factories_[platform_id].fft = plugin_id; 133 break; 134 case PluginKind::kRng: 135 default_factories_[platform_id].rng = plugin_id; 136 break; 137 default: 138 LOG(ERROR) << "Invalid plugin kind specified: " 139 << static_cast<int>(plugin_kind); 140 return false; 141 } 142 143 return true; 144 } 145 146 bool PluginRegistry::HasFactory(const PluginFactories& factories, 147 PluginKind plugin_kind, 148 PluginId plugin_id) const { 149 switch (plugin_kind) { 150 case PluginKind::kBlas: 151 return factories.blas.find(plugin_id) != factories.blas.end(); 152 case PluginKind::kDnn: 153 return factories.dnn.find(plugin_id) != factories.dnn.end(); 154 case PluginKind::kFft: 155 return factories.fft.find(plugin_id) != factories.fft.end(); 156 case PluginKind::kRng: 157 return factories.rng.find(plugin_id) != factories.rng.end(); 158 default: 159 LOG(ERROR) << "Invalid plugin kind specified: " 160 << PluginKindString(plugin_kind); 161 return false; 162 } 163 } 164 165 bool PluginRegistry::HasFactory(Platform::Id platform_id, 166 PluginKind plugin_kind, 167 PluginId plugin_id) const { 168 auto iter = factories_.find(platform_id); 169 if (iter != factories_.end()) { 170 if (HasFactory(iter->second, plugin_kind, plugin_id)) { 171 return true; 172 } 173 } 174 175 return HasFactory(generic_factories_, plugin_kind, plugin_id); 176 } 177 178 // Explicit instantiations to support types exposed in user/public API. 179 #define EMIT_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE, FACTORY_VAR, PLUGIN_STRING) \ 180 template port::StatusOr<PluginRegistry::FACTORY_TYPE> \ 181 PluginRegistry::GetFactoryInternal<PluginRegistry::FACTORY_TYPE>( \ 182 PluginId plugin_id, \ 183 const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& factories, \ 184 const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& \ 185 generic_factories) const; \ 186 \ 187 template port::Status \ 188 PluginRegistry::RegisterFactoryInternal<PluginRegistry::FACTORY_TYPE>( \ 189 PluginId plugin_id, const string& plugin_name, \ 190 PluginRegistry::FACTORY_TYPE factory, \ 191 std::map<PluginId, PluginRegistry::FACTORY_TYPE>* factories); \ 192 \ 193 template <> \ 194 port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \ 195 Platform::Id platform_id, PluginId plugin_id, const string& name, \ 196 PluginRegistry::FACTORY_TYPE factory) { \ 197 return RegisterFactoryInternal(plugin_id, name, factory, \ 198 &factories_[platform_id].FACTORY_VAR); \ 199 } \ 200 \ 201 template <> \ 202 port::Status PluginRegistry::RegisterFactoryForAllPlatforms< \ 203 PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, const string& name, \ 204 PluginRegistry::FACTORY_TYPE factory) { \ 205 return RegisterFactoryInternal(plugin_id, name, factory, \ 206 &generic_factories_.FACTORY_VAR); \ 207 } \ 208 \ 209 template <> \ 210 port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \ 211 Platform::Id platform_id, PluginId plugin_id) { \ 212 if (plugin_id == PluginConfig::kDefault) { \ 213 plugin_id = default_factories_[platform_id].FACTORY_VAR; \ 214 \ 215 if (plugin_id == kNullPlugin) { \ 216 return port::Status{port::error::FAILED_PRECONDITION, \ 217 "No suitable " PLUGIN_STRING \ 218 " plugin registered. Have you linked in a " \ 219 PLUGIN_STRING "-providing plugin?"}; \ 220 } else { \ 221 VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, " \ 222 << plugin_names_[plugin_id]; \ 223 } \ 224 } \ 225 return GetFactoryInternal(plugin_id, factories_[platform_id].FACTORY_VAR, \ 226 generic_factories_.FACTORY_VAR); \ 227 } \ 228 \ 229 /* TODO(b/22689637): Also temporary WRT MultiPlatformManager */ \ 230 template <> \ 231 port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \ 232 PlatformKind platform_kind, PluginId plugin_id) { \ 233 auto iter = platform_id_by_kind_.find(platform_kind); \ 234 if (iter == platform_id_by_kind_.end()) { \ 235 return port::Status{port::error::FAILED_PRECONDITION, \ 236 port::Printf("Platform kind %d not registered.", \ 237 static_cast<int>(platform_kind))}; \ 238 } \ 239 return GetFactory<PluginRegistry::FACTORY_TYPE>(iter->second, plugin_id); \ 240 } 241 242 EMIT_PLUGIN_SPECIALIZATIONS(BlasFactory, blas, "BLAS"); 243 EMIT_PLUGIN_SPECIALIZATIONS(DnnFactory, dnn, "DNN"); 244 EMIT_PLUGIN_SPECIALIZATIONS(FftFactory, fft, "FFT"); 245 EMIT_PLUGIN_SPECIALIZATIONS(RngFactory, rng, "RNG"); 246 247 } // namespace gputools 248 } // namespace perftools 249