Home | History | Annotate | Download | only in stream_executor
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/stream_executor/plugin_registry.h"
     17 
     18 #include "tensorflow/stream_executor/lib/error.h"
     19 #include "tensorflow/stream_executor/lib/stringprintf.h"
     20 #include "tensorflow/stream_executor/multi_platform_manager.h"
     21 
     22 namespace stream_executor {
     23 
     24 const PluginId kNullPlugin = nullptr;
     25 
     26 // Returns the string representation of the specified PluginKind.
     27 string PluginKindString(PluginKind plugin_kind) {
     28   switch (plugin_kind) {
     29     case PluginKind::kBlas:
     30       return "BLAS";
     31     case PluginKind::kDnn:
     32       return "DNN";
     33     case PluginKind::kFft:
     34       return "FFT";
     35     case PluginKind::kRng:
     36       return "RNG";
     37     case PluginKind::kInvalid:
     38     default:
     39       return "kInvalid";
     40   }
     41 }
     42 
     43 PluginRegistry::DefaultFactories::DefaultFactories() :
     44     blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { }
     45 
     46 static mutex& GetPluginRegistryMutex() {
     47   static mutex* mu = new mutex;
     48   return *mu;
     49 }
     50 
     51 /* static */ PluginRegistry* PluginRegistry::instance_ = nullptr;
     52 
     53 PluginRegistry::PluginRegistry() {}
     54 
     55 /* static */ PluginRegistry* PluginRegistry::Instance() {
     56   mutex_lock lock{GetPluginRegistryMutex()};
     57   if (instance_ == nullptr) {
     58     instance_ = new PluginRegistry();
     59   }
     60   return instance_;
     61 }
     62 
     63 void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind,
     64                                          Platform::Id platform_id) {
     65   platform_id_by_kind_[platform_kind] = platform_id;
     66 }
     67 
     68 template <typename FACTORY_TYPE>
     69 port::Status PluginRegistry::RegisterFactoryInternal(
     70     PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory,
     71     std::map<PluginId, FACTORY_TYPE>* factories) {
     72   mutex_lock lock{GetPluginRegistryMutex()};
     73 
     74   if (factories->find(plugin_id) != factories->end()) {
     75     return port::Status(
     76         port::error::ALREADY_EXISTS,
     77         port::Printf("Attempting to register factory for plugin %s when "
     78                      "one has already been registered",
     79                      plugin_name.c_str()));
     80   }
     81 
     82   (*factories)[plugin_id] = factory;
     83   plugin_names_[plugin_id] = plugin_name;
     84   return port::Status::OK();
     85 }
     86 
     87 template <typename FACTORY_TYPE>
     88 port::StatusOr<FACTORY_TYPE> PluginRegistry::GetFactoryInternal(
     89     PluginId plugin_id, const std::map<PluginId, FACTORY_TYPE>& factories,
     90     const std::map<PluginId, FACTORY_TYPE>& generic_factories) const {
     91   auto iter = factories.find(plugin_id);
     92   if (iter == factories.end()) {
     93     iter = generic_factories.find(plugin_id);
     94     if (iter == generic_factories.end()) {
     95       return port::Status(
     96           port::error::NOT_FOUND,
     97           port::Printf("Plugin ID %p not registered.", plugin_id));
     98     }
     99   }
    100 
    101   return iter->second;
    102 }
    103 
    104 bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id,
    105                                        PluginKind plugin_kind,
    106                                        PluginId plugin_id) {
    107   if (!HasFactory(platform_id, plugin_kind, plugin_id)) {
    108     port::StatusOr<Platform*> status =
    109         MultiPlatformManager::PlatformWithId(platform_id);
    110     string platform_name = "<unregistered platform>";
    111     if (status.ok()) {
    112       platform_name = status.ValueOrDie()->Name();
    113     }
    114 
    115     LOG(ERROR) << "A factory must be registered for a platform before being "
    116                << "set as default! "
    117                << "Platform name: " << platform_name
    118                << ", PluginKind: " << PluginKindString(plugin_kind)
    119                << ", PluginId: " << plugin_id;
    120     return false;
    121   }
    122 
    123   switch (plugin_kind) {
    124     case PluginKind::kBlas:
    125       default_factories_[platform_id].blas = plugin_id;
    126       break;
    127     case PluginKind::kDnn:
    128       default_factories_[platform_id].dnn = plugin_id;
    129       break;
    130     case PluginKind::kFft:
    131       default_factories_[platform_id].fft = plugin_id;
    132       break;
    133     case PluginKind::kRng:
    134       default_factories_[platform_id].rng = plugin_id;
    135       break;
    136     default:
    137       LOG(ERROR) << "Invalid plugin kind specified: "
    138                  << static_cast<int>(plugin_kind);
    139       return false;
    140   }
    141 
    142   return true;
    143 }
    144 
    145 bool PluginRegistry::HasFactory(const PluginFactories& factories,
    146                                 PluginKind plugin_kind,
    147                                 PluginId plugin_id) const {
    148   switch (plugin_kind) {
    149     case PluginKind::kBlas:
    150       return factories.blas.find(plugin_id) != factories.blas.end();
    151     case PluginKind::kDnn:
    152       return factories.dnn.find(plugin_id) != factories.dnn.end();
    153     case PluginKind::kFft:
    154       return factories.fft.find(plugin_id) != factories.fft.end();
    155     case PluginKind::kRng:
    156       return factories.rng.find(plugin_id) != factories.rng.end();
    157     default:
    158       LOG(ERROR) << "Invalid plugin kind specified: "
    159                  << PluginKindString(plugin_kind);
    160       return false;
    161   }
    162 }
    163 
    164 bool PluginRegistry::HasFactory(Platform::Id platform_id,
    165                                 PluginKind plugin_kind,
    166                                 PluginId plugin_id) const {
    167   auto iter = factories_.find(platform_id);
    168   if (iter != factories_.end()) {
    169     if (HasFactory(iter->second, plugin_kind, plugin_id)) {
    170       return true;
    171     }
    172   }
    173 
    174   return HasFactory(generic_factories_, plugin_kind, plugin_id);
    175 }
    176 
    177 // Explicit instantiations to support types exposed in user/public API.
    178 #define EMIT_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE, FACTORY_VAR, PLUGIN_STRING) \
    179   template port::StatusOr<PluginRegistry::FACTORY_TYPE>                       \
    180   PluginRegistry::GetFactoryInternal<PluginRegistry::FACTORY_TYPE>(           \
    181       PluginId plugin_id,                                                     \
    182       const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& factories,      \
    183       const std::map<PluginId, PluginRegistry::FACTORY_TYPE>&                 \
    184           generic_factories) const;                                           \
    185                                                                               \
    186   template port::Status                                                       \
    187   PluginRegistry::RegisterFactoryInternal<PluginRegistry::FACTORY_TYPE>(      \
    188       PluginId plugin_id, const string& plugin_name,                          \
    189       PluginRegistry::FACTORY_TYPE factory,                                   \
    190       std::map<PluginId, PluginRegistry::FACTORY_TYPE>* factories);           \
    191                                                                               \
    192   template <>                                                                 \
    193   port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
    194       Platform::Id platform_id, PluginId plugin_id, const string& name,       \
    195       PluginRegistry::FACTORY_TYPE factory) {                                 \
    196     return RegisterFactoryInternal(plugin_id, name, factory,                  \
    197                                    &factories_[platform_id].FACTORY_VAR);     \
    198   }                                                                           \
    199                                                                               \
    200   template <>                                                                 \
    201   port::Status PluginRegistry::RegisterFactoryForAllPlatforms<                \
    202       PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, const string& name,   \
    203                                     PluginRegistry::FACTORY_TYPE factory) {   \
    204     return RegisterFactoryInternal(plugin_id, name, factory,                  \
    205                                    &generic_factories_.FACTORY_VAR);          \
    206   }                                                                           \
    207                                                                               \
    208   template <>                                                                 \
    209   port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory(    \
    210       Platform::Id platform_id, PluginId plugin_id) {                         \
    211     if (plugin_id == PluginConfig::kDefault) {                                \
    212       plugin_id = default_factories_[platform_id].FACTORY_VAR;                \
    213                                                                               \
    214       if (plugin_id == kNullPlugin) {                                         \
    215         return port::Status(                                                  \
    216             port::error::FAILED_PRECONDITION,                                 \
    217             "No suitable " PLUGIN_STRING                                      \
    218             " plugin registered. Have you linked in a " PLUGIN_STRING         \
    219             "-providing plugin?");                                            \
    220       } else {                                                                \
    221         VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, "             \
    222                 << plugin_names_[plugin_id];                                  \
    223       }                                                                       \
    224     }                                                                         \
    225     return GetFactoryInternal(plugin_id, factories_[platform_id].FACTORY_VAR, \
    226                               generic_factories_.FACTORY_VAR);                \
    227   }                                                                           \
    228                                                                               \
    229   /* TODO(b/22689637): Also temporary WRT MultiPlatformManager */             \
    230   template <>                                                                 \
    231   port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory(    \
    232       PlatformKind platform_kind, PluginId plugin_id) {                       \
    233     auto iter = platform_id_by_kind_.find(platform_kind);                     \
    234     if (iter == platform_id_by_kind_.end()) {                                 \
    235       return port::Status(port::error::FAILED_PRECONDITION,                   \
    236                           port::Printf("Platform kind %d not registered.",    \
    237                                        static_cast<int>(platform_kind)));     \
    238     }                                                                         \
    239     return GetFactory<PluginRegistry::FACTORY_TYPE>(iter->second, plugin_id); \
    240   }
    241 
    242 EMIT_PLUGIN_SPECIALIZATIONS(BlasFactory, blas, "BLAS");
    243 EMIT_PLUGIN_SPECIALIZATIONS(DnnFactory, dnn, "DNN");
    244 EMIT_PLUGIN_SPECIALIZATIONS(FftFactory, fft, "FFT");
    245 EMIT_PLUGIN_SPECIALIZATIONS(RngFactory, rng, "RNG");
    246 
    247 }  // namespace stream_executor
    248