1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_ 17 #define TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_ 18 19 #include <string> 20 #include <typeindex> 21 #include <typeinfo> 22 #include <unordered_map> 23 24 #include "tensorflow/core/framework/common_shape_fns.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/resource_handle.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/framework/tensor_shape.h" 29 #include "tensorflow/core/framework/tensor_types.h" 30 #include "tensorflow/core/framework/type_index.h" 31 #include "tensorflow/core/lib/core/errors.h" 32 #include "tensorflow/core/lib/core/refcount.h" 33 #include "tensorflow/core/lib/core/status.h" 34 #include "tensorflow/core/lib/hash/hash.h" 35 #include "tensorflow/core/platform/logging.h" 36 #include "tensorflow/core/platform/macros.h" 37 #include "tensorflow/core/platform/mutex.h" 38 #include "tensorflow/core/platform/thread_annotations.h" 39 40 namespace tensorflow { 41 42 // A ResourceMgr instance keeps track of named and typed resources 43 // grouped into containers. 44 // 45 // Each resource must be represented as a sub-class of ResourceBase, 46 // which is reference counted explicitly. Each named resource is 47 // registered with ResourceMgr under a named "container" name. At any 48 // time, there is at most one instance of a resource given the container 49 // name, the resource type and the resource name. 50 // 51 // All resources for a given container can be dropped by one call of 52 // Cleanup(). 53 // 54 // E.g., 55 // struct MyVar : public ResourceBase { 56 // mutex mu; 57 // Tensor val; 58 // } 59 // 60 // ResourceMgr rm; 61 // 62 // // Create a var. 63 // MyVar* my_var = new MyVar; 64 // my_var.val = Tensor(DT_FLOAT, my_shape); 65 // my_var.val.flat<float>().setZeros(); // 0 initialized. 66 // ctx->SetStatus(rm.Create("my_container", "my_name", my_var)); 67 // 68 // // += a variable. 69 // MyVar* my_var = nullptr; 70 // Status s = rm.Lookup("my_container", "my_name", &my_var); 71 // if (s.ok()) { 72 // my_var->val.flat<float>() += grad; 73 // } 74 // my_var->Unref(); // Or use ScopedUnref(). 75 // ctx->SetStatus(s); 76 class ResourceBase : public core::RefCounted { 77 public: 78 // Returns a debug string for *this. 79 virtual string DebugString() = 0; 80 81 // Returns memory used by this resource. 82 virtual int64 MemoryUsed() const { return 0; }; 83 }; 84 85 // Container used for per-step resources. 86 class ScopedStepContainer { 87 public: 88 // step_id: the unique ID of this step. Doesn't have to be sequential, just 89 // has to be unique. 90 // cleanup: callback to delete a container of this name. 91 ScopedStepContainer(const int64 step_id, 92 std::function<void(const string&)> cleanup) 93 : name_(strings::StrCat("__per_step_", step_id)), cleanup_(cleanup) {} 94 ~ScopedStepContainer() { cleanup_(name_); } 95 96 const string& name() const { return name_; } 97 98 private: 99 const string name_; 100 const std::function<void(const string&)> cleanup_; 101 }; 102 103 class ResourceMgr { 104 public: 105 ResourceMgr(); 106 explicit ResourceMgr(const string& default_container); 107 ~ResourceMgr(); 108 109 // Returns the default container name for *this. 110 const string& default_container() const { return default_container_; } 111 112 // Creates a resource "name" in the "container". The caller transfers 113 // the ownership of one ref on "resource" to *this 114 // 115 // REQUIRES: std::is_base_of<ResourceBase, T> 116 // REQUIRES: resource != nullptr. 117 template <typename T> 118 Status Create(const string& container, const string& name, 119 T* resource) TF_MUST_USE_RESULT; 120 121 // If "container" has a resource "name", returns it in "*resource" and 122 // the caller takes the ownership of one ref on "*resource". 123 // 124 // REQUIRES: std::is_base_of<ResourceBase, T> 125 // REQUIRES: resource != nullptr 126 template <typename T> 127 Status Lookup(const string& container, const string& name, 128 T** resource) const TF_MUST_USE_RESULT; 129 130 // If "container" has a resource "name", returns it in 131 // "*resource". Otherwise, invokes creator() to create the resource. 132 // The caller takes the ownership of one ref on "*resource". 133 // 134 // REQUIRES: std::is_base_of<ResourceBase, T> 135 // REQUIRES: resource != nullptr 136 template <typename T> 137 Status LookupOrCreate(const string& container, const string& name, 138 T** resource, 139 std::function<Status(T**)> creator) TF_MUST_USE_RESULT; 140 141 // Deletes the resource "name" from the "container". 142 // 143 // REQUIRES: std::is_base_of<ResourceBase, T> 144 template <typename T> 145 Status Delete(const string& container, const string& name) TF_MUST_USE_RESULT; 146 147 // Deletes the resource pointed by "handle". 148 Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT; 149 150 // Deletes all resources from the "container" and removes the container. 151 Status Cleanup(const string& container) TF_MUST_USE_RESULT; 152 153 // Deletes all resources in all containers. 154 void Clear(); 155 156 // Returns a text description for all resources. 157 string DebugString() const; 158 159 private: 160 typedef std::pair<uint64, string> Key; 161 struct KeyHash { 162 std::size_t operator()(const Key& k) const { 163 return Hash64(k.second.data(), k.second.size(), k.first); 164 } 165 }; 166 struct KeyEqual { 167 bool operator()(const Key& x, const Key& y) const { 168 return (x.second == y.second) && (x.first == y.first); 169 } 170 }; 171 typedef std::unordered_map<Key, ResourceBase*, KeyHash, KeyEqual> Container; 172 173 const string default_container_; 174 mutable mutex mu_; 175 std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_); 176 177 Status DoCreate(const string& container, TypeIndex type, const string& name, 178 ResourceBase* resource) TF_MUST_USE_RESULT; 179 Status DoLookup(const string& container, TypeIndex type, const string& name, 180 ResourceBase** resource) const TF_MUST_USE_RESULT; 181 Status DoDelete(const string& container, uint64 type_hash_code, 182 const string& resource_name, 183 const string& type_name) TF_MUST_USE_RESULT; 184 Status DoDelete(const string& container, TypeIndex type, 185 const string& resource_name) TF_MUST_USE_RESULT; 186 187 // Inserts the type name for 'hash_code' into the hash_code to type name map. 188 Status InsertDebugTypeName(uint64 hash_code, const string& type_name) 189 EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; 190 191 // Returns the type name for the 'hash_code'. 192 // Returns "<unknown>" if a resource with such a type was never inserted into 193 // the container. 194 const char* DebugTypeName(uint64 hash_code) const 195 EXCLUSIVE_LOCKS_REQUIRED(mu_); 196 197 // Map from type hash_code to type name. 198 std::unordered_map<uint64, string> debug_type_names_ GUARDED_BY(mu_); 199 200 TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr); 201 }; 202 203 // Makes a resource handle with the specified type for a given container / 204 // name. 205 ResourceHandle MakeResourceHandle(OpKernelContext* ctx, const string& container, 206 const string& name, 207 const TypeIndex& type_index); 208 209 template <typename T> 210 ResourceHandle MakeResourceHandle(OpKernelContext* ctx, const string& container, 211 const string& name) { 212 return MakeResourceHandle(ctx, container, name, MakeTypeIndex<T>()); 213 } 214 215 Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, 216 const string& container, const string& name, 217 const TypeIndex& type_index); 218 219 template <typename T> 220 ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx, 221 const string& name); 222 223 // Returns a resource handle from a numbered op input. 224 ResourceHandle HandleFromInput(OpKernelContext* ctx, int input); 225 Status HandleFromInput(OpKernelContext* ctx, StringPiece input, 226 ResourceHandle* handle); 227 228 // Create a resource pointed by a given resource handle. 229 template <typename T> 230 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value); 231 232 // Looks up a resource pointed by a given resource handle. 233 template <typename T> 234 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value); 235 236 // Looks up or creates a resource. 237 template <typename T> 238 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, 239 T** value, std::function<Status(T**)> creator); 240 241 // Destroys a resource pointed by a given resource handle. 242 template <typename T> 243 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); 244 245 // Same as above, but uses the hash code of the type directly. 246 // The type name information will be missing in the debug output when the 247 // resource is not present in the container. 248 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); 249 250 // Policy helper to decide which container/shared_name to use for a 251 // stateful kernel that accesses shared resource. 252 class ContainerInfo { 253 public: 254 // Analyze the node attribute of 'ndef' and decides the container and 255 // resource name the kernel should use for accessing the shared 256 // resource. 257 // 258 // 'ndef' is expected to have node attribute "container" and 259 // "shared_name". Returns non-OK if they are not provided or they are 260 // invalid. 261 // 262 // The policy is as following: 263 // * If the attribute "container" is non-empty, it is used as is. 264 // Otherwise, uses the resource manager's default container. 265 // * If the attribute "shared_name" is non-empty, it is used as is. 266 // Otherwise, if "use_node_name_as_default" is true, the kernel's 267 // node name is used as the resource name. Otherwise, a string 268 // unique to this process is used. 269 Status Init(ResourceMgr* rmgr, const NodeDef& ndef, 270 bool use_node_name_as_default); 271 Status Init(ResourceMgr* rmgr, const NodeDef& ndef) { 272 return Init(rmgr, ndef, false); 273 } 274 275 // The policy decides that the kernel should access the resource in 276 // resource_manager(), the resource is in the container() and its 277 // name is name(). If resource_is_private_to_kernel() is true, the 278 // kernel should delete the resource when the kernel is deleted. 279 ResourceMgr* resource_manager() const { return rmgr_; } 280 const string& container() const { return container_; } 281 const string& name() const { return name_; } 282 bool resource_is_private_to_kernel() const { 283 return resource_is_private_to_kernel_; 284 } 285 286 // Returns a readable string for *this. 287 string DebugString() const; 288 289 private: 290 ResourceMgr* rmgr_ = nullptr; 291 string container_; 292 string name_; 293 bool resource_is_private_to_kernel_ = false; 294 }; 295 296 // Helper for kernels to obtain 'resource' from the 297 // ctx->resource_manager(). 298 // 299 // "input_name" specifies the kernel's ref input which gives a string 300 // tensor with two elements, which specifies the container and 301 // resource name. 302 // 303 // Returns OK if the resource is found and transfers one ref of 304 // *resource to the caller. Otherwise, returns an error. 305 template <typename T> 306 Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name, 307 T** resource); 308 309 // Utility op kernel to check if a handle to resource type T is initialized. 310 template <typename T> 311 class IsResourceInitialized : public OpKernel { 312 public: 313 explicit IsResourceInitialized(OpKernelConstruction* c) : OpKernel(c) {} 314 315 void Compute(OpKernelContext* ctx) override; 316 }; 317 318 // Registers an op which produces just a resource handle to a resource of the 319 // specified type. The type will be a part of the generated op name. 320 // TODO(apassos): figure out how to get non-cpu-allocated tensors to work 321 // through constant folding so this doesn't have to be marked as stateful. 322 #define REGISTER_RESOURCE_HANDLE_OP(Type) \ 323 REGISTER_OP(#Type "HandleOp") \ 324 .Attr("container: string = ''") \ 325 .Attr("shared_name: string = ''") \ 326 .Output("resource: resource") \ 327 .SetIsStateful() \ 328 .SetShapeFn(tensorflow::shape_inference::ScalarShape) \ 329 .Doc("Creates a handle to a " #Type) 330 331 // Utility op kernel to produce a handle to a resource of type T. 332 template <typename T> 333 class ResourceHandleOp : public OpKernel { 334 public: 335 explicit ResourceHandleOp(OpKernelConstruction* context); 336 337 void Compute(OpKernelContext* ctx) override; 338 339 private: 340 string container_; 341 string name_; 342 }; 343 344 // Registers a kernel for an op which produces a handle to a resource of the 345 // specified type. 346 #define REGISTER_RESOURCE_HANDLE_KERNEL(Type) \ 347 REGISTER_KERNEL_BUILDER(Name(#Type "HandleOp").Device(DEVICE_CPU), \ 348 ResourceHandleOp<Type>) 349 350 // Implementation details below. 351 352 template <typename T> 353 void CheckDeriveFromResourceBase() { 354 static_assert(std::is_base_of<ResourceBase, T>::value, 355 "T must derive from ResourceBase"); 356 } 357 358 template <typename T> 359 Status ResourceMgr::Create(const string& container, const string& name, 360 T* resource) { 361 CheckDeriveFromResourceBase<T>(); 362 CHECK(resource != nullptr); 363 return DoCreate(container, MakeTypeIndex<T>(), name, resource); 364 } 365 366 template <typename T> 367 Status ResourceMgr::Lookup(const string& container, const string& name, 368 T** resource) const { 369 CheckDeriveFromResourceBase<T>(); 370 ResourceBase* found = nullptr; 371 Status s = DoLookup(container, MakeTypeIndex<T>(), name, &found); 372 if (s.ok()) { 373 // It's safe to down cast 'found' to T* since 374 // typeid(T).hash_code() is part of the map key. 375 *resource = static_cast<T*>(found); 376 } 377 return s; 378 } 379 380 template <typename T> 381 Status ResourceMgr::LookupOrCreate(const string& container, const string& name, 382 T** resource, 383 std::function<Status(T**)> creator) { 384 Status s; 385 *resource = nullptr; 386 while (*resource == nullptr) { 387 s = Lookup(container, name, resource); 388 if (s.ok()) break; 389 s = creator(resource); 390 if (!s.ok()) break; 391 s = Create(container, name, *resource); 392 if (s.ok()) { 393 (*resource)->Ref(); 394 break; 395 } 396 // Rare event. Concurrent racy creation. Redo the lookup. 397 *resource = nullptr; 398 } 399 return s; 400 } 401 402 template <typename T> 403 Status ResourceMgr::Delete(const string& container, const string& name) { 404 CheckDeriveFromResourceBase<T>(); 405 return DoDelete(container, MakeTypeIndex<T>(), name); 406 } 407 408 template <typename T> 409 Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name, 410 T** resource) { 411 DataType dtype; 412 TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype)); 413 if (dtype == DT_RESOURCE) { 414 const Tensor* handle; 415 TF_RETURN_IF_ERROR(ctx->input(input_name, &handle)); 416 return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource); 417 } 418 string container; 419 string shared_name; 420 { 421 mutex* mu; 422 TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu)); 423 mutex_lock l(*mu); 424 Tensor tensor; 425 TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true)); 426 if (tensor.NumElements() != 2) { 427 return errors::InvalidArgument( 428 "Resource handle must have 2 elements, but had shape: ", 429 tensor.shape().DebugString()); 430 } 431 container = tensor.flat<string>()(0); 432 shared_name = tensor.flat<string>()(1); 433 } 434 return ctx->resource_manager()->Lookup(container, shared_name, resource); 435 } 436 437 template <typename T> 438 ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx, 439 const string& name) { 440 return MakeResourceHandle<T>(ctx, ctx->step_container()->name(), name); 441 } 442 443 namespace internal { 444 445 Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p); 446 447 template <typename T> 448 Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) { 449 TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); 450 auto type_index = MakeTypeIndex<T>(); 451 if (type_index.hash_code() != p.hash_code()) { 452 return errors::InvalidArgument( 453 "Trying to access resource using the wrong type. Expected ", 454 p.maybe_type_name(), " got ", type_index.name()); 455 } 456 return Status::OK(); 457 } 458 459 } // namespace internal 460 461 template <typename T> 462 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) { 463 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p)); 464 return ctx->resource_manager()->Create(p.container(), p.name(), value); 465 } 466 467 template <typename T> 468 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, 469 T** value) { 470 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p)); 471 return ctx->resource_manager()->Lookup(p.container(), p.name(), value); 472 } 473 474 template <typename T> 475 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, 476 T** value, std::function<Status(T**)> creator) { 477 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p)); 478 return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value, 479 creator); 480 } 481 482 template <typename T> 483 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { 484 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p)); 485 return ctx->resource_manager()->Delete<T>(p.container(), p.name()); 486 } 487 488 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); 489 490 template <typename T> 491 void IsResourceInitialized<T>::Compute(OpKernelContext* ctx) { 492 Tensor* output; 493 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output)); 494 T* object; 495 bool found; 496 if (LookupResource(ctx, HandleFromInput(ctx, 0), &object).ok()) { 497 found = true; 498 object->Unref(); 499 } else { 500 found = false; 501 } 502 503 output->flat<bool>()(0) = found; 504 } 505 506 template <typename T> 507 ResourceHandleOp<T>::ResourceHandleOp(OpKernelConstruction* context) 508 : OpKernel(context) { 509 OP_REQUIRES_OK(context, context->GetAttr("container", &container_)); 510 OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_)); 511 } 512 513 template <typename T> 514 void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) { 515 Tensor* output = nullptr; 516 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); 517 output->scalar<ResourceHandle>()() = 518 MakeResourceHandle<T>(ctx, container_, name_); 519 } 520 521 } // end namespace tensorflow 522 523 #endif // TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_ 524