Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     17 
     18 #include <functional>
     19 #include <memory>
     20 
     21 #include "tensorflow/compiler/jit/flags.h"
     22 #include "tensorflow/compiler/jit/xla_cluster_util.h"
     23 #include "tensorflow/compiler/tf2xla/type_util.h"
     24 #include "tensorflow/compiler/tf2xla/xla_context.h"
     25 #include "tensorflow/compiler/xla/client/client_library.h"
     26 #include "tensorflow/core/common_runtime/device_factory.h"
     27 #include "tensorflow/core/common_runtime/local_device.h"
     28 #include "tensorflow/core/framework/device_base.h"
     29 #include "tensorflow/core/framework/kernel_def.pb.h"
     30 #include "tensorflow/core/framework/node_def.pb.h"
     31 #include "tensorflow/core/framework/op_def_util.h"
     32 #include "tensorflow/core/platform/mem.h"
     33 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     34 
     35 namespace tensorflow {
     36 
     37 const char* const DEVICE_CPU_XLA_JIT = "XLA_CPU_JIT";
     38 const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT";
     39 const char* const DEVICE_XLA_CPU = "XLA_CPU";
     40 const char* const DEVICE_XLA_GPU = "XLA_GPU";
     41 
     42 static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) {
     43   const OpDef* op_def;
     44   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("XlaLaunch", &op_def));
     45   NodeDef node_def;
     46   node_def.set_name("_XlaLaunch-op");
     47   node_def.set_op("XlaLaunch");
     48   string kernel_class_name;
     49   TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr,
     50                                    &kernel_class_name));
     51   VLOG(1) << "LaunchOpHasKernelForDevice"
     52           << " kernel_class_name: " << kernel_class_name;
     53   return Status::OK();
     54 }
     55 
     56 XlaOpRegistry::XlaOpRegistry() = default;
     57 XlaOpRegistry::~XlaOpRegistry() = default;
     58 
     59 // TODO(b/64575122) consider adding more sophisticated definitions of
     60 // compatibility if needed by future use cases.
     61 /* static */ bool XlaOpRegistry::IsCompatible(const OpRegistration& x,
     62                                               const OpRegistration& y) {
     63   if (x.name != y.name) return true;
     64   // The registrations refer to the same Op: ensures they are compatible and
     65   // are restricted to different device whitelists.
     66   if (x.compilation_only != y.compilation_only) {
     67     LOG(WARNING) << "Registrations of " << x.name
     68                  << " have incompatible compilation_only settings.";
     69     return false;
     70   }
     71   if (x.allow_resource_types != y.allow_resource_types) {
     72     LOG(WARNING) << "Registrations of " << x.name
     73                  << " have incompatible allow_resource_types settings.";
     74     return false;
     75   }
     76   if (x.allow_variant_types != y.allow_variant_types) {
     77     LOG(WARNING) << "Registrations of " << x.name
     78                  << " have incompatible allow_variant_types settings.";
     79     return false;
     80   }
     81   if (!x.has_device_whitelist && !y.has_device_whitelist) {
     82     LOG(WARNING) << "Duplicate registrations of " << x.name
     83                  << "with no device whitelists.";
     84     return false;
     85   }
     86   if (x.has_device_whitelist && y.has_device_whitelist) {
     87     for (const auto& device : x.device_whitelist) {
     88       if (y.device_whitelist.count(device) != 0) {
     89         LOG(WARNING) << "Multiple registrations of " << x.name << " on device "
     90                      << device;
     91         return false;
     92       }
     93     }
     94   }
     95   if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) {
     96     LOG(WARNING) << "Registrations of " << x.name
     97                  << " have incompatible compile time constant inputs.";
     98     return false;
     99   }
    100   if (x.is_metadata_op != y.is_metadata_op) {
    101     LOG(WARNING) << "Registrations of " << x.name
    102                  << " have incompatible values for is_metadata_op.";
    103     return false;
    104   }
    105   return true;
    106 }
    107 
    108 /* static */ void XlaOpRegistry::RegisterCompilationDevice(
    109     const string& device_name, const DeviceRegistration& registration) {
    110   XlaOpRegistry& registry = Instance();
    111   mutex_lock lock(registry.mutex_);
    112   auto result =
    113       registry.compilation_devices_.emplace(device_name, registration);
    114   CHECK(result.second || result.first->second.compilation_device_name ==
    115                              registration.compilation_device_name);
    116 }
    117 
    118 /* static */ void XlaOpRegistry::RegisterBackend(
    119     const string& compilation_device_name,
    120     absl::Span<const DataType> supported_types, BackendOpFilter op_filter) {
    121   XlaOpRegistry& registry = Instance();
    122   mutex_lock lock(registry.mutex_);
    123   auto result = registry.backends_.emplace(compilation_device_name, Backend());
    124   CHECK(result.second) << "Duplicate XLA backend registration "
    125                        << compilation_device_name;
    126   result.first->second.supported_types.insert(supported_types.begin(),
    127                                               supported_types.end());
    128   result.first->second.op_filter = op_filter;
    129 }
    130 
    131 /* static */ bool XlaOpRegistry::GetCompilationDevice(
    132     const string& device_name, const DeviceRegistration** registration) {
    133   XlaOpRegistry& registry = Instance();
    134 
    135   // Lazily register the CPU and GPU JIT devices the first time
    136   // GetCompilationDevice is called.
    137   static void* registration_init = [&registry]() {
    138     MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
    139     bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
    140     VLOG(2) << "tf_xla_cpu_global_jit = " << cpu_global_jit;
    141 
    142     mutex_lock lock(registry.mutex_);
    143     if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) {
    144       DeviceRegistration& registration =
    145           registry.compilation_devices_[DEVICE_CPU];
    146       registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
    147       registration.autoclustering_policy =
    148           cpu_global_jit
    149               ? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally
    150               : XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested;
    151       registration.compile_all_resource_ops = false;
    152     }
    153     if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) {
    154       DeviceRegistration& registration =
    155           registry.compilation_devices_[DEVICE_GPU];
    156       registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
    157       registration.autoclustering_policy =
    158           XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally;
    159       registration.compile_all_resource_ops = false;
    160     }
    161     return nullptr;
    162   }();
    163   (void)registration_init;
    164 
    165   mutex_lock lock(registry.mutex_);
    166   auto it = registry.compilation_devices_.find(device_name);
    167   if (it == registry.compilation_devices_.end()) return false;
    168   *registration = &it->second;
    169   return true;
    170 }
    171 
    172 void XlaOpRegistry::RegisterCompilationKernels() {
    173   XlaOpRegistry& registry = Instance();
    174   mutex_lock lock(registry.mutex_);
    175 
    176   if (registry.jit_kernels_registered_) return;
    177   registry.jit_kernels_registered_ = true;
    178 
    179   OpRegistryInterface* op_registry = OpRegistry::Global();
    180   // Order of op registration:
    181   // The goal is to allow the co-existence of backend-specific kernels and
    182   // generic kernels. To achieve this, we enforce the following order of
    183   // registrations for one op:
    184   // 1. Process op registration with device whitelists:
    185   //      this pass registers backend-specific kernels for this op.
    186   // 2. Process op registration without device whitelists:
    187   //      this pass registers the kernels for all the other supported backends.
    188   for (auto& ops : registry.ops_) {
    189     const string& op_name = ops.first;
    190     std::vector<std::unique_ptr<OpRegistration>>& op_registrations = ops.second;
    191     // Partition the op registration so that the ones with device whitelists
    192     // precede the one without device whitelist.
    193     std::partition(op_registrations.begin(), op_registrations.end(),
    194                    [](const std::unique_ptr<OpRegistration>& op_reg) {
    195                      return op_reg->has_device_whitelist;
    196                    });
    197 
    198     // Collect a set of backend registered by ops with device whitelists.
    199     // The op registration without whitelists will register a generic kernel
    200     // for all other backends not in this set.
    201     std::unordered_set<string> whitelisted_backend;
    202     for (auto& op_registration : op_registrations) {
    203       if (op_registration->has_device_whitelist) {
    204         whitelisted_backend.insert(op_registration->device_whitelist.begin(),
    205                                    op_registration->device_whitelist.end());
    206       }
    207     }
    208 
    209     for (auto& op_registration : op_registrations) {
    210       const OpDef* op_def;
    211       Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def);
    212       if (!lookup_status.ok()) {
    213         LOG(ERROR) << lookup_status.error_message();
    214         XLA_LOG_LINES(
    215             ERROR,
    216             "Ops registered: \n" +
    217                 dynamic_cast<OpRegistry*>(op_registry)->DebugString(true));
    218       }
    219       TF_CHECK_OK(lookup_status);
    220 
    221       std::unordered_set<string> type_attrs;
    222       for (const OpDef::AttrDef& attr_def : op_def->attr()) {
    223         if (attr_def.type() == "type" || attr_def.type() == "list(type)") {
    224           type_attrs.insert(attr_def.name());
    225         }
    226       }
    227 
    228       // Checks there are no type constraints referring to unknown attributes.
    229       for (const auto& constraint : op_registration->type_constraints) {
    230         if (type_attrs.find(constraint.first) == type_attrs.end()) {
    231           LOG(FATAL) << "Unknown type attribute " << constraint.first
    232                      << " in XLA op registration for " << op_name;
    233         }
    234       }
    235 
    236       for (auto& backend : registry.backends_) {
    237         // If the operator has a device whitelist, only register on whitelisted
    238         // devices.
    239         if (op_registration->has_device_whitelist &&
    240             op_registration->device_whitelist.find(backend.first) ==
    241                 op_registration->device_whitelist.end()) {
    242           continue;
    243         }
    244 
    245         // If the operator does NOT has a device whitelist, skip all devices
    246         // that has already been registered.
    247         if (!op_registration->has_device_whitelist &&
    248             whitelisted_backend.find(backend.first) !=
    249                 whitelisted_backend.end()) {
    250           continue;
    251         }
    252 
    253         std::unique_ptr<KernelDef> kdef(new KernelDef);
    254         kdef->set_op(op_registration->name);
    255         kdef->set_device_type(backend.first);
    256 
    257         // Constrain each type attribute to the intersection of:
    258         // a) the types supported by the backend, and
    259         // b) the types allowed by the OpDef, and
    260         // c) the type constraints.
    261         bool unsatisfiable_type_constraint = false;
    262         for (const string& type_attr : type_attrs) {
    263           KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
    264           attr_constraint->set_name(type_attr);
    265           auto* allowed_values =
    266               attr_constraint->mutable_allowed_values()->mutable_list();
    267 
    268           const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def);
    269           const auto* op_def_allowed_types =
    270               op_def_attr.has_allowed_values()
    271                   ? &op_def_attr.allowed_values().list().type()
    272                   : nullptr;
    273           auto constraint_it =
    274               op_registration->type_constraints.find(type_attr);
    275           const std::set<DataType>* type_constraints =
    276               constraint_it != op_registration->type_constraints.end()
    277                   ? &constraint_it->second
    278                   : nullptr;
    279           for (DataType dtype : backend.second.supported_types) {
    280             // Filter out types that aren't allowed by the OpDef.
    281             if (op_def_allowed_types != nullptr &&
    282                 std::find(op_def_allowed_types->begin(),
    283                           op_def_allowed_types->end(),
    284                           dtype) == op_def_allowed_types->end()) {
    285               continue;
    286             }
    287             // Filter out types based on the type constraints.
    288             if (type_constraints != nullptr &&
    289                 type_constraints->find(dtype) == type_constraints->end()) {
    290               continue;
    291             }
    292             // Passed all the filters, this type is allowed.
    293             allowed_values->add_type(dtype);
    294           }
    295           if (op_registration->allow_resource_types) {
    296             allowed_values->add_type(DT_RESOURCE);
    297           }
    298           if (op_registration->allow_variant_types) {
    299             allowed_values->add_type(DT_VARIANT);
    300           }
    301           // Don't build KernelDefs that have unsatisfiable type constraints.
    302           if (allowed_values->type().empty()) {
    303             unsatisfiable_type_constraint = true;
    304             break;
    305           }
    306         }
    307         if (unsatisfiable_type_constraint) continue;
    308 
    309         if (backend.second.op_filter != nullptr &&
    310             !backend.second.op_filter(kdef.get())) {
    311           continue;
    312         }
    313         VLOG(2) << "XLA op registration: device: " << backend.first
    314                 << " op: " << op_name;
    315         registry.kernel_registrars_.emplace_back(
    316             new kernel_factory::OpKernelRegistrar(
    317                 new KernelDef(*kdef), "XlaJitOp", op_registration->factory));
    318         backend.second.kernel_defs.push_back(std::move(kdef));
    319       }
    320     }
    321   }
    322 }
    323 
    324 std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
    325     const string& compilation_device_name,
    326     bool include_compilation_only_kernels) {
    327   // Ensure compilation kernels registered.
    328   RegisterCompilationKernels();
    329   std::vector<const KernelDef*> kernels;
    330   XlaOpRegistry& registry = Instance();
    331   mutex_lock lock(registry.mutex_);
    332   auto it = registry.backends_.find(compilation_device_name);
    333   CHECK(it != registry.backends_.end())
    334       << "Unknown backend " << compilation_device_name;
    335   for (const std::unique_ptr<KernelDef>& k : it->second.kernel_defs) {
    336     auto op_iter = registry.ops_.find(k->op());
    337     CHECK(op_iter != registry.ops_.end() && !op_iter->second.empty());
    338     // The test in IsCompatible ensures that if there are multiple matching
    339     // registrations for this op name, they all have the same value of
    340     // compilation_only, so only the first match needs to be tested.
    341     if (include_compilation_only_kernels ||
    342         !op_iter->second.front()->compilation_only) {
    343       kernels.push_back(k.get());
    344     }
    345   }
    346   return kernels;
    347 }
    348 
    349 /*static*/ std::vector<string> XlaOpRegistry::GetAllRegisteredOps() {
    350   std::vector<string> ops;
    351   XlaOpRegistry& registry = Instance();
    352   mutex_lock lock(registry.mutex_);
    353   for (const auto& pair : registry.ops_) {
    354     ops.push_back(pair.first);
    355   }
    356   std::sort(ops.begin(), ops.end());
    357   return ops;
    358 }
    359 
    360 /* static */ Status XlaOpRegistry::CompileTimeConstantInputs(
    361     const NodeDef& node_def, const OpKernel* op_kernel, const OpDef* op_def,
    362     std::vector<int>* result) {
    363   result->clear();
    364 
    365   DCHECK(op_def != nullptr || op_kernel != nullptr);
    366 
    367   std::unordered_set<string> compile_time_constant_inputs_from_attr;
    368   std::vector<string> compile_time_constant_inputs_vect_from_attr;
    369 
    370   const std::unordered_set<string>* compile_time_constant_inputs;
    371 
    372   if (GetNodeAttr(node_def, kXlaCompileTimeConstantInputsAttr,
    373                   &compile_time_constant_inputs_vect_from_attr)
    374           .ok()) {
    375     absl::c_copy(compile_time_constant_inputs_vect_from_attr,
    376                  std::inserter(compile_time_constant_inputs_from_attr,
    377                                compile_time_constant_inputs_from_attr.end()));
    378     compile_time_constant_inputs = &compile_time_constant_inputs_from_attr;
    379   } else {
    380     const string& op = node_def.op();
    381 
    382     XlaOpRegistry& registry = Instance();
    383     mutex_lock lock(registry.mutex_);
    384     auto it = registry.ops_.find(op);
    385     if (it == registry.ops_.end() || it->second.empty()) {
    386       return Status::OK();
    387     } else {
    388       // The test in IsCompatible ensures that if there are multiple matching
    389       // registrations for this op name, they all have the same value of
    390       // compile_time_constant_inputs, so only the first match is returned.
    391       //
    392       // TODO(sanjoy): This can probably be a std::vector<string>.
    393       compile_time_constant_inputs =
    394           &it->second.front()->compile_time_constant_inputs;
    395     }
    396   }
    397 
    398   for (const string& input : *compile_time_constant_inputs) {
    399     if (op_def) {
    400       NameRangeMap input_name_ranges;
    401       TF_RETURN_IF_ERROR(
    402           NameRangesForNode(node_def, *op_def, &input_name_ranges, nullptr));
    403       auto name_range = input_name_ranges.find(input);
    404       if (name_range == input_name_ranges.end()) {
    405         continue;
    406       }
    407 
    408       for (int i = name_range->second.first; i < name_range->second.second;
    409            i++) {
    410         result->push_back(i);
    411       }
    412     } else {
    413       int start, stop;
    414       TF_CHECK_OK(op_kernel->InputRange(input, &start, &stop));
    415       for (int i = start; i < stop; ++i) {
    416         result->push_back(i);
    417       }
    418     }
    419   }
    420 
    421   absl::c_sort(*result);
    422   return Status::OK();
    423 }
    424 
    425 /*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) {
    426   XlaOpRegistry& registry = Instance();
    427   mutex_lock lock(registry.mutex_);
    428   auto it = registry.ops_.find(op);
    429   if (it == registry.ops_.end() || it->second.empty()) {
    430     return false;
    431   }
    432 
    433   // The test in IsCompatible ensures that if there are multiple matching
    434   // registrations for this op name, they all have the same value of
    435   // is_metadata_op, so only the first match is returned.
    436   return it->second.front()->is_metadata_op;
    437 }
    438 
    439 std::vector<string> XlaOpRegistry::BackendNames() {
    440   std::vector<string> names;
    441   XlaOpRegistry& registry = Instance();
    442   mutex_lock lock(registry.mutex_);
    443   for (const auto& backend_pair : registry.backends_) {
    444     names.push_back(backend_pair.first);
    445   }
    446   return names;
    447 }
    448 
    449 bool XlaOpRegistry::IsBackendRegistered(const string& name) {
    450   XlaOpRegistry& registry = Instance();
    451   mutex_lock lock(registry.mutex_);
    452   return registry.backends_.find(name) != registry.backends_.end();
    453 }
    454 
    455 XlaOpRegistry& XlaOpRegistry::Instance() {
    456   static XlaOpRegistry* r = new XlaOpRegistry;
    457   return *r;
    458 }
    459 
    460 XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(absl::string_view name) {
    461   registration_.reset(new XlaOpRegistry::OpRegistration);
    462   registration_->name = string(name);
    463 }
    464 
    465 XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(
    466     absl::string_view name) {
    467   XlaOpRegistrationBuilder registration(name);
    468   return registration;
    469 }
    470 
    471 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
    472     absl::Span<const absl::string_view> devices) {
    473   registration_->has_device_whitelist = true;
    474   for (absl::string_view device : devices) {
    475     registration_->device_whitelist.emplace(device);
    476   }
    477   return *this;
    478 }
    479 
    480 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
    481     absl::string_view device) {
    482   registration_->has_device_whitelist = true;
    483   registration_->device_whitelist.emplace(device);
    484   return *this;
    485 }
    486 
    487 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompilationOnly() {
    488   registration_->compilation_only = true;
    489   return *this;
    490 }
    491 
    492 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() {
    493   registration_->allow_resource_types = true;
    494   return *this;
    495 }
    496 
    497 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowVariantTypes() {
    498   registration_->allow_variant_types = true;
    499   return *this;
    500 }
    501 
    502 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
    503     absl::string_view attr_name, DataType allowed) {
    504   std::set<DataType>& types =
    505       registration_->type_constraints[string(attr_name)];
    506   types.insert(allowed);
    507   return *this;
    508 }
    509 
    510 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
    511     absl::string_view attr_name, absl::Span<const DataType> allowed) {
    512   std::set<DataType>& types =
    513       registration_->type_constraints[string(attr_name)];
    514   for (DataType t : allowed) {
    515     types.insert(t);
    516   }
    517   return *this;
    518 }
    519 
    520 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstantInput(
    521     absl::string_view input_name) {
    522   registration_->compile_time_constant_inputs.emplace(input_name);
    523   return *this;
    524 }
    525 
    526 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() {
    527   registration_->is_metadata_op = true;
    528   return *this;
    529 }
    530 
    531 std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build(
    532     XlaOpRegistry::Factory factory) {
    533   registration_->factory = factory;
    534   return std::move(registration_);
    535 }
    536 
    537 XlaOpRegistrar::XlaOpRegistrar(
    538     std::unique_ptr<XlaOpRegistry::OpRegistration> registration) {
    539   XlaOpRegistry& registry = XlaOpRegistry::Instance();
    540   mutex_lock lock(registry.mutex_);
    541   auto& existing_ops = registry.ops_[registration->name];
    542   for (auto& existing : existing_ops) {
    543     if (!XlaOpRegistry::IsCompatible(*existing, *registration)) {
    544       LOG(FATAL)
    545           << "XLA op registration " << registration->name
    546           << " is incompatible with existing registration of the same name.";
    547     }
    548   }
    549   existing_ops.emplace_back(std::move(registration));
    550 }
    551 
    552 XlaBackendRegistrar::XlaBackendRegistrar(
    553     absl::string_view name, absl::Span<const DataType> types,
    554     XlaOpRegistry::BackendOpFilter op_filter) {
    555   XlaOpRegistry& registry = XlaOpRegistry::Instance();
    556   registry.RegisterBackend(string(name), types, op_filter);
    557 }
    558 
    559 }  // namespace tensorflow
    560