Home | History | Annotate | Download | only in stream_executor
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_
     17 #define TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_
     18 
     19 namespace perftools {
     20 namespace gputools {
     21 
     22 // A plugin ID is a unique identifier for each registered plugin type.
     23 typedef void* PluginId;
     24 
     25 // Helper macro to define a plugin ID. To be used only inside plugin
     26 // implementation files. Works by "reserving" an address/value (guaranteed to be
     27 // unique) inside a process space.
     28 #define PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(ID_VAR_NAME) \
     29   namespace {                                         \
     30   int plugin_id_value;                                \
     31   }                                                   \
     32   const PluginId ID_VAR_NAME = &plugin_id_value;
     33 
     34 // kNullPlugin denotes an invalid plugin identifier.
     35 extern const PluginId kNullPlugin;
     36 
     37 // Enumeration to list the supported types of plugins / support libraries.
     38 enum class PluginKind {
     39   kInvalid,
     40   kBlas,
     41   kDnn,
     42   kFft,
     43   kRng,
     44 };
     45 
     46 // A PluginConfig describes the set of plugins to be used by a StreamExecutor
     47 // instance. Each plugin is defined by an arbitrary identifier, usually best set
     48 // to the address static member in the implementation (to avoid conflicts).
     49 //
     50 // A PluginConfig may be passed to the StreamExecutor constructor - the plugins
     51 // described therein will be used to provide BLAS, DNN, FFT, and RNG
     52 // functionality. Platform-appropriate defaults will be used for any un-set
     53 // libraries. If a platform does not support a specified plugin (ex. cuBLAS on
     54 // an OpenCL executor), then an error will be logged and no plugin operations
     55 // will succeed.
     56 //
     57 // The StreamExecutor BUILD target does not link ANY plugin libraries - even
     58 // common host fallbacks! Any plugins must be explicitly linked by dependent
     59 // targets. See the cuda, opencl and host BUILD files for implemented plugin
     60 // support (search for "plugin").
     61 class PluginConfig {
     62  public:
     63   // Value specifying the platform's default option for that plugin.
     64   static const PluginId kDefault;
     65 
     66   // Initializes all members to the default options.
     67   PluginConfig();
     68 
     69   bool operator==(const PluginConfig& rhs) const;
     70 
     71   // Sets the appropriate library kind to that passed in.
     72   PluginConfig& SetBlas(PluginId blas);
     73   PluginConfig& SetDnn(PluginId dnn);
     74   PluginConfig& SetFft(PluginId fft);
     75   PluginConfig& SetRng(PluginId rng);
     76 
     77   PluginId blas() const { return blas_; }
     78   PluginId dnn() const { return dnn_; }
     79   PluginId fft() const { return fft_; }
     80   PluginId rng() const { return rng_; }
     81 
     82  private:
     83   PluginId blas_, dnn_, fft_, rng_;
     84 };
     85 
     86 }  // namespace gputools
     87 }  // namespace perftools
     88 
     89 #endif  // TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_
     90