Home | History | Annotate | Download | only in stream_executor
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/stream_executor/plugin_registry.h"
     17 
     18 #include "tensorflow/stream_executor/lib/error.h"
     19 #include "tensorflow/stream_executor/lib/stringprintf.h"
     20 #include "tensorflow/stream_executor/multi_platform_manager.h"
     21 
     22 namespace perftools {
     23 namespace gputools {
     24 
     25 const PluginId kNullPlugin = nullptr;
     26 
     27 // Returns the string representation of the specified PluginKind.
     28 string PluginKindString(PluginKind plugin_kind) {
     29   switch (plugin_kind) {
     30     case PluginKind::kBlas:
     31       return "BLAS";
     32     case PluginKind::kDnn:
     33       return "DNN";
     34     case PluginKind::kFft:
     35       return "FFT";
     36     case PluginKind::kRng:
     37       return "RNG";
     38     case PluginKind::kInvalid:
     39     default:
     40       return "kInvalid";
     41   }
     42 }
     43 
     44 PluginRegistry::DefaultFactories::DefaultFactories() :
     45     blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { }
     46 
     47 static mutex& GetPluginRegistryMutex() {
     48   static mutex* mu = new mutex;
     49   return *mu;
     50 }
     51 
     52 /* static */ PluginRegistry* PluginRegistry::instance_ = nullptr;
     53 
     54 PluginRegistry::PluginRegistry() {}
     55 
     56 /* static */ PluginRegistry* PluginRegistry::Instance() {
     57   mutex_lock lock{GetPluginRegistryMutex()};
     58   if (instance_ == nullptr) {
     59     instance_ = new PluginRegistry();
     60   }
     61   return instance_;
     62 }
     63 
     64 void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind,
     65                                          Platform::Id platform_id) {
     66   platform_id_by_kind_[platform_kind] = platform_id;
     67 }
     68 
     69 template <typename FACTORY_TYPE>
     70 port::Status PluginRegistry::RegisterFactoryInternal(
     71     PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory,
     72     std::map<PluginId, FACTORY_TYPE>* factories) {
     73   mutex_lock lock{GetPluginRegistryMutex()};
     74 
     75   if (factories->find(plugin_id) != factories->end()) {
     76     return port::Status{
     77         port::error::ALREADY_EXISTS,
     78         port::Printf("Attempting to register factory for plugin %s when "
     79                      "one has already been registered",
     80                      plugin_name.c_str())};
     81   }
     82 
     83   (*factories)[plugin_id] = factory;
     84   plugin_names_[plugin_id] = plugin_name;
     85   return port::Status::OK();
     86 }
     87 
     88 template <typename FACTORY_TYPE>
     89 port::StatusOr<FACTORY_TYPE> PluginRegistry::GetFactoryInternal(
     90     PluginId plugin_id, const std::map<PluginId, FACTORY_TYPE>& factories,
     91     const std::map<PluginId, FACTORY_TYPE>& generic_factories) const {
     92   auto iter = factories.find(plugin_id);
     93   if (iter == factories.end()) {
     94     iter = generic_factories.find(plugin_id);
     95     if (iter == generic_factories.end()) {
     96       return port::Status{
     97           port::error::NOT_FOUND,
     98           port::Printf("Plugin ID %p not registered.", plugin_id)};
     99     }
    100   }
    101 
    102   return iter->second;
    103 }
    104 
    105 bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id,
    106                                        PluginKind plugin_kind,
    107                                        PluginId plugin_id) {
    108   if (!HasFactory(platform_id, plugin_kind, plugin_id)) {
    109     port::StatusOr<Platform*> status =
    110         MultiPlatformManager::PlatformWithId(platform_id);
    111     string platform_name = "<unregistered platform>";
    112     if (status.ok()) {
    113       platform_name = status.ValueOrDie()->Name();
    114     }
    115 
    116     LOG(ERROR) << "A factory must be registered for a platform before being "
    117                << "set as default! "
    118                << "Platform name: " << platform_name
    119                << ", PluginKind: " << PluginKindString(plugin_kind)
    120                << ", PluginId: " << plugin_id;
    121     return false;
    122   }
    123 
    124   switch (plugin_kind) {
    125     case PluginKind::kBlas:
    126       default_factories_[platform_id].blas = plugin_id;
    127       break;
    128     case PluginKind::kDnn:
    129       default_factories_[platform_id].dnn = plugin_id;
    130       break;
    131     case PluginKind::kFft:
    132       default_factories_[platform_id].fft = plugin_id;
    133       break;
    134     case PluginKind::kRng:
    135       default_factories_[platform_id].rng = plugin_id;
    136       break;
    137     default:
    138       LOG(ERROR) << "Invalid plugin kind specified: "
    139                  << static_cast<int>(plugin_kind);
    140       return false;
    141   }
    142 
    143   return true;
    144 }
    145 
    146 bool PluginRegistry::HasFactory(const PluginFactories& factories,
    147                                 PluginKind plugin_kind,
    148                                 PluginId plugin_id) const {
    149   switch (plugin_kind) {
    150     case PluginKind::kBlas:
    151       return factories.blas.find(plugin_id) != factories.blas.end();
    152     case PluginKind::kDnn:
    153       return factories.dnn.find(plugin_id) != factories.dnn.end();
    154     case PluginKind::kFft:
    155       return factories.fft.find(plugin_id) != factories.fft.end();
    156     case PluginKind::kRng:
    157       return factories.rng.find(plugin_id) != factories.rng.end();
    158     default:
    159       LOG(ERROR) << "Invalid plugin kind specified: "
    160                  << PluginKindString(plugin_kind);
    161       return false;
    162   }
    163 }
    164 
    165 bool PluginRegistry::HasFactory(Platform::Id platform_id,
    166                                 PluginKind plugin_kind,
    167                                 PluginId plugin_id) const {
    168   auto iter = factories_.find(platform_id);
    169   if (iter != factories_.end()) {
    170     if (HasFactory(iter->second, plugin_kind, plugin_id)) {
    171       return true;
    172     }
    173   }
    174 
    175   return HasFactory(generic_factories_, plugin_kind, plugin_id);
    176 }
    177 
    178 // Explicit instantiations to support types exposed in user/public API.
    179 #define EMIT_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE, FACTORY_VAR, PLUGIN_STRING) \
    180   template port::StatusOr<PluginRegistry::FACTORY_TYPE>                       \
    181   PluginRegistry::GetFactoryInternal<PluginRegistry::FACTORY_TYPE>(           \
    182       PluginId plugin_id,                                                     \
    183       const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& factories,      \
    184       const std::map<PluginId, PluginRegistry::FACTORY_TYPE>&                 \
    185           generic_factories) const;                                           \
    186                                                                               \
    187   template port::Status                                                       \
    188   PluginRegistry::RegisterFactoryInternal<PluginRegistry::FACTORY_TYPE>(      \
    189       PluginId plugin_id, const string& plugin_name,                          \
    190       PluginRegistry::FACTORY_TYPE factory,                                   \
    191       std::map<PluginId, PluginRegistry::FACTORY_TYPE>* factories);           \
    192                                                                               \
    193   template <>                                                                 \
    194   port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
    195       Platform::Id platform_id, PluginId plugin_id, const string& name,       \
    196       PluginRegistry::FACTORY_TYPE factory) {                                 \
    197     return RegisterFactoryInternal(plugin_id, name, factory,                  \
    198                                    &factories_[platform_id].FACTORY_VAR);     \
    199   }                                                                           \
    200                                                                               \
    201   template <>                                                                 \
    202   port::Status PluginRegistry::RegisterFactoryForAllPlatforms<                \
    203       PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, const string& name,   \
    204                                     PluginRegistry::FACTORY_TYPE factory) {   \
    205     return RegisterFactoryInternal(plugin_id, name, factory,                  \
    206                                    &generic_factories_.FACTORY_VAR);          \
    207   }                                                                           \
    208                                                                               \
    209   template <>                                                                 \
    210   port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory(    \
    211       Platform::Id platform_id, PluginId plugin_id) {                         \
    212     if (plugin_id == PluginConfig::kDefault) {                                \
    213       plugin_id = default_factories_[platform_id].FACTORY_VAR;                \
    214                                                                               \
    215       if (plugin_id == kNullPlugin) {                                         \
    216         return port::Status{port::error::FAILED_PRECONDITION,                 \
    217                             "No suitable " PLUGIN_STRING                      \
    218                             " plugin registered. Have you linked in a "       \
    219                             PLUGIN_STRING "-providing plugin?"};              \
    220       } else {                                                                \
    221         VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, "             \
    222                 << plugin_names_[plugin_id];                                  \
    223       }                                                                       \
    224     }                                                                         \
    225     return GetFactoryInternal(plugin_id, factories_[platform_id].FACTORY_VAR, \
    226                               generic_factories_.FACTORY_VAR);                \
    227   }                                                                           \
    228                                                                               \
    229   /* TODO(b/22689637): Also temporary WRT MultiPlatformManager */             \
    230   template <>                                                                 \
    231   port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory(    \
    232       PlatformKind platform_kind, PluginId plugin_id) {                       \
    233     auto iter = platform_id_by_kind_.find(platform_kind);                     \
    234     if (iter == platform_id_by_kind_.end()) {                                 \
    235       return port::Status{port::error::FAILED_PRECONDITION,                   \
    236                           port::Printf("Platform kind %d not registered.",    \
    237                                        static_cast<int>(platform_kind))};     \
    238     }                                                                         \
    239     return GetFactory<PluginRegistry::FACTORY_TYPE>(iter->second, plugin_id); \
    240   }
    241 
    242 EMIT_PLUGIN_SPECIALIZATIONS(BlasFactory, blas, "BLAS");
    243 EMIT_PLUGIN_SPECIALIZATIONS(DnnFactory, dnn, "DNN");
    244 EMIT_PLUGIN_SPECIALIZATIONS(FftFactory, fft, "FFT");
    245 EMIT_PLUGIN_SPECIALIZATIONS(RngFactory, rng, "RNG");
    246 
    247 }  // namespace gputools
    248 }  // namespace perftools
    249