1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/jit/xla_device.h" 17 18 #include <stdlib.h> 19 #include <unordered_set> 20 21 #include "tensorflow/compiler/jit/defs.h" 22 #include "tensorflow/compiler/jit/xla_device_context.h" 23 #include "tensorflow/compiler/jit/xla_device_ops.h" 24 #include "tensorflow/compiler/tf2xla/dump_graph.h" 25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 26 #include "tensorflow/compiler/xla/client/client_library.h" 27 #include "tensorflow/core/common_runtime/device.h" 28 #include "tensorflow/core/common_runtime/device_factory.h" 29 #include "tensorflow/core/common_runtime/dma_helper.h" 30 #include "tensorflow/core/common_runtime/function.h" 31 #include "tensorflow/core/common_runtime/renamed_device.h" 32 #include "tensorflow/core/framework/allocator.h" 33 #include "tensorflow/core/framework/device_base.h" 34 #include "tensorflow/core/framework/function.h" 35 #include "tensorflow/core/framework/kernel_def.pb.h" 36 #include "tensorflow/core/framework/node_def_builder.h" 37 #include "tensorflow/core/framework/op_kernel.h" 38 #include "tensorflow/core/framework/tensor.h" 39 #include "tensorflow/core/framework/tensor.pb.h" 40 #include "tensorflow/core/framework/types.h" 41 #include "tensorflow/core/graph/graph_constructor.h" 42 #include "tensorflow/core/lib/core/notification.h" 43 #include "tensorflow/core/lib/core/status.h" 44 #include "tensorflow/core/platform/logging.h" 45 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 46 #include "tensorflow/core/platform/tracing.h" 47 #include "tensorflow/core/public/session_options.h" 48 #include "tensorflow/core/public/version.h" 49 #include "tensorflow/core/util/device_name_utils.h" 50 #include "tensorflow/core/util/stream_executor_util.h" 51 52 namespace se = ::perftools::gputools; 53 54 namespace tensorflow { 55 56 // Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A 57 // XlaDeviceAllocator is created on demand and is associated with a 58 // XlaDevice. It outlives the device itself (for instance, the buffer 59 // backing a tensor holds a pointer to the allocator for book-keeping, 60 // and this buffer can outlast the device). 61 class XlaDeviceAllocatorState { 62 public: 63 // Creates or returns a cached XlaDeviceAllocator for a given 64 // backend and device_ordinal. 65 static XlaDeviceAllocator* GetOrCreateXlaDeviceAllocator( 66 const xla::Backend* backend, int device_ordinal); 67 68 private: 69 // Returns the singleton instance of XlaDeviceAllocatorState. 70 static XlaDeviceAllocatorState& Singleton(); 71 XlaDeviceAllocatorState(); 72 ~XlaDeviceAllocatorState(); 73 74 mutex allocator_mutex_; // Guards the singleton allocator state. 75 std::unordered_map<std::pair<const xla::Backend*, int>, 76 std::unique_ptr<XlaDeviceAllocator>, 77 hash<std::pair<const xla::Backend*, int>>> 78 allocators_ GUARDED_BY(allocator_mutex_); 79 80 TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState); 81 }; 82 83 /* static */ XlaDeviceAllocatorState& XlaDeviceAllocatorState::Singleton() { 84 static auto a = new XlaDeviceAllocatorState; 85 return *a; 86 } 87 88 XlaDeviceAllocatorState::XlaDeviceAllocatorState() = default; 89 XlaDeviceAllocatorState::~XlaDeviceAllocatorState() = default; 90 91 XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( 92 const xla::Backend* backend, int device_ordinal) { 93 XlaDeviceAllocatorState& state = Singleton(); 94 mutex_lock lock(state.allocator_mutex_); 95 96 auto it = state.allocators_.find({backend, device_ordinal}); 97 if (it != state.allocators_.end()) { 98 return it->second.get(); 99 } 100 101 std::unique_ptr<XlaDeviceAllocator> alloc = 102 xla::MakeUnique<XlaDeviceAllocator>(backend, device_ordinal); 103 XlaDeviceAllocator* alloc_ptr = alloc.get(); 104 state.allocators_[{backend, device_ordinal}] = std::move(alloc); 105 return alloc_ptr; 106 } 107 108 /* static */ Status XlaDevice::Create( 109 const string& platform_name, const string& device_name, int device_ordinal, 110 const string& jit_device_name, const SessionOptions& options, 111 const string& name_prefix, bool register_device_for_compilation, 112 std::unique_ptr<XlaDevice>* device) { 113 VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" 114 << device_ordinal; 115 116 if (register_device_for_compilation) { 117 // These are no-ops if they have already been done previously for 118 // this device_name/compilation_device_name pair. 119 XlaOpRegistry::DeviceRegistration registration; 120 registration.compilation_device_name = jit_device_name; 121 registration.requires_compilation = true; 122 registration.enable_jit_by_default = false; 123 registration.compile_resource_ops = true; 124 XlaOpRegistry::RegisterCompilationDevice(device_name, registration); 125 } 126 127 auto platform = se::MultiPlatformManager::PlatformWithName(platform_name); 128 if (!platform.ok()) { 129 return StreamExecutorUtil::ConvertStatus(platform.status()); 130 } 131 132 const DeviceAttributes attrs = Device::BuildDeviceAttributes( 133 strings::StrCat(name_prefix, "/device:", device_name, ":", 134 device_ordinal), 135 DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), 136 strings::StrCat("device: ", device_name, " device")); 137 138 device->reset(new XlaDevice(options, attrs, device_ordinal, 139 DeviceType(jit_device_name), 140 platform.ValueOrDie())); 141 return Status::OK(); 142 } 143 144 XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform, 145 const DeviceType& device_type) 146 : device_ordinal_(device_ordinal), 147 device_type_(device_type), 148 platform_(platform) {} 149 150 int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } 151 152 se::Platform* XlaDevice::Metadata::platform() const { return platform_; } 153 154 xla::LocalClient* XlaDevice::Metadata::client() const { 155 auto client = xla::ClientLibrary::GetOrCreateLocalClient(platform_); 156 return client.ValueOrDie(); 157 } 158 159 const DeviceType& XlaDevice::Metadata::jit_device_type() const { 160 return device_type_; 161 } 162 163 /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, 164 const Metadata** metadata) { 165 XlaDevice* xla_device = 166 dynamic_cast<XlaDevice*>(ctx->device()->UnderlyingDevice()); 167 if (xla_device == nullptr) { 168 return errors::Internal( 169 "Cannot get XLA metadata from non-XLA device \"", ctx->device()->name(), 170 "\". GetMetadata must only be called on an XLA device. Either an " 171 "internal bug has been triggered, or an XLA-specific op has been " 172 "placed on the wrong device."); 173 } 174 *metadata = &(xla_device->xla_metadata_); 175 return Status::OK(); 176 } 177 178 XlaDevice::XlaDevice(const SessionOptions& options, 179 const DeviceAttributes& attrs, int device_ordinal, 180 const DeviceType& jit_device_name, se::Platform* platform) 181 : LocalDevice(options, attrs), 182 xla_metadata_(device_ordinal, platform, jit_device_name), 183 device_ordinal_(device_ordinal), 184 jit_device_name_(jit_device_name), 185 xla_allocator_(nullptr), 186 platform_(platform) {} 187 188 XlaDevice::~XlaDevice() {} 189 190 xla::LocalClient* XlaDevice::client() const { 191 // We lazily create the client because the platform commits to the 192 // details of the host hardware when the client is created, so we 193 // don't want to do it until we get a chance to hook the platform up 194 // to a simulator. 195 196 // For now GetOrCreateLocalClient always returns success when passed 197 // a non-null platform. If that changes we may have to plumb in some 198 // way to pass Status back. 199 return xla::ClientLibrary::GetOrCreateLocalClient(platform_).ValueOrDie(); 200 } 201 202 Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { 203 if (attr.on_host()) { 204 return cpu_allocator(); 205 } 206 207 if (xla_allocator_ == nullptr) { 208 xla::Backend* backend = client()->mutable_backend(); 209 xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( 210 backend, device_ordinal_); 211 } 212 return xla_allocator_; 213 } 214 215 xla::StatusOr<se::Stream*> XlaDevice::GetStream() { 216 if (!stream_) { 217 xla::Backend* backend = client()->mutable_backend(); 218 TF_ASSIGN_OR_RETURN(stream_, backend->BorrowStream(device_ordinal_)); 219 } 220 return stream_.get(); 221 } 222 223 Status XlaDevice::FillContextMap(const Graph* graph, 224 DeviceContextMap* device_context_map) { 225 VLOG(1) << "XlaDevice::FillContextMap"; 226 device_context_map->resize(graph->num_node_ids()); 227 TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); 228 auto ctx = new XlaDeviceContext(stream); 229 for (Node* n : graph->nodes()) { 230 VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); 231 ctx->Ref(); 232 (*device_context_map)[n->id()] = ctx; 233 } 234 ctx->Unref(); 235 return Status::OK(); 236 } 237 238 void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { 239 VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":" 240 << op_kernel->type_string(); 241 // When TraceMe profiling is off (which is the default), the 242 // following TraceMe constructor is simply a conditional test of 243 // false value. Measurements show that its overhead is negligible. 244 port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(), 245 op_kernel->IsExpensive()); 246 op_kernel->Compute(context); 247 } 248 249 void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, 250 AsyncOpKernel::DoneCallback done) { 251 VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" 252 << op_kernel->type_string(); 253 port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(), 254 op_kernel->IsExpensive()); 255 op_kernel->ComputeAsync(context, done); 256 } 257 258 Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, 259 const AllocatorAttributes alloc_attrs, 260 Tensor* tensor) { 261 VLOG(1) << "XlaDevice::MakeTensorFromProto"; 262 263 Tensor parsed(tensor_proto.dtype()); 264 if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { 265 return errors::InvalidArgument("Cannot parse tensor from proto: ", 266 tensor_proto.DebugString()); 267 } 268 269 Status status; 270 if (alloc_attrs.on_host()) { 271 *tensor = parsed; 272 } else { 273 Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); 274 Notification n; 275 TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); 276 XlaTransferManager manager(stream); 277 manager.CopyCPUTensorToDevice(&parsed, this, ©, 278 [&n, &status](const Status& s) { 279 status = s; 280 n.Notify(); 281 }); 282 n.WaitForNotification(); 283 *tensor = copy; 284 } 285 VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor); 286 return status; 287 } 288 289 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, 290 const char* jit_device) { 291 XlaOpRegistry::RegisterCompilationKernels(); 292 XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations; 293 auto dummy_factory = [](OpKernelConstruction* context) -> OpKernel* { 294 return new XlaDeviceDummyOp(context); 295 }; 296 for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels( 297 jit_device, 298 /*include_compilation_only_kernels=*/false)) { 299 KernelDef* def = new KernelDef(*jit_def); 300 def->set_device_type(device); 301 registrations->op_kernel_registrars.emplace_back( 302 new kernel_factory::OpKernelRegistrar(def, "XlaDeviceDummyOp", 303 dummy_factory)); 304 } 305 return registrations; 306 } 307 308 } // namespace tensorflow 309