Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     17 
     18 #include <functional>
     19 #include <memory>
     20 
     21 #include "tensorflow/compiler/tf2xla/type_util.h"
     22 #include "tensorflow/compiler/tf2xla/xla_context.h"
     23 #include "tensorflow/compiler/xla/client/client_library.h"
     24 #include "tensorflow/core/common_runtime/device_factory.h"
     25 #include "tensorflow/core/common_runtime/local_device.h"
     26 #include "tensorflow/core/framework/device_base.h"
     27 #include "tensorflow/core/framework/kernel_def.pb.h"
     28 #include "tensorflow/core/framework/node_def.pb.h"
     29 #include "tensorflow/core/framework/op_def_util.h"
     30 #include "tensorflow/core/platform/mem.h"
     31 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     32 
     33 namespace tensorflow {
     34 
     35 const char* const DEVICE_CPU_XLA_JIT = "XLA_CPU_JIT";
     36 const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT";
     37 const char* const DEVICE_XLA_CPU = "XLA_CPU";
     38 const char* const DEVICE_XLA_GPU = "XLA_GPU";
     39 
     40 static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) {
     41   const OpDef* op_def;
     42   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def));
     43   NodeDef node_def;
     44   node_def.set_name("_XlaLaunch-op");
     45   node_def.set_op("_XlaLaunch");
     46   string kernel_class_name;
     47   TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr,
     48                                    &kernel_class_name));
     49   VLOG(1) << "LaunchOpHasKernelForDevice"
     50           << " kernel_class_name: " << kernel_class_name;
     51   return Status::OK();
     52 }
     53 
     54 XlaOpRegistry::XlaOpRegistry() = default;
     55 XlaOpRegistry::~XlaOpRegistry() = default;
     56 
     57 // TODO(b/64575122) consider adding more sophisticated definitions of
     58 // compatibility if needed by future use cases.
     59 /* static */ bool XlaOpRegistry::IsCompatible(const OpRegistration& x,
     60                                               const OpRegistration& y) {
     61   if (x.name != y.name) return true;
     62   // The registrations refer to the same Op: ensures they are compatible and
     63   // are restricted to different device whitelists.
     64   if (x.compilation_only != y.compilation_only) {
     65     LOG(WARNING) << "Registrations of " << x.name
     66                  << " have incompatible compilation_only settings.";
     67     return false;
     68   }
     69   if (x.allow_resource_types != y.allow_resource_types) {
     70     LOG(WARNING) << "Registrations of " << x.name
     71                  << " have incompatible allow_resource_types settings.";
     72     return false;
     73   }
     74   if (!x.has_device_whitelist || !y.has_device_whitelist) {
     75     LOG(WARNING) << "Registrations of " << x.name
     76                  << " do not both have device whitelists.";
     77     return false;
     78   }
     79   for (const auto& device : x.device_whitelist) {
     80     if (y.device_whitelist.count(device) != 0) {
     81       LOG(WARNING) << "Multiple registrations of " << x.name << " on device "
     82                    << device;
     83       return false;
     84     }
     85   }
     86   if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) {
     87     LOG(WARNING) << "Registrations of " << x.name
     88                  << " have incompatible compile time constant inputs.";
     89     return false;
     90   }
     91   return true;
     92 }
     93 
     94 /* static */ void XlaOpRegistry::RegisterCompilationDevice(
     95     const string& device_name, const DeviceRegistration& registration) {
     96   XlaOpRegistry& registry = Instance();
     97   mutex_lock lock(registry.mutex_);
     98   auto result =
     99       registry.compilation_devices_.emplace(device_name, registration);
    100   CHECK(result.second || result.first->second.compilation_device_name ==
    101                              registration.compilation_device_name);
    102 }
    103 
    104 /* static */ void XlaOpRegistry::RegisterBackend(
    105     const string& compilation_device_name,
    106     gtl::ArraySlice<DataType> supported_types, BackendOpFilter op_filter) {
    107   XlaOpRegistry& registry = Instance();
    108   mutex_lock lock(registry.mutex_);
    109   auto result = registry.backends_.emplace(compilation_device_name, Backend());
    110   CHECK(result.second) << "Duplicate XLA backend registration "
    111                        << compilation_device_name;
    112   result.first->second.supported_types.insert(supported_types.begin(),
    113                                               supported_types.end());
    114   result.first->second.op_filter = op_filter;
    115 }
    116 
    117 /* static */ bool XlaOpRegistry::GetCompilationDevice(
    118     const string& device_name, const DeviceRegistration** registration) {
    119   XlaOpRegistry& registry = Instance();
    120 
    121   // Lazily register the CPU and GPU JIT devices the first time
    122   // GetCompilationDevice is called.
    123   static void* registration_init = [&registry]() {
    124     mutex_lock lock(registry.mutex_);
    125     if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) {
    126       DeviceRegistration& registration =
    127           registry.compilation_devices_[DEVICE_CPU];
    128       registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
    129       registration.requires_compilation = false;
    130       registration.enable_jit_by_default = false;
    131       registration.compile_resource_ops = false;
    132     }
    133     if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) {
    134       DeviceRegistration& registration =
    135           registry.compilation_devices_[DEVICE_GPU];
    136       registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
    137       registration.requires_compilation = false;
    138       registration.enable_jit_by_default = true;
    139       registration.compile_resource_ops = false;
    140     }
    141     return nullptr;
    142   }();
    143   (void)registration_init;
    144 
    145   mutex_lock lock(registry.mutex_);
    146   auto it = registry.compilation_devices_.find(device_name);
    147   if (it == registry.compilation_devices_.end()) return false;
    148   *registration = &it->second;
    149   return true;
    150 }
    151 
    152 void XlaOpRegistry::RegisterCompilationKernels() {
    153   XlaOpRegistry& registry = Instance();
    154   mutex_lock lock(registry.mutex_);
    155 
    156   if (registry.jit_kernels_registered_) return;
    157   registry.jit_kernels_registered_ = true;
    158 
    159   OpRegistryInterface* op_registry = OpRegistry::Global();
    160   for (const auto& op : registry.ops_) {
    161     const string& op_name = op.first;
    162     const std::unique_ptr<OpRegistration>& op_registration = op.second;
    163     const OpDef* op_def;
    164     Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def);
    165     if (!lookup_status.ok()) {
    166       LOG(ERROR) << lookup_status.error_message();
    167       XLA_LOG_LINES(
    168           ERROR, "Ops registered: \n" +
    169                      dynamic_cast<OpRegistry*>(op_registry)->DebugString(true));
    170     }
    171     TF_CHECK_OK(lookup_status);
    172 
    173     std::unordered_set<string> type_attrs;
    174     for (const OpDef::AttrDef& attr_def : op_def->attr()) {
    175       if (attr_def.type() == "type" || attr_def.type() == "list(type)") {
    176         type_attrs.insert(attr_def.name());
    177       }
    178     }
    179 
    180     // Checks there are no type constraints referring to unknown attributes.
    181     for (const auto& constraint : op_registration->type_constraints) {
    182       if (type_attrs.find(constraint.first) == type_attrs.end()) {
    183         LOG(FATAL) << "Unknown type attribute " << constraint.first
    184                    << " in XLA op registration for " << op_name;
    185       }
    186     }
    187 
    188     for (auto& backend : registry.backends_) {
    189       // If the operator has a device whitelist, only register on whitelisted
    190       // devices.
    191       if (op_registration->has_device_whitelist &&
    192           op_registration->device_whitelist.find(backend.first) ==
    193               op_registration->device_whitelist.end()) {
    194         continue;
    195       }
    196 
    197       std::unique_ptr<KernelDef> kdef(new KernelDef);
    198       kdef->set_op(op_registration->name);
    199       kdef->set_device_type(backend.first);
    200 
    201       // Constrain each type attribute to the intersection of:
    202       // a) the types supported by the backend, and
    203       // b) the types allowed by the OpDef, and
    204       // c) the type constraints.
    205       for (const string& type_attr : type_attrs) {
    206         KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
    207         attr_constraint->set_name(type_attr);
    208         auto* allowed_values =
    209             attr_constraint->mutable_allowed_values()->mutable_list();
    210 
    211         const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def);
    212         const auto* op_def_allowed_types =
    213             op_def_attr.has_allowed_values()
    214                 ? &op_def_attr.allowed_values().list().type()
    215                 : nullptr;
    216         auto constraint_it = op_registration->type_constraints.find(type_attr);
    217         const std::set<DataType>* type_constraints =
    218             constraint_it != op_registration->type_constraints.end()
    219                 ? &constraint_it->second
    220                 : nullptr;
    221         for (DataType dtype : backend.second.supported_types) {
    222           // Filter out types that aren't allowed by the OpDef.
    223           if (op_def_allowed_types != nullptr &&
    224               std::find(op_def_allowed_types->begin(),
    225                         op_def_allowed_types->end(),
    226                         dtype) == op_def_allowed_types->end()) {
    227             continue;
    228           }
    229           // Filter out types based on the type constraints.
    230           if (type_constraints != nullptr &&
    231               type_constraints->find(dtype) == type_constraints->end()) {
    232             continue;
    233           }
    234           // Passed all the filters, this type is allowed.
    235           allowed_values->add_type(dtype);
    236         }
    237         if (op_registration->allow_resource_types) {
    238           allowed_values->add_type(DT_RESOURCE);
    239         }
    240       }
    241       if (backend.second.op_filter != nullptr &&
    242           !backend.second.op_filter(kdef.get())) {
    243         continue;
    244       }
    245       VLOG(2) << "XLA op registration: device: " << backend.first
    246               << " op: " << op_name;
    247       registry.kernel_registrars_.emplace_back(
    248           new kernel_factory::OpKernelRegistrar(
    249               new KernelDef(*kdef), "XlaJitOp", op_registration->factory));
    250       backend.second.kernel_defs.push_back(std::move(kdef));
    251     }
    252   }
    253 }
    254 
    255 std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
    256     const string& compilation_device_name,
    257     bool include_compilation_only_kernels) {
    258   // Ensure compilation kernels registered.
    259   RegisterCompilationKernels();
    260   std::vector<const KernelDef*> kernels;
    261   XlaOpRegistry& registry = Instance();
    262   mutex_lock lock(registry.mutex_);
    263   auto it = registry.backends_.find(compilation_device_name);
    264   CHECK(it != registry.backends_.end())
    265       << "Unknown backend " << compilation_device_name;
    266   for (const std::unique_ptr<KernelDef>& k : it->second.kernel_defs) {
    267     auto op_iter = registry.ops_.find(k->op());
    268     CHECK(op_iter != registry.ops_.end());
    269     // The test in IsCompatible ensures that if there are multiple matching
    270     // registrations for this op name, they all have the same value of
    271     // compilation_only, so only the first match needs to be tested.
    272     if (include_compilation_only_kernels ||
    273         !op_iter->second->compilation_only) {
    274       kernels.push_back(k.get());
    275     }
    276   }
    277   return kernels;
    278 }
    279 
    280 /* static */ const std::unordered_set<string>*
    281 XlaOpRegistry::CompileTimeConstantInputs(const string& op) {
    282   XlaOpRegistry& registry = Instance();
    283   mutex_lock lock(registry.mutex_);
    284   auto it = registry.ops_.find(op);
    285   if (it == registry.ops_.end()) {
    286     return nullptr;
    287   }
    288   return &it->second->compile_time_constant_inputs;
    289 }
    290 
    291 std::vector<string> XlaOpRegistry::BackendNames() {
    292   std::vector<string> names;
    293   XlaOpRegistry& registry = Instance();
    294   mutex_lock lock(registry.mutex_);
    295   for (const auto& backend_pair : registry.backends_) {
    296     names.push_back(backend_pair.first);
    297   }
    298   return names;
    299 }
    300 
    301 bool XlaOpRegistry::IsBackendRegistered(const string& name) {
    302   XlaOpRegistry& registry = Instance();
    303   mutex_lock lock(registry.mutex_);
    304   return registry.backends_.find(name) != registry.backends_.end();
    305 }
    306 
    307 XlaOpRegistry& XlaOpRegistry::Instance() {
    308   static XlaOpRegistry* r = new XlaOpRegistry;
    309   return *r;
    310 }
    311 
    312 XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) {
    313   registration_.reset(new XlaOpRegistry::OpRegistration);
    314   registration_->name = name.ToString();
    315 }
    316 
    317 XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) {
    318   XlaOpRegistrationBuilder registration(name);
    319   return registration;
    320 }
    321 
    322 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
    323     gtl::ArraySlice<StringPiece> devices) {
    324   registration_->has_device_whitelist = true;
    325   for (StringPiece device : devices) {
    326     registration_->device_whitelist.insert(device.ToString());
    327   }
    328   return *this;
    329 }
    330 
    331 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) {
    332   registration_->has_device_whitelist = true;
    333   registration_->device_whitelist.insert(device.ToString());
    334   return *this;
    335 }
    336 
    337 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompilationOnly() {
    338   registration_->compilation_only = true;
    339   return *this;
    340 }
    341 
    342 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() {
    343   registration_->allow_resource_types = true;
    344   return *this;
    345 }
    346 
    347 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
    348     StringPiece attr_name, DataType allowed) {
    349   std::set<DataType>& types =
    350       registration_->type_constraints[attr_name.ToString()];
    351   types.insert(allowed);
    352   return *this;
    353 }
    354 
    355 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
    356     StringPiece attr_name, gtl::ArraySlice<DataType> allowed) {
    357   std::set<DataType>& types =
    358       registration_->type_constraints[attr_name.ToString()];
    359   for (DataType t : allowed) {
    360     types.insert(t);
    361   }
    362   return *this;
    363 }
    364 
    365 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput(
    366     StringPiece input_name) {
    367   registration_->compile_time_constant_inputs.insert(input_name.ToString());
    368   return *this;
    369 }
    370 
    371 std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build(
    372     XlaOpRegistry::Factory factory) {
    373   registration_->factory = factory;
    374   return std::move(registration_);
    375 }
    376 
    377 XlaOpRegistrar::XlaOpRegistrar(
    378     std::unique_ptr<XlaOpRegistry::OpRegistration> registration) {
    379   XlaOpRegistry& registry = XlaOpRegistry::Instance();
    380   mutex_lock lock(registry.mutex_);
    381   auto existing_ops = registry.ops_.equal_range(registration->name);
    382   for (auto existing = existing_ops.first; existing != existing_ops.second;
    383        ++existing) {
    384     if (!XlaOpRegistry::IsCompatible(*existing->second, *registration)) {
    385       LOG(FATAL)
    386           << "XLA op registration " << registration->name
    387           << " is incompatible with existing registration of the same name.";
    388     }
    389   }
    390   registry.ops_.emplace(registration->name, std::move(registration));
    391 }
    392 
    393 XlaBackendRegistrar::XlaBackendRegistrar(
    394     StringPiece name, gtl::ArraySlice<DataType> types,
    395     XlaOpRegistry::BackendOpFilter op_filter) {
    396   XlaOpRegistry& registry = XlaOpRegistry::Instance();
    397   registry.RegisterBackend(name.ToString(), types, op_filter);
    398 }
    399 
    400 }  // namespace tensorflow
    401