Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
     17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
     18 
     19 #include <functional>
     20 #include <memory>
     21 #include <set>
     22 #include <unordered_map>
     23 #include <vector>
     24 
     25 #include "tensorflow/core/common_runtime/device_factory.h"
     26 #include "tensorflow/core/common_runtime/local_device.h"
     27 #include "tensorflow/core/framework/device_base.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/types.pb.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/platform/mem.h"
     32 #include "tensorflow/core/platform/mutex.h"
     33 #include "tensorflow/core/platform/thread_annotations.h"
     34 #include "tensorflow/core/public/session_options.h"
     35 
     36 namespace tensorflow {
     37 
     38 // Names of the XLA compilation devices. These are not user-visible, and are
     39 // used internally by the Tensorflow/XLA bridge to perform symbolic execution of
     40 // a Tensorflow graph.
     41 
     42 extern const char* const DEVICE_CPU_XLA_JIT;  // "CPU_XLA_JIT"
     43 extern const char* const DEVICE_GPU_XLA_JIT;  // "GPU_XLA_JIT"
     44 
     45 extern const char* const DEVICE_XLA_CPU;
     46 extern const char* const DEVICE_XLA_GPU;
     47 
     48 constexpr std::array<DataType, 4> kFloatTypes = {
     49     {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
     50 constexpr std::array<DataType, 9> kNumericTypes = {
     51     {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
     52      DT_COMPLEX64, DT_BFLOAT16}};
     53 
     54 constexpr std::array<DataType, 8> kCpuAllTypes = {
     55     {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
     56      DT_COMPLEX64, DT_BOOL}};
     57 
     58 constexpr std::array<DataType, 8> kGpuAllTypes = {
     59     {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
     60      DT_COMPLEX64, DT_BOOL}};
     61 
     62 // Class that manages registrations of operators and devices for the XLA JIT.
     63 // Not thread-safe.
     64 class XlaOpRegistry {
     65  public:
     66   typedef OpKernel* (*Factory)(OpKernelConstruction*);
     67 
     68   // Describes how to compile operators assigned to a device.
     69   struct DeviceRegistration {
     70     // The name of the an XLA compilation device to use to compile code.
     71     string compilation_device_name;
     72 
     73     // Do operators assigned to this device require compilation?
     74     bool requires_compilation;
     75 
     76     // If !requires_compilation, should we try to JIT operators on this device
     77     // when XLA JIT compilation is enabled globally via the SessionOptions?
     78     // (It is still possible to explicitly mark operators to JIT compile, even
     79     // if enable_jit_by_default is false.)
     80     bool enable_jit_by_default;
     81 
     82     // Enable compilation of operators that use DT_RESOURCE types?
     83     bool compile_resource_ops = false;
     84   };
     85 
     86   // Registers an XLA backend. `compilation_device_name` is the name of the
     87   // device used for symbolic execution during compilation. `supported_types`
     88   // is the list of non-resource types supported by the device. Each operators
     89   // will be registered for the intersection of the operator's supported types
     90   // and the device's supported types. `backend_op_filter` is a function used
     91   // to exclude or modify operator registrations on the device; it may be
     92   // nullptr, in which case all ops are included.
     93   // `backend_op_filter` should return true if the op should be registered on
     94   // the device; it may optionally modify the KernelDef.
     95   typedef bool (*BackendOpFilter)(KernelDef* kdef);
     96   static void RegisterBackend(const string& compilation_device_name,
     97                               gtl::ArraySlice<DataType> supported_types,
     98                               BackendOpFilter op_filter);
     99 
    100   // Returns the names of the registered backends.
    101   static std::vector<string> BackendNames();
    102 
    103   // Returns true iff a backend with the given name is registered.
    104   static bool IsBackendRegistered(const string& name);
    105 
    106   // Registers `device_name` for XLA compilation, using information from
    107   // `registration`.
    108   static void RegisterCompilationDevice(const string& device_name,
    109                                         const DeviceRegistration& registration);
    110 
    111   // Returns the JIT device name associated with 'device_name', setting
    112   // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they
    113   // are not null. Returns false and leaves the outputs unchanged if no matching
    114   // JIT device is registered.
    115   // '*enable_jit_by_default' is set to true if we should try to JIT using this
    116   // device when the JIT is enabled via the Session OptimizerOptions.
    117   static bool GetCompilationDevice(const string& device_name,
    118                                    const DeviceRegistration** registration);
    119 
    120   // Registers all JIT kernels on JIT devices, if not already registered.
    121   // Does nothing otherwise.
    122   static void RegisterCompilationKernels();
    123 
    124   // Returns KernelDefs for compilation ops registered on
    125   // 'compilation_device_name'.  Does not include kernels registered as
    126   // CompilationOnly, iff include_compilation_only_kernels=false.
    127   static std::vector<const KernelDef*> DeviceKernels(
    128       const string& compilation_device_name,
    129       bool include_compilation_only_kernels);
    130 
    131   // Returns the set of compile-time constant inputs to 'op'. Returns nullptr
    132   // if the op is not registered.
    133   static const std::unordered_set<string>* CompileTimeConstantInputs(
    134       const string& op);
    135 
    136  private:
    137   friend class XlaBackendRegistrar;
    138   friend class XlaOpRegistrar;
    139   friend class XlaOpRegistrationBuilder;
    140 
    141   static XlaOpRegistry& Instance();
    142 
    143   XlaOpRegistry();
    144   ~XlaOpRegistry();
    145 
    146   mutex mutex_;
    147 
    148   // Describes an XLA backend.
    149   struct Backend {
    150     // Which types are supported by this device?
    151     std::set<DataType> supported_types;
    152 
    153     // The per-backend operator filter function. See the comment on
    154     // RegisterBackend() for details.
    155     BackendOpFilter op_filter;
    156 
    157     // KernelDefs built by RegisterCompilationKernels() for each op supported
    158     // by the device.
    159     std::vector<std::unique_ptr<KernelDef>> kernel_defs;
    160   };
    161 
    162   // Map from compilation device names to a description of the backend.
    163   std::unordered_map<string, Backend> backends_ GUARDED_BY(mutex_);
    164 
    165   // Map from Tensorflow device names to the corresponding JIT device metadata.
    166   std::unordered_map<string, DeviceRegistration> compilation_devices_
    167       GUARDED_BY(mutex_);
    168 
    169   // A description of a Tensorflow operator that can be compiled to XLA.
    170   struct OpRegistration {
    171     string name;
    172 
    173     // Should this operator be registered only on compilation devices, without a
    174     // dummy kernel registered on the corresponding XLA device?
    175     bool compilation_only = false;
    176 
    177     // Should we allow resource types for type attributes? Used by _Arg to
    178     // allow DT_RESOURCE.
    179     bool allow_resource_types = false;
    180 
    181     // Mapping from attribute name to a list of supported types.
    182     std::unordered_map<string, std::set<DataType>> type_constraints;
    183 
    184     // An optional whitelist of devices. If there is no whitelist, all devices
    185     // are permitted.
    186     bool has_device_whitelist = false;
    187     std::unordered_set<string> device_whitelist;
    188 
    189     // Names of arguments that must be compile-time constants.
    190     std::unordered_set<string> compile_time_constant_inputs;
    191 
    192     // Factory used to build OpKernels that perform symbolic execution.
    193     Factory factory;
    194   };
    195 
    196   // Returns true if registrations x and y can both be added to the registry.
    197   // This is always the case if they refer to different ops. If they refer to
    198   // the same op name, they must: have the same values for compilation_only and
    199   // allow_resource_types; use a device_whitelist; and their
    200   // whitelists must not intersect.
    201   static bool IsCompatible(const OpRegistration& x, const OpRegistration& y);
    202 
    203   // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP.
    204   // Registrations present under the same key must satisfy IsCompatible above,
    205   // and this is checked during registration.
    206   std::unordered_multimap<string, std::unique_ptr<OpRegistration>> ops_
    207       GUARDED_BY(mutex_);
    208 
    209   // Have we already registered the JIT kernels on the JIT devices?
    210   bool jit_kernels_registered_ = false;
    211 
    212   // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel
    213   // registrations created by RegisterCompilationKernels() and
    214   // RegisterDeviceKernels().
    215   std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
    216       kernel_registrars_ GUARDED_BY(mutex_);
    217 };
    218 
    219 // REGISTER_XLA_OP() registers an XLA OpKernel by name, for example:
    220 // REGISTER_XLA_OP(Name("Add"), AddOp);
    221 // where 'AddOp' is the name of a JIT OpKernel class that implements "Add".
    222 //
    223 // We don't use a variadic macro here because we don't expect JIT operators to
    224 // be templated.
    225 
    226 #define REGISTER_XLA_OP(NAME, OP) \
    227   REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP)
    228 
    229 class XlaOpRegistrationBuilder {
    230  public:
    231   // Starts an operator registration chain.
    232   static XlaOpRegistrationBuilder Name(StringPiece name);
    233 
    234   // Specifies a whitelist of devices on which the operator may run.
    235   XlaOpRegistrationBuilder& Device(StringPiece devices);
    236   XlaOpRegistrationBuilder& Device(gtl::ArraySlice<StringPiece> devices);
    237 
    238   // Specifies a type constraint for a type variable attribute. Each constraint
    239   // specifies the set of types that the type variable may assume.
    240   XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
    241                                            DataType allowed);
    242 
    243   XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
    244                                            gtl::ArraySlice<DataType> allowed);
    245 
    246   // Specifies that a dummy copy of this operator should not be registered on
    247   // XLA_* devices, but may be used during compilation.
    248   XlaOpRegistrationBuilder& CompilationOnly();
    249 
    250   // Allow DT_RESOURCE types for type parameters.
    251   XlaOpRegistrationBuilder& AllowResourceTypes();
    252 
    253   // Mark 'input_name' as an argument whose value must be known at compile-time.
    254   XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name);
    255 
    256   std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
    257       XlaOpRegistry::Factory factory);
    258 
    259  private:
    260   XlaOpRegistrationBuilder(StringPiece name);
    261 
    262   std::unique_ptr<XlaOpRegistry::OpRegistration> registration_;
    263 };
    264 
    265 // REGISTER_XLA_BACKEND() registers an XLA backend. Example usage:
    266 // REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter);
    267 #define REGISTER_XLA_BACKEND(NAME, ...) \
    268   REGISTER_XLA_BACKEND_UNIQ_HELPER(__COUNTER__, NAME, __VA_ARGS__)
    269 
    270 // Implementation details.
    271 
    272 class XlaOpRegistrar {
    273  public:
    274   XlaOpRegistrar(std::unique_ptr<XlaOpRegistry::OpRegistration> registration);
    275 };
    276 
    277 #define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, BUILDER, OP) \
    278   REGISTER_XLA_OP_UNIQ(COUNTER, BUILDER, OP)
    279 
    280 #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP)                                 \
    281   static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \
    282       XlaOpRegistrationBuilder::BUILDER.Build(                                 \
    283           [](::tensorflow::OpKernelConstruction* context)                      \
    284               -> ::tensorflow::OpKernel* { return new OP(context); }));
    285 
    286 class XlaBackendRegistrar {
    287  public:
    288   XlaBackendRegistrar(StringPiece name, gtl::ArraySlice<DataType> types,
    289                       XlaOpRegistry::BackendOpFilter op_filter = nullptr);
    290 };
    291 
    292 #define REGISTER_XLA_BACKEND_UNIQ_HELPER(COUNTER, NAME, ...) \
    293   REGISTER_XLA_BACKEND_UNIQ(COUNTER, NAME, __VA_ARGS__)
    294 
    295 #define REGISTER_XLA_BACKEND_UNIQ(CTR, NAME, ...) \
    296   static ::tensorflow::XlaBackendRegistrar        \
    297       xla_backend_registrar__body__##CTR##__object(NAME, __VA_ARGS__);
    298 
    299 }  // namespace tensorflow
    300 
    301 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
    302