1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_ 17 #define TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_ 18 19 #include <map> 20 21 #include "tensorflow/stream_executor/blas.h" 22 #include "tensorflow/stream_executor/dnn.h" 23 #include "tensorflow/stream_executor/fft.h" 24 #include "tensorflow/stream_executor/lib/status.h" 25 #include "tensorflow/stream_executor/lib/statusor.h" 26 #include "tensorflow/stream_executor/platform.h" 27 #include "tensorflow/stream_executor/platform/mutex.h" 28 #include "tensorflow/stream_executor/plugin.h" 29 #include "tensorflow/stream_executor/rng.h" 30 31 namespace perftools { 32 namespace gputools { 33 34 namespace internal { 35 class StreamExecutorInterface; 36 } 37 38 // The PluginRegistry is a singleton that maintains the set of registered 39 // "support library" plugins. Currently, there are four kinds of plugins: 40 // BLAS, DNN, FFT, and RNG. Each interface is defined in the corresponding 41 // gpu_{kind}.h header. 42 // 43 // At runtime, a StreamExecutor object will query the singleton registry to 44 // retrieve the plugin kind that StreamExecutor was configured with (refer to 45 // the StreamExecutor and PluginConfig declarations). 46 // 47 // Plugin libraries are best registered using REGISTER_MODULE_INITIALIZER, 48 // but can be registered at any time. When registering a DSO-backed plugin, it 49 // is usually a good idea to load the DSO at registration time, to prevent 50 // late-loading from distorting performance/benchmarks as much as possible. 51 class PluginRegistry { 52 public: 53 typedef blas::BlasSupport* (*BlasFactory)(internal::StreamExecutorInterface*); 54 typedef dnn::DnnSupport* (*DnnFactory)(internal::StreamExecutorInterface*); 55 typedef fft::FftSupport* (*FftFactory)(internal::StreamExecutorInterface*); 56 typedef rng::RngSupport* (*RngFactory)(internal::StreamExecutorInterface*); 57 58 // Gets (and creates, if necessary) the singleton PluginRegistry instance. 59 static PluginRegistry* Instance(); 60 61 // Registers the specified factory with the specified platform. 62 // Returns a non-successful status if the factory has already been registered 63 // with that platform (but execution should be otherwise unaffected). 64 template <typename FactoryT> 65 port::Status RegisterFactory(Platform::Id platform_id, PluginId plugin_id, 66 const string& name, FactoryT factory); 67 68 // Registers the specified factory as usable by _all_ platform types. 69 // Reports errors just as RegisterFactory. 70 template <typename FactoryT> 71 port::Status RegisterFactoryForAllPlatforms(PluginId plugin_id, 72 const string& name, 73 FactoryT factory); 74 75 // TODO(b/22689637): Setter for temporary mapping until all users are using 76 // MultiPlatformManager / PlatformId. 77 void MapPlatformKindToId(PlatformKind platform_kind, 78 Platform::Id platform_id); 79 80 // Potentially sets the plugin identified by plugin_id to be the default 81 // for the specified platform and plugin kind. If this routine is called 82 // multiple types for the same PluginKind, the PluginId given in the last call 83 // will be used. 84 bool SetDefaultFactory(Platform::Id platform_id, PluginKind plugin_kind, 85 PluginId plugin_id); 86 87 // Return true if the factory/id has been registered for the 88 // specified platform and plugin kind and false otherwise. 89 bool HasFactory(Platform::Id platform_id, PluginKind plugin_kind, 90 PluginId plugin) const; 91 92 // Retrieves the factory registered for the specified kind, 93 // or a port::Status on error. 94 template <typename FactoryT> 95 port::StatusOr<FactoryT> GetFactory(Platform::Id platform_id, 96 PluginId plugin_id); 97 98 // TODO(b/22689637): Deprecated/temporary. Will be deleted once all users are 99 // on MultiPlatformManager / PlatformId. 100 template <typename FactoryT> 101 port::StatusOr<FactoryT> GetFactory(PlatformKind platform_kind, 102 PluginId plugin_id); 103 104 private: 105 // Containers for the sets of registered factories, by plugin kind. 106 struct PluginFactories { 107 std::map<PluginId, BlasFactory> blas; 108 std::map<PluginId, DnnFactory> dnn; 109 std::map<PluginId, FftFactory> fft; 110 std::map<PluginId, RngFactory> rng; 111 }; 112 113 // Simple structure to hold the currently configured default plugins (for a 114 // particular Platform). 115 struct DefaultFactories { 116 DefaultFactories(); 117 PluginId blas, dnn, fft, rng; 118 }; 119 120 PluginRegistry(); 121 122 // Actually performs the work of registration. 123 template <typename FactoryT> 124 port::Status RegisterFactoryInternal(PluginId plugin_id, 125 const string& plugin_name, 126 FactoryT factory, 127 std::map<PluginId, FactoryT>* factories); 128 129 // Actually performs the work of factory retrieval. 130 template <typename FactoryT> 131 port::StatusOr<FactoryT> GetFactoryInternal( 132 PluginId plugin_id, const std::map<PluginId, FactoryT>& factories, 133 const std::map<PluginId, FactoryT>& generic_factories) const; 134 135 // Returns true if the specified plugin has been registered with the specified 136 // platform factories. Unlike the other overload of this method, this does 137 // not implicitly examine the default factory lists. 138 bool HasFactory(const PluginFactories& factories, PluginKind plugin_kind, 139 PluginId plugin) const; 140 141 // The singleton itself. 142 static PluginRegistry* instance_; 143 144 // TODO(b/22689637): Temporary mapping until all users are using 145 // MultiPlatformManager / PlatformId. 146 std::map<PlatformKind, Platform::Id> platform_id_by_kind_; 147 148 // The set of registered factories, keyed by platform ID. 149 std::map<Platform::Id, PluginFactories> factories_; 150 151 // Plugins supported for all platform kinds. 152 PluginFactories generic_factories_; 153 154 // The sets of default factories, keyed by platform ID. 155 std::map<Platform::Id, DefaultFactories> default_factories_; 156 157 // Lookup table for plugin names. 158 std::map<PluginId, string> plugin_names_; 159 160 SE_DISALLOW_COPY_AND_ASSIGN(PluginRegistry); 161 }; 162 163 } // namespace gputools 164 } // namespace perftools 165 166 #endif // TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_ 167