1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 17 18 #include <functional> 19 #include <memory> 20 21 #include "tensorflow/compiler/tf2xla/type_util.h" 22 #include "tensorflow/compiler/tf2xla/xla_context.h" 23 #include "tensorflow/compiler/xla/client/client_library.h" 24 #include "tensorflow/core/common_runtime/device_factory.h" 25 #include "tensorflow/core/common_runtime/local_device.h" 26 #include "tensorflow/core/framework/device_base.h" 27 #include "tensorflow/core/framework/kernel_def.pb.h" 28 #include "tensorflow/core/framework/node_def.pb.h" 29 #include "tensorflow/core/framework/op_def_util.h" 30 #include "tensorflow/core/platform/mem.h" 31 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 32 33 namespace tensorflow { 34 35 const char* const DEVICE_CPU_XLA_JIT = "XLA_CPU_JIT"; 36 const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT"; 37 const char* const DEVICE_XLA_CPU = "XLA_CPU"; 38 const char* const DEVICE_XLA_GPU = "XLA_GPU"; 39 40 static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { 41 const OpDef* op_def; 42 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def)); 43 NodeDef node_def; 44 node_def.set_name("_XlaLaunch-op"); 45 node_def.set_op("_XlaLaunch"); 46 string kernel_class_name; 47 TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr, 48 &kernel_class_name)); 49 VLOG(1) << "LaunchOpHasKernelForDevice" 50 << " kernel_class_name: " << kernel_class_name; 51 return Status::OK(); 52 } 53 54 XlaOpRegistry::XlaOpRegistry() = default; 55 XlaOpRegistry::~XlaOpRegistry() = default; 56 57 // TODO(b/64575122) consider adding more sophisticated definitions of 58 // compatibility if needed by future use cases. 59 /* static */ bool XlaOpRegistry::IsCompatible(const OpRegistration& x, 60 const OpRegistration& y) { 61 if (x.name != y.name) return true; 62 // The registrations refer to the same Op: ensures they are compatible and 63 // are restricted to different device whitelists. 64 if (x.compilation_only != y.compilation_only) { 65 LOG(WARNING) << "Registrations of " << x.name 66 << " have incompatible compilation_only settings."; 67 return false; 68 } 69 if (x.allow_resource_types != y.allow_resource_types) { 70 LOG(WARNING) << "Registrations of " << x.name 71 << " have incompatible allow_resource_types settings."; 72 return false; 73 } 74 if (!x.has_device_whitelist || !y.has_device_whitelist) { 75 LOG(WARNING) << "Registrations of " << x.name 76 << " do not both have device whitelists."; 77 return false; 78 } 79 for (const auto& device : x.device_whitelist) { 80 if (y.device_whitelist.count(device) != 0) { 81 LOG(WARNING) << "Multiple registrations of " << x.name << " on device " 82 << device; 83 return false; 84 } 85 } 86 if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) { 87 LOG(WARNING) << "Registrations of " << x.name 88 << " have incompatible compile time constant inputs."; 89 return false; 90 } 91 return true; 92 } 93 94 /* static */ void XlaOpRegistry::RegisterCompilationDevice( 95 const string& device_name, const DeviceRegistration& registration) { 96 XlaOpRegistry& registry = Instance(); 97 mutex_lock lock(registry.mutex_); 98 auto result = 99 registry.compilation_devices_.emplace(device_name, registration); 100 CHECK(result.second || result.first->second.compilation_device_name == 101 registration.compilation_device_name); 102 } 103 104 /* static */ void XlaOpRegistry::RegisterBackend( 105 const string& compilation_device_name, 106 gtl::ArraySlice<DataType> supported_types, BackendOpFilter op_filter) { 107 XlaOpRegistry& registry = Instance(); 108 mutex_lock lock(registry.mutex_); 109 auto result = registry.backends_.emplace(compilation_device_name, Backend()); 110 CHECK(result.second) << "Duplicate XLA backend registration " 111 << compilation_device_name; 112 result.first->second.supported_types.insert(supported_types.begin(), 113 supported_types.end()); 114 result.first->second.op_filter = op_filter; 115 } 116 117 /* static */ bool XlaOpRegistry::GetCompilationDevice( 118 const string& device_name, const DeviceRegistration** registration) { 119 XlaOpRegistry& registry = Instance(); 120 121 // Lazily register the CPU and GPU JIT devices the first time 122 // GetCompilationDevice is called. 123 static void* registration_init = [®istry]() { 124 mutex_lock lock(registry.mutex_); 125 if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) { 126 DeviceRegistration& registration = 127 registry.compilation_devices_[DEVICE_CPU]; 128 registration.compilation_device_name = DEVICE_CPU_XLA_JIT; 129 registration.requires_compilation = false; 130 registration.enable_jit_by_default = false; 131 registration.compile_resource_ops = false; 132 } 133 if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) { 134 DeviceRegistration& registration = 135 registry.compilation_devices_[DEVICE_GPU]; 136 registration.compilation_device_name = DEVICE_GPU_XLA_JIT; 137 registration.requires_compilation = false; 138 registration.enable_jit_by_default = true; 139 registration.compile_resource_ops = false; 140 } 141 return nullptr; 142 }(); 143 (void)registration_init; 144 145 mutex_lock lock(registry.mutex_); 146 auto it = registry.compilation_devices_.find(device_name); 147 if (it == registry.compilation_devices_.end()) return false; 148 *registration = &it->second; 149 return true; 150 } 151 152 void XlaOpRegistry::RegisterCompilationKernels() { 153 XlaOpRegistry& registry = Instance(); 154 mutex_lock lock(registry.mutex_); 155 156 if (registry.jit_kernels_registered_) return; 157 registry.jit_kernels_registered_ = true; 158 159 OpRegistryInterface* op_registry = OpRegistry::Global(); 160 for (const auto& op : registry.ops_) { 161 const string& op_name = op.first; 162 const std::unique_ptr<OpRegistration>& op_registration = op.second; 163 const OpDef* op_def; 164 Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); 165 if (!lookup_status.ok()) { 166 LOG(ERROR) << lookup_status.error_message(); 167 XLA_LOG_LINES( 168 ERROR, "Ops registered: \n" + 169 dynamic_cast<OpRegistry*>(op_registry)->DebugString(true)); 170 } 171 TF_CHECK_OK(lookup_status); 172 173 std::unordered_set<string> type_attrs; 174 for (const OpDef::AttrDef& attr_def : op_def->attr()) { 175 if (attr_def.type() == "type" || attr_def.type() == "list(type)") { 176 type_attrs.insert(attr_def.name()); 177 } 178 } 179 180 // Checks there are no type constraints referring to unknown attributes. 181 for (const auto& constraint : op_registration->type_constraints) { 182 if (type_attrs.find(constraint.first) == type_attrs.end()) { 183 LOG(FATAL) << "Unknown type attribute " << constraint.first 184 << " in XLA op registration for " << op_name; 185 } 186 } 187 188 for (auto& backend : registry.backends_) { 189 // If the operator has a device whitelist, only register on whitelisted 190 // devices. 191 if (op_registration->has_device_whitelist && 192 op_registration->device_whitelist.find(backend.first) == 193 op_registration->device_whitelist.end()) { 194 continue; 195 } 196 197 std::unique_ptr<KernelDef> kdef(new KernelDef); 198 kdef->set_op(op_registration->name); 199 kdef->set_device_type(backend.first); 200 201 // Constrain each type attribute to the intersection of: 202 // a) the types supported by the backend, and 203 // b) the types allowed by the OpDef, and 204 // c) the type constraints. 205 for (const string& type_attr : type_attrs) { 206 KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); 207 attr_constraint->set_name(type_attr); 208 auto* allowed_values = 209 attr_constraint->mutable_allowed_values()->mutable_list(); 210 211 const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def); 212 const auto* op_def_allowed_types = 213 op_def_attr.has_allowed_values() 214 ? &op_def_attr.allowed_values().list().type() 215 : nullptr; 216 auto constraint_it = op_registration->type_constraints.find(type_attr); 217 const std::set<DataType>* type_constraints = 218 constraint_it != op_registration->type_constraints.end() 219 ? &constraint_it->second 220 : nullptr; 221 for (DataType dtype : backend.second.supported_types) { 222 // Filter out types that aren't allowed by the OpDef. 223 if (op_def_allowed_types != nullptr && 224 std::find(op_def_allowed_types->begin(), 225 op_def_allowed_types->end(), 226 dtype) == op_def_allowed_types->end()) { 227 continue; 228 } 229 // Filter out types based on the type constraints. 230 if (type_constraints != nullptr && 231 type_constraints->find(dtype) == type_constraints->end()) { 232 continue; 233 } 234 // Passed all the filters, this type is allowed. 235 allowed_values->add_type(dtype); 236 } 237 if (op_registration->allow_resource_types) { 238 allowed_values->add_type(DT_RESOURCE); 239 } 240 } 241 if (backend.second.op_filter != nullptr && 242 !backend.second.op_filter(kdef.get())) { 243 continue; 244 } 245 VLOG(2) << "XLA op registration: device: " << backend.first 246 << " op: " << op_name; 247 registry.kernel_registrars_.emplace_back( 248 new kernel_factory::OpKernelRegistrar( 249 new KernelDef(*kdef), "XlaJitOp", op_registration->factory)); 250 backend.second.kernel_defs.push_back(std::move(kdef)); 251 } 252 } 253 } 254 255 std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels( 256 const string& compilation_device_name, 257 bool include_compilation_only_kernels) { 258 // Ensure compilation kernels registered. 259 RegisterCompilationKernels(); 260 std::vector<const KernelDef*> kernels; 261 XlaOpRegistry& registry = Instance(); 262 mutex_lock lock(registry.mutex_); 263 auto it = registry.backends_.find(compilation_device_name); 264 CHECK(it != registry.backends_.end()) 265 << "Unknown backend " << compilation_device_name; 266 for (const std::unique_ptr<KernelDef>& k : it->second.kernel_defs) { 267 auto op_iter = registry.ops_.find(k->op()); 268 CHECK(op_iter != registry.ops_.end()); 269 // The test in IsCompatible ensures that if there are multiple matching 270 // registrations for this op name, they all have the same value of 271 // compilation_only, so only the first match needs to be tested. 272 if (include_compilation_only_kernels || 273 !op_iter->second->compilation_only) { 274 kernels.push_back(k.get()); 275 } 276 } 277 return kernels; 278 } 279 280 /* static */ const std::unordered_set<string>* 281 XlaOpRegistry::CompileTimeConstantInputs(const string& op) { 282 XlaOpRegistry& registry = Instance(); 283 mutex_lock lock(registry.mutex_); 284 auto it = registry.ops_.find(op); 285 if (it == registry.ops_.end()) { 286 return nullptr; 287 } 288 return &it->second->compile_time_constant_inputs; 289 } 290 291 std::vector<string> XlaOpRegistry::BackendNames() { 292 std::vector<string> names; 293 XlaOpRegistry& registry = Instance(); 294 mutex_lock lock(registry.mutex_); 295 for (const auto& backend_pair : registry.backends_) { 296 names.push_back(backend_pair.first); 297 } 298 return names; 299 } 300 301 bool XlaOpRegistry::IsBackendRegistered(const string& name) { 302 XlaOpRegistry& registry = Instance(); 303 mutex_lock lock(registry.mutex_); 304 return registry.backends_.find(name) != registry.backends_.end(); 305 } 306 307 XlaOpRegistry& XlaOpRegistry::Instance() { 308 static XlaOpRegistry* r = new XlaOpRegistry; 309 return *r; 310 } 311 312 XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) { 313 registration_.reset(new XlaOpRegistry::OpRegistration); 314 registration_->name = name.ToString(); 315 } 316 317 XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) { 318 XlaOpRegistrationBuilder registration(name); 319 return registration; 320 } 321 322 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( 323 gtl::ArraySlice<StringPiece> devices) { 324 registration_->has_device_whitelist = true; 325 for (StringPiece device : devices) { 326 registration_->device_whitelist.insert(device.ToString()); 327 } 328 return *this; 329 } 330 331 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) { 332 registration_->has_device_whitelist = true; 333 registration_->device_whitelist.insert(device.ToString()); 334 return *this; 335 } 336 337 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompilationOnly() { 338 registration_->compilation_only = true; 339 return *this; 340 } 341 342 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { 343 registration_->allow_resource_types = true; 344 return *this; 345 } 346 347 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( 348 StringPiece attr_name, DataType allowed) { 349 std::set<DataType>& types = 350 registration_->type_constraints[attr_name.ToString()]; 351 types.insert(allowed); 352 return *this; 353 } 354 355 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( 356 StringPiece attr_name, gtl::ArraySlice<DataType> allowed) { 357 std::set<DataType>& types = 358 registration_->type_constraints[attr_name.ToString()]; 359 for (DataType t : allowed) { 360 types.insert(t); 361 } 362 return *this; 363 } 364 365 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( 366 StringPiece input_name) { 367 registration_->compile_time_constant_inputs.insert(input_name.ToString()); 368 return *this; 369 } 370 371 std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build( 372 XlaOpRegistry::Factory factory) { 373 registration_->factory = factory; 374 return std::move(registration_); 375 } 376 377 XlaOpRegistrar::XlaOpRegistrar( 378 std::unique_ptr<XlaOpRegistry::OpRegistration> registration) { 379 XlaOpRegistry& registry = XlaOpRegistry::Instance(); 380 mutex_lock lock(registry.mutex_); 381 auto existing_ops = registry.ops_.equal_range(registration->name); 382 for (auto existing = existing_ops.first; existing != existing_ops.second; 383 ++existing) { 384 if (!XlaOpRegistry::IsCompatible(*existing->second, *registration)) { 385 LOG(FATAL) 386 << "XLA op registration " << registration->name 387 << " is incompatible with existing registration of the same name."; 388 } 389 } 390 registry.ops_.emplace(registration->name, std::move(registration)); 391 } 392 393 XlaBackendRegistrar::XlaBackendRegistrar( 394 StringPiece name, gtl::ArraySlice<DataType> types, 395 XlaOpRegistry::BackendOpFilter op_filter) { 396 XlaOpRegistry& registry = XlaOpRegistry::Instance(); 397 registry.RegisterBackend(name.ToString(), types, op_filter); 398 } 399 400 } // namespace tensorflow 401