Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
     17 #define TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
     18 
     19 #include <string>
     20 #include <typeindex>
     21 #include <typeinfo>
     22 #include <unordered_map>
     23 
     24 #include "tensorflow/core/framework/common_shape_fns.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/resource_handle.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor_shape.h"
     29 #include "tensorflow/core/framework/tensor_types.h"
     30 #include "tensorflow/core/framework/type_index.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/lib/core/refcount.h"
     33 #include "tensorflow/core/lib/core/status.h"
     34 #include "tensorflow/core/lib/hash/hash.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/platform/macros.h"
     37 #include "tensorflow/core/platform/mutex.h"
     38 #include "tensorflow/core/platform/thread_annotations.h"
     39 
     40 namespace tensorflow {
     41 
     42 // A ResourceMgr instance keeps track of named and typed resources
     43 // grouped into containers.
     44 //
     45 // Each resource must be represented as a sub-class of ResourceBase,
     46 // which is reference counted explicitly.  Each named resource is
     47 // registered with ResourceMgr under a named "container" name. At any
     48 // time, there is at most one instance of a resource given the container
     49 // name, the resource type and the resource name.
     50 //
     51 // All resources for a given container can be dropped by one call of
     52 // Cleanup().
     53 //
     54 // E.g.,
     55 //   struct MyVar : public ResourceBase {
     56 //     mutex mu;
     57 //     Tensor val;
     58 //   }
     59 //
     60 //   ResourceMgr rm;
     61 //
     62 //   // Create a var.
     63 //   MyVar* my_var = new MyVar;
     64 //   my_var.val = Tensor(DT_FLOAT, my_shape);
     65 //   my_var.val.flat<float>().setZeros();   // 0 initialized.
     66 //   ctx->SetStatus(rm.Create("my_container", "my_name", my_var));
     67 //
     68 //   // += a variable.
     69 //   MyVar* my_var = nullptr;
     70 //   Status s = rm.Lookup("my_container", "my_name", &my_var);
     71 //   if (s.ok()) {
     72 //     my_var->val.flat<float>() += grad;
     73 //   }
     74 //   my_var->Unref();   // Or use ScopedUnref().
     75 //   ctx->SetStatus(s);
     76 class ResourceBase : public core::RefCounted {
     77  public:
     78   // Returns a debug string for *this.
     79   virtual string DebugString() = 0;
     80 
     81   // Returns memory used by this resource.
     82   virtual int64 MemoryUsed() const { return 0; };
     83 };
     84 
     85 // Container used for per-step resources.
     86 class ScopedStepContainer {
     87  public:
     88   // step_id: the unique ID of this step. Doesn't have to be sequential, just
     89   // has to be unique.
     90   // cleanup: callback to delete a container of this name.
     91   ScopedStepContainer(const int64 step_id,
     92                       std::function<void(const string&)> cleanup)
     93       : name_(strings::StrCat("__per_step_", step_id)), cleanup_(cleanup) {}
     94   ~ScopedStepContainer() { cleanup_(name_); }
     95 
     96   const string& name() const { return name_; }
     97 
     98  private:
     99   const string name_;
    100   const std::function<void(const string&)> cleanup_;
    101 };
    102 
    103 class ResourceMgr {
    104  public:
    105   ResourceMgr();
    106   explicit ResourceMgr(const string& default_container);
    107   ~ResourceMgr();
    108 
    109   // Returns the default container name for *this.
    110   const string& default_container() const { return default_container_; }
    111 
    112   // Creates a resource "name" in the "container".  The caller transfers
    113   // the ownership of one ref on "resource" to *this
    114   //
    115   // REQUIRES: std::is_base_of<ResourceBase, T>
    116   // REQUIRES: resource != nullptr.
    117   template <typename T>
    118   Status Create(const string& container, const string& name,
    119                 T* resource) TF_MUST_USE_RESULT;
    120 
    121   // If "container" has a resource "name", returns it in "*resource" and
    122   // the caller takes the ownership of one ref on "*resource".
    123   //
    124   // REQUIRES: std::is_base_of<ResourceBase, T>
    125   // REQUIRES: resource != nullptr
    126   template <typename T>
    127   Status Lookup(const string& container, const string& name,
    128                 T** resource) const TF_MUST_USE_RESULT;
    129 
    130   // If "container" has a resource "name", returns it in
    131   // "*resource". Otherwise, invokes creator() to create the resource.
    132   // The caller takes the ownership of one ref on "*resource".
    133   //
    134   // REQUIRES: std::is_base_of<ResourceBase, T>
    135   // REQUIRES: resource != nullptr
    136   template <typename T>
    137   Status LookupOrCreate(const string& container, const string& name,
    138                         T** resource,
    139                         std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
    140 
    141   // Deletes the resource "name" from the "container".
    142   //
    143   // REQUIRES: std::is_base_of<ResourceBase, T>
    144   template <typename T>
    145   Status Delete(const string& container, const string& name) TF_MUST_USE_RESULT;
    146 
    147   // Deletes the resource pointed by "handle".
    148   Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT;
    149 
    150   // Deletes all resources from the "container" and removes the container.
    151   Status Cleanup(const string& container) TF_MUST_USE_RESULT;
    152 
    153   // Deletes all resources in all containers.
    154   void Clear();
    155 
    156   // Returns a text description for all resources.
    157   string DebugString() const;
    158 
    159  private:
    160   typedef std::pair<uint64, string> Key;
    161   struct KeyHash {
    162     std::size_t operator()(const Key& k) const {
    163       return Hash64(k.second.data(), k.second.size(), k.first);
    164     }
    165   };
    166   struct KeyEqual {
    167     bool operator()(const Key& x, const Key& y) const {
    168       return (x.second == y.second) && (x.first == y.first);
    169     }
    170   };
    171   typedef std::unordered_map<Key, ResourceBase*, KeyHash, KeyEqual> Container;
    172 
    173   const string default_container_;
    174   mutable mutex mu_;
    175   std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_);
    176 
    177   Status DoCreate(const string& container, TypeIndex type, const string& name,
    178                   ResourceBase* resource) TF_MUST_USE_RESULT;
    179   Status DoLookup(const string& container, TypeIndex type, const string& name,
    180                   ResourceBase** resource) const TF_MUST_USE_RESULT;
    181   Status DoDelete(const string& container, uint64 type_hash_code,
    182                   const string& resource_name,
    183                   const string& type_name) TF_MUST_USE_RESULT;
    184   Status DoDelete(const string& container, TypeIndex type,
    185                   const string& resource_name) TF_MUST_USE_RESULT;
    186 
    187   // Inserts the type name for 'hash_code' into the hash_code to type name map.
    188   Status InsertDebugTypeName(uint64 hash_code, const string& type_name)
    189       EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
    190 
    191   // Returns the type name for the 'hash_code'.
    192   // Returns "<unknown>" if a resource with such a type was never inserted into
    193   // the container.
    194   const char* DebugTypeName(uint64 hash_code) const
    195       EXCLUSIVE_LOCKS_REQUIRED(mu_);
    196 
    197   // Map from type hash_code to type name.
    198   std::unordered_map<uint64, string> debug_type_names_ GUARDED_BY(mu_);
    199 
    200   TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr);
    201 };
    202 
    203 // Makes a resource handle with the specified type for a given container /
    204 // name.
    205 ResourceHandle MakeResourceHandle(OpKernelContext* ctx, const string& container,
    206                                   const string& name,
    207                                   const TypeIndex& type_index);
    208 
    209 template <typename T>
    210 ResourceHandle MakeResourceHandle(OpKernelContext* ctx, const string& container,
    211                                   const string& name) {
    212   return MakeResourceHandle(ctx, container, name, MakeTypeIndex<T>());
    213 }
    214 
    215 Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
    216                                   const string& container, const string& name,
    217                                   const TypeIndex& type_index);
    218 
    219 template <typename T>
    220 ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
    221                                          const string& name);
    222 
    223 // Returns a resource handle from a numbered op input.
    224 ResourceHandle HandleFromInput(OpKernelContext* ctx, int input);
    225 Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
    226                        ResourceHandle* handle);
    227 
    228 // Create a resource pointed by a given resource handle.
    229 template <typename T>
    230 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
    231 
    232 // Looks up a resource pointed by a given resource handle.
    233 template <typename T>
    234 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
    235 
    236 // Looks up or creates a resource.
    237 template <typename T>
    238 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
    239                               T** value, std::function<Status(T**)> creator);
    240 
    241 // Destroys a resource pointed by a given resource handle.
    242 template <typename T>
    243 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
    244 
    245 // Same as above, but uses the hash code of the type directly.
    246 // The type name information will be missing in the debug output when the
    247 // resource is not present in the container.
    248 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
    249 
    250 // Policy helper to decide which container/shared_name to use for a
    251 // stateful kernel that accesses shared resource.
    252 class ContainerInfo {
    253  public:
    254   // Analyze the node attribute of 'ndef' and decides the container and
    255   // resource name the kernel should use for accessing the shared
    256   // resource.
    257   //
    258   // 'ndef' is expected to have node attribute "container" and
    259   // "shared_name". Returns non-OK if they are not provided or they are
    260   // invalid.
    261   //
    262   // The policy is as following:
    263   // * If the attribute "container" is non-empty, it is used as is.
    264   //   Otherwise, uses the resource manager's default container.
    265   // * If the attribute "shared_name" is non-empty, it is used as is.
    266   //   Otherwise, if "use_node_name_as_default" is true, the kernel's
    267   //   node name is used as the resource name. Otherwise, a string
    268   //   unique to this process is used.
    269   Status Init(ResourceMgr* rmgr, const NodeDef& ndef,
    270               bool use_node_name_as_default);
    271   Status Init(ResourceMgr* rmgr, const NodeDef& ndef) {
    272     return Init(rmgr, ndef, false);
    273   }
    274 
    275   // The policy decides that the kernel should access the resource in
    276   // resource_manager(), the resource is in the container() and its
    277   // name is name().  If resource_is_private_to_kernel() is true, the
    278   // kernel should delete the resource when the kernel is deleted.
    279   ResourceMgr* resource_manager() const { return rmgr_; }
    280   const string& container() const { return container_; }
    281   const string& name() const { return name_; }
    282   bool resource_is_private_to_kernel() const {
    283     return resource_is_private_to_kernel_;
    284   }
    285 
    286   // Returns a readable string for *this.
    287   string DebugString() const;
    288 
    289  private:
    290   ResourceMgr* rmgr_ = nullptr;
    291   string container_;
    292   string name_;
    293   bool resource_is_private_to_kernel_ = false;
    294 };
    295 
    296 // Helper for kernels to obtain 'resource' from the
    297 // ctx->resource_manager().
    298 //
    299 // "input_name" specifies the kernel's ref input which gives a string
    300 // tensor with two elements, which specifies the container and
    301 // resource name.
    302 //
    303 // Returns OK if the resource is found and transfers one ref of
    304 // *resource to the caller. Otherwise, returns an error.
    305 template <typename T>
    306 Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
    307                               T** resource);
    308 
    309 // Utility op kernel to check if a handle to resource type T is initialized.
    310 template <typename T>
    311 class IsResourceInitialized : public OpKernel {
    312  public:
    313   explicit IsResourceInitialized(OpKernelConstruction* c) : OpKernel(c) {}
    314 
    315   void Compute(OpKernelContext* ctx) override;
    316 };
    317 
    318 // Registers an op which produces just a resource handle to a resource of the
    319 // specified type. The type will be a part of the generated op name.
    320 // TODO(apassos): figure out how to get non-cpu-allocated tensors to work
    321 // through constant folding so this doesn't have to be marked as stateful.
    322 #define REGISTER_RESOURCE_HANDLE_OP(Type)                   \
    323   REGISTER_OP(#Type "HandleOp")                             \
    324       .Attr("container: string = ''")                       \
    325       .Attr("shared_name: string = ''")                     \
    326       .Output("resource: resource")                         \
    327       .SetIsStateful()                                      \
    328       .SetShapeFn(tensorflow::shape_inference::ScalarShape) \
    329       .Doc("Creates a handle to a " #Type)
    330 
    331 // Utility op kernel to produce a handle to a resource of type T.
    332 template <typename T>
    333 class ResourceHandleOp : public OpKernel {
    334  public:
    335   explicit ResourceHandleOp(OpKernelConstruction* context);
    336 
    337   void Compute(OpKernelContext* ctx) override;
    338 
    339  private:
    340   string container_;
    341   string name_;
    342 };
    343 
    344 // Registers a kernel for an op which produces a handle to a resource of the
    345 // specified type.
    346 #define REGISTER_RESOURCE_HANDLE_KERNEL(Type)                        \
    347   REGISTER_KERNEL_BUILDER(Name(#Type "HandleOp").Device(DEVICE_CPU), \
    348                           ResourceHandleOp<Type>)
    349 
    350 // Implementation details below.
    351 
    352 template <typename T>
    353 void CheckDeriveFromResourceBase() {
    354   static_assert(std::is_base_of<ResourceBase, T>::value,
    355                 "T must derive from ResourceBase");
    356 }
    357 
    358 template <typename T>
    359 Status ResourceMgr::Create(const string& container, const string& name,
    360                            T* resource) {
    361   CheckDeriveFromResourceBase<T>();
    362   CHECK(resource != nullptr);
    363   return DoCreate(container, MakeTypeIndex<T>(), name, resource);
    364 }
    365 
    366 template <typename T>
    367 Status ResourceMgr::Lookup(const string& container, const string& name,
    368                            T** resource) const {
    369   CheckDeriveFromResourceBase<T>();
    370   ResourceBase* found = nullptr;
    371   Status s = DoLookup(container, MakeTypeIndex<T>(), name, &found);
    372   if (s.ok()) {
    373     // It's safe to down cast 'found' to T* since
    374     // typeid(T).hash_code() is part of the map key.
    375     *resource = static_cast<T*>(found);
    376   }
    377   return s;
    378 }
    379 
    380 template <typename T>
    381 Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
    382                                    T** resource,
    383                                    std::function<Status(T**)> creator) {
    384   Status s;
    385   *resource = nullptr;
    386   while (*resource == nullptr) {
    387     s = Lookup(container, name, resource);
    388     if (s.ok()) break;
    389     s = creator(resource);
    390     if (!s.ok()) break;
    391     s = Create(container, name, *resource);
    392     if (s.ok()) {
    393       (*resource)->Ref();
    394       break;
    395     }
    396     // Rare event. Concurrent racy creation. Redo the lookup.
    397     *resource = nullptr;
    398   }
    399   return s;
    400 }
    401 
    402 template <typename T>
    403 Status ResourceMgr::Delete(const string& container, const string& name) {
    404   CheckDeriveFromResourceBase<T>();
    405   return DoDelete(container, MakeTypeIndex<T>(), name);
    406 }
    407 
    408 template <typename T>
    409 Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
    410                               T** resource) {
    411   DataType dtype;
    412   TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype));
    413   if (dtype == DT_RESOURCE) {
    414     const Tensor* handle;
    415     TF_RETURN_IF_ERROR(ctx->input(input_name, &handle));
    416     return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource);
    417   }
    418   string container;
    419   string shared_name;
    420   {
    421     mutex* mu;
    422     TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
    423     mutex_lock l(*mu);
    424     Tensor tensor;
    425     TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
    426     if (tensor.NumElements() != 2) {
    427       return errors::InvalidArgument(
    428           "Resource handle must have 2 elements, but had shape: ",
    429           tensor.shape().DebugString());
    430     }
    431     container = tensor.flat<string>()(0);
    432     shared_name = tensor.flat<string>()(1);
    433   }
    434   return ctx->resource_manager()->Lookup(container, shared_name, resource);
    435 }
    436 
    437 template <typename T>
    438 ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
    439                                          const string& name) {
    440   return MakeResourceHandle<T>(ctx, ctx->step_container()->name(), name);
    441 }
    442 
    443 namespace internal {
    444 
    445 Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p);
    446 
    447 template <typename T>
    448 Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) {
    449   TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
    450   auto type_index = MakeTypeIndex<T>();
    451   if (type_index.hash_code() != p.hash_code()) {
    452     return errors::InvalidArgument(
    453         "Trying to access resource using the wrong type. Expected ",
    454         p.maybe_type_name(), " got ", type_index.name());
    455   }
    456   return Status::OK();
    457 }
    458 
    459 }  // namespace internal
    460 
    461 template <typename T>
    462 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
    463   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
    464   return ctx->resource_manager()->Create(p.container(), p.name(), value);
    465 }
    466 
    467 template <typename T>
    468 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
    469                       T** value) {
    470   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
    471   return ctx->resource_manager()->Lookup(p.container(), p.name(), value);
    472 }
    473 
    474 template <typename T>
    475 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
    476                               T** value, std::function<Status(T**)> creator) {
    477   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
    478   return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value,
    479                                                  creator);
    480 }
    481 
    482 template <typename T>
    483 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
    484   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
    485   return ctx->resource_manager()->Delete<T>(p.container(), p.name());
    486 }
    487 
    488 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
    489 
    490 template <typename T>
    491 void IsResourceInitialized<T>::Compute(OpKernelContext* ctx) {
    492   Tensor* output;
    493   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
    494   T* object;
    495   bool found;
    496   if (LookupResource(ctx, HandleFromInput(ctx, 0), &object).ok()) {
    497     found = true;
    498     object->Unref();
    499   } else {
    500     found = false;
    501   }
    502 
    503   output->flat<bool>()(0) = found;
    504 }
    505 
    506 template <typename T>
    507 ResourceHandleOp<T>::ResourceHandleOp(OpKernelConstruction* context)
    508     : OpKernel(context) {
    509   OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
    510   OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
    511 }
    512 
    513 template <typename T>
    514 void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
    515   Tensor* output = nullptr;
    516   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
    517   output->scalar<ResourceHandle>()() =
    518       MakeResourceHandle<T>(ctx, container_, name_);
    519 }
    520 
    521 }  //  end namespace tensorflow
    522 
    523 #endif  // TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
    524