Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
     17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
     18 
     19 #include <functional>
     20 #include <memory>
     21 #include <set>
     22 #include <unordered_map>
     23 #include <vector>
     24 
     25 #include "tensorflow/core/common_runtime/device_factory.h"
     26 #include "tensorflow/core/common_runtime/local_device.h"
     27 #include "tensorflow/core/framework/device_base.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/types.pb.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/platform/mem.h"
     32 #include "tensorflow/core/platform/mutex.h"
     33 #include "tensorflow/core/platform/thread_annotations.h"
     34 #include "tensorflow/core/public/session_options.h"
     35 
     36 namespace tensorflow {
     37 
     38 // Names of the XLA compilation devices. These are not user-visible, and are
     39 // used internally by the Tensorflow/XLA bridge to perform symbolic execution of
     40 // a Tensorflow graph.
     41 
     42 extern const char* const DEVICE_CPU_XLA_JIT;  // "CPU_XLA_JIT"
     43 extern const char* const DEVICE_GPU_XLA_JIT;  // "GPU_XLA_JIT"
     44 
     45 extern const char* const DEVICE_XLA_CPU;
     46 extern const char* const DEVICE_XLA_GPU;
     47 
     48 constexpr std::array<DataType, 4> kFloatTypes = {
     49     {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
     50 constexpr std::array<DataType, 12> kNumericTypes = {
     51     {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
     52      DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}};
     53 
     54 constexpr std::array<DataType, 16> kCpuAllTypes = {
     55     {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
     56      DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
     57      DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
     58 
     59 constexpr std::array<DataType, 15> kGpuAllTypes = {
     60     {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
     61      DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL,
     62      DT_BFLOAT16}};
     63 
     64 // Class that manages registrations of operators and devices for the XLA JIT.
     65 // Not thread-safe.
     66 class XlaOpRegistry {
     67  public:
     68   typedef OpKernel* (*Factory)(OpKernelConstruction*);
     69 
     70   enum class AutoclusteringPolicy {
     71     // Enable autoclustering if the user requests it, e.g., via
     72     // experimental_jit_scope. Does not autocluster if the JIT is enabled
     73     // globally (e.g., via the OptimizerOptions in the TF session
     74     // configuration.)
     75     kIfExplicitlyRequested,
     76     // Enable autoclustering if explicitly requested, or if the JIT is enabled
     77     // globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N.
     78     kIfEnabledGlobally,
     79     // Always try to autocluster ops placed on this device.
     80     kAlways,
     81   };
     82 
     83   // Describes how to compile operators assigned to a device.
     84   struct DeviceRegistration {
     85     // The name of the an XLA compilation device to use to compile code.
     86     string compilation_device_name;
     87 
     88     // When should we autocluster operators assigned to this device?
     89     AutoclusteringPolicy autoclustering_policy;
     90 
     91     // Enable compilation of operators that use DT_RESOURCE types?
     92     bool compile_all_resource_ops = false;
     93   };
     94 
     95   // Registers an XLA backend. `compilation_device_name` is the name of the
     96   // device used for symbolic execution during compilation. `supported_types`
     97   // is the list of non-resource types supported by the device. Each operators
     98   // will be registered for the intersection of the operator's supported types
     99   // and the device's supported types. `backend_op_filter` is a function used
    100   // to exclude or modify operator registrations on the device; it may be
    101   // nullptr, in which case all ops are included.
    102   // `backend_op_filter` should return true if the op should be registered on
    103   // the device; it may optionally modify the KernelDef.
    104   typedef bool (*BackendOpFilter)(KernelDef* kdef);
    105   static void RegisterBackend(const string& compilation_device_name,
    106                               absl::Span<const DataType> supported_types,
    107                               BackendOpFilter op_filter);
    108 
    109   // Returns the names of the registered backends.
    110   static std::vector<string> BackendNames();
    111 
    112   // Returns true iff a backend with the given name is registered.
    113   static bool IsBackendRegistered(const string& name);
    114 
    115   // Registers `device_name` for XLA compilation, using information from
    116   // `registration`.
    117   // Does nothing if a registration for `device_name` already exists.
    118   static void RegisterCompilationDevice(const string& device_name,
    119                                         const DeviceRegistration& registration);
    120 
    121   // Returns the JIT device name associated with 'device_name', setting
    122   // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they
    123   // are not null. Returns false and leaves the outputs unchanged if no matching
    124   // JIT device is registered.
    125   // '*enable_jit_by_default' is set to true if we should try to JIT using this
    126   // device when the JIT is enabled via the Session OptimizerOptions.
    127   static bool GetCompilationDevice(const string& device_name,
    128                                    const DeviceRegistration** registration);
    129 
    130   // Registers all JIT kernels on JIT devices, if not already registered.
    131   // Does nothing otherwise.
    132   static void RegisterCompilationKernels();
    133 
    134   // Returns KernelDefs for compilation ops registered on
    135   // 'compilation_device_name'.  Does not include kernels registered as
    136   // CompilationOnly, iff include_compilation_only_kernels=false.
    137   static std::vector<const KernelDef*> DeviceKernels(
    138       const string& compilation_device_name,
    139       bool include_compilation_only_kernels);
    140 
    141   // Returns all operations for which there are XLA kernels on any device.
    142   static std::vector<string> GetAllRegisteredOps();
    143 
    144   // Returns (via `result`) the indices of inputs to `node_def` that must be
    145   // compile-time constants. Returns an empty vector if the op is not
    146   // registered.
    147   //
    148   // `result` is sorted.
    149   static Status CompileTimeConstantInputs(const NodeDef& node_def,
    150                                           const OpDef& op_def,
    151                                           std::vector<int>* result) {
    152     return CompileTimeConstantInputs(node_def, /*op_kernel=*/nullptr, &op_def,
    153                                      result);
    154   }
    155 
    156   // Returns (via `result`) the indices of inputs to `op_kernel` that must be
    157   // compile-time constants.
    158   //
    159   // `result` is sorted.
    160   static Status CompileTimeConstantInputs(const OpKernel& op_kernel,
    161                                           std::vector<int>* result) {
    162     return CompileTimeConstantInputs(op_kernel.def(), /*op_kernel=*/&op_kernel,
    163                                      /*op_def=*/nullptr, result);
    164   }
    165 
    166   // Returns true if `op` is a "metadata" op, one that only looks at the shapes
    167   // of its operands and not their values.
    168   static bool IsMetadataOp(const string& op);
    169 
    170  private:
    171   friend class XlaBackendRegistrar;
    172   friend class XlaOpRegistrar;
    173   friend class XlaOpRegistrationBuilder;
    174 
    175   static XlaOpRegistry& Instance();
    176 
    177   XlaOpRegistry();
    178   ~XlaOpRegistry();
    179 
    180   mutex mutex_;
    181 
    182   // Describes an XLA backend.
    183   struct Backend {
    184     // Which types are supported by this device?
    185     std::set<DataType> supported_types;
    186 
    187     // The per-backend operator filter function. See the comment on
    188     // RegisterBackend() for details.
    189     BackendOpFilter op_filter;
    190 
    191     // KernelDefs built by RegisterCompilationKernels() for each op supported
    192     // by the device.
    193     std::vector<std::unique_ptr<KernelDef>> kernel_defs;
    194   };
    195 
    196   // Map from compilation device names to a description of the backend.
    197   std::unordered_map<string, Backend> backends_ GUARDED_BY(mutex_);
    198 
    199   // Map from Tensorflow device names to the corresponding JIT device metadata.
    200   std::unordered_map<string, DeviceRegistration> compilation_devices_
    201       GUARDED_BY(mutex_);
    202 
    203   // A description of a Tensorflow operator that can be compiled to XLA.
    204   struct OpRegistration {
    205     string name;
    206 
    207     // Should this operator be registered only on compilation devices, without a
    208     // dummy kernel registered on the corresponding XLA device?
    209     bool compilation_only = false;
    210 
    211     // Should we allow resource types for type attributes? Used by _Arg to
    212     // allow DT_RESOURCE.
    213     bool allow_resource_types = false;
    214 
    215     // Should we allow variant types for type attributes? Used by While to
    216     // allow TensorList which is of type DT_VARIANT.
    217     bool allow_variant_types = false;
    218 
    219     // Mapping from attribute name to a list of supported types.
    220     std::unordered_map<string, std::set<DataType>> type_constraints;
    221 
    222     // An optional whitelist of devices. If there is no whitelist, all devices
    223     // are permitted.
    224     bool has_device_whitelist = false;
    225     std::unordered_set<string> device_whitelist;
    226 
    227     // Names of arguments that must be compile-time constants.
    228     std::unordered_set<string> compile_time_constant_inputs;
    229 
    230     // True if this is a "metadata" op, one that only looks at the shapes of its
    231     // operands and not their values.
    232     bool is_metadata_op = false;
    233 
    234     // Factory used to build OpKernels that perform symbolic execution.
    235     Factory factory;
    236   };
    237 
    238   // Returns true if registrations x and y can both be added to the registry.
    239   // This is always the case if they refer to different ops. If they refer to
    240   // the same op name, they must: have the same values for compilation_only,
    241   // allow_resource_types and allow_variant_types; use a device_whitelist; and
    242   // their whitelists must not intersect.
    243   static bool IsCompatible(const OpRegistration& x, const OpRegistration& y);
    244 
    245   static Status CompileTimeConstantInputs(const NodeDef& node_def,
    246                                           const OpKernel* op_kernel,
    247                                           const OpDef* op_def,
    248                                           std::vector<int>* result);
    249 
    250   // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP.
    251   // Registrations present under the same key must satisfy IsCompatible above,
    252   // and this is checked during registration.
    253   std::unordered_map<string, std::vector<std::unique_ptr<OpRegistration>>> ops_
    254       GUARDED_BY(mutex_);
    255 
    256   // Have we already registered the JIT kernels on the JIT devices?
    257   bool jit_kernels_registered_ = false;
    258 
    259   // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel
    260   // registrations created by RegisterCompilationKernels() and
    261   // RegisterDeviceKernels().
    262   std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
    263       kernel_registrars_ GUARDED_BY(mutex_);
    264 };
    265 
    266 // REGISTER_XLA_OP() registers an XLA OpKernel by name, for example:
    267 // REGISTER_XLA_OP(Name("Add"), AddOp);
    268 // where 'AddOp' is the name of a JIT OpKernel class that implements "Add".
    269 //
    270 // We don't use a variadic macro here because we don't expect JIT operators to
    271 // be templated.
    272 
    273 #define REGISTER_XLA_OP(NAME, OP) \
    274   REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP)
    275 
    276 class XlaOpRegistrationBuilder {
    277  public:
    278   // Starts an operator registration chain.
    279   static XlaOpRegistrationBuilder Name(absl::string_view name);
    280 
    281   // Specifies a whitelist of devices on which the operator may run.
    282   XlaOpRegistrationBuilder& Device(absl::string_view devices);
    283   XlaOpRegistrationBuilder& Device(absl::Span<const absl::string_view> devices);
    284 
    285   // Specifies a type constraint for a type variable attribute. Each constraint
    286   // specifies the set of types that the type variable may assume.
    287   XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name,
    288                                            DataType allowed);
    289 
    290   XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name,
    291                                            absl::Span<const DataType> allowed);
    292 
    293   // Specifies that a dummy copy of this operator should not be registered on
    294   // XLA_* devices, but may be used during compilation.
    295   XlaOpRegistrationBuilder& CompilationOnly();
    296 
    297   // Allow DT_RESOURCE types for type parameters.
    298   XlaOpRegistrationBuilder& AllowResourceTypes();
    299 
    300   // Allow DT_VARIANT types for type parameters.
    301   XlaOpRegistrationBuilder& AllowVariantTypes();
    302 
    303   // Mark 'input_name' as an argument whose value must be known at compile-time.
    304   XlaOpRegistrationBuilder& CompileTimeConstantInput(
    305       absl::string_view input_name);
    306 
    307   // Mark this op as a "metadata" op, one that only looks at the shapes of its
    308   // operands and not their values.
    309   XlaOpRegistrationBuilder& IsMetadataOp();
    310 
    311   std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
    312       XlaOpRegistry::Factory factory);
    313 
    314  private:
    315   XlaOpRegistrationBuilder(absl::string_view name);
    316 
    317   std::unique_ptr<XlaOpRegistry::OpRegistration> registration_;
    318 };
    319 
    320 // REGISTER_XLA_BACKEND() registers an XLA backend. Example usage:
    321 // REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter);
    322 #define REGISTER_XLA_BACKEND(NAME, ...) \
    323   REGISTER_XLA_BACKEND_UNIQ_HELPER(__COUNTER__, NAME, __VA_ARGS__)
    324 
    325 // Implementation details.
    326 
    327 class XlaOpRegistrar {
    328  public:
    329   XlaOpRegistrar(std::unique_ptr<XlaOpRegistry::OpRegistration> registration);
    330 };
    331 
    332 #define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, BUILDER, OP) \
    333   REGISTER_XLA_OP_UNIQ(COUNTER, BUILDER, OP)
    334 
    335 #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP)                                 \
    336   static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \
    337       ::tensorflow::XlaOpRegistrationBuilder::BUILDER.Build(                   \
    338           [](::tensorflow::OpKernelConstruction* context)                      \
    339               -> ::tensorflow::OpKernel* { return new OP(context); }));
    340 
    341 class XlaBackendRegistrar {
    342  public:
    343   XlaBackendRegistrar(absl::string_view name, absl::Span<const DataType> types,
    344                       XlaOpRegistry::BackendOpFilter op_filter = nullptr);
    345 };
    346 
    347 #define REGISTER_XLA_BACKEND_UNIQ_HELPER(COUNTER, NAME, ...) \
    348   REGISTER_XLA_BACKEND_UNIQ(COUNTER, NAME, __VA_ARGS__)
    349 
    350 #define REGISTER_XLA_BACKEND_UNIQ(CTR, NAME, ...) \
    351   static ::tensorflow::XlaBackendRegistrar        \
    352       xla_backend_registrar__body__##CTR##__object(NAME, __VA_ARGS__);
    353 
    354 }  // namespace tensorflow
    355 
    356 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
    357