Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      7     http://www.apache.org/licenses/LICENSE-2.0
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     19 #include <string>
     20 #include <typeindex>
     21 #include <typeinfo>
     22 #include <unordered_map>
     24 #include "tensorflow/core/framework/common_shape_fns.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/resource_handle.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor_shape.h"
     29 #include "tensorflow/core/framework/tensor_types.h"
     30 #include "tensorflow/core/framework/type_index.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/lib/core/refcount.h"
     33 #include "tensorflow/core/lib/core/status.h"
     34 #include "tensorflow/core/lib/hash/hash.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/platform/macros.h"
     37 #include "tensorflow/core/platform/mutex.h"
     38 #include "tensorflow/core/platform/thread_annotations.h"
     40 namespace tensorflow {
     42 // A ResourceMgr instance keeps track of named and typed resources
     43 // grouped into containers.
     44 //
     45 // Each resource must be represented as a sub-class of ResourceBase,
     46 // which is reference counted explicitly.  Each named resource is
     47 // registered with ResourceMgr under a named "container" name. At any
     48 // time, there is at most one instance of a resource given the container
     49 // name, the resource type and the resource name.
     50 //
     51 // All resources for a given container can be dropped by one call of
     52 // Cleanup().
     53 //
     54 // E.g.,
     55 //   struct MyVar : public ResourceBase {
     56 //     mutex mu;
     57 //     Tensor val;
     58 //   }
     59 //
     60 //   ResourceMgr rm;
     61 //
     62 //   // Create a var.
     63 //   MyVar* my_var = new MyVar;
     64 //   my_var.val = Tensor(DT_FLOAT, my_shape);
     65 //   my_var.val.flat<float>().setZeros();   // 0 initialized.
     66 //   ctx->SetStatus(rm.Create("my_container", "my_name", my_var));
     67 //
     68 //   // += a variable.
     69 //   MyVar* my_var = nullptr;
     70 //   Status s = rm.Lookup("my_container", "my_name", &my_var);
     71 //   if (s.ok()) {
     72 //     my_var->val.flat<float>() += grad;
     73 //   }
     74 //   my_var->Unref();   // Or use ScopedUnref().
     75 //   ctx->SetStatus(s);
     76 class ResourceBase : public core::RefCounted {
     77  public:
     78   // Returns a debug string for *this.
     79   virtual string DebugString() = 0;
     81   // Returns memory used by this resource.
     82   virtual int64 MemoryUsed() const { return 0; };
     83 };
     85 // Container used for per-step resources.
     86 class ScopedStepContainer {
     87  public:
     88   // step_id: the unique ID of this step. Doesn't have to be sequential, just
     89   // has to be unique.
     90   // cleanup: callback to delete a container of this name.
     91   ScopedStepContainer(const int64 step_id,
     92                       std::function<void(const string&)> cleanup)
     93       : name_(strings::StrCat("__per_step_", step_id)), cleanup_(cleanup) {}
     94   ~ScopedStepContainer() { cleanup_(name_); }
     96   const string& name() const { return name_; }
     98  private:
     99   const string name_;
    100   const std::function<void(const string&)> cleanup_;
    101 };
    103 class ResourceMgr {
    104  public:
    105   ResourceMgr();
    106   explicit ResourceMgr(const string& default_container);
    107   ~ResourceMgr();
    109   // Returns the default container name for *this.
    110   const string& default_container() const { return default_container_; }
    112   // Creates a resource "name" in the "container".  The caller transfers
    113   // the ownership of one ref on "resource" to *this
    114   //
    115   // REQUIRES: std::is_base_of<ResourceBase, T>
    116   // REQUIRES: resource != nullptr.
    117   template <typename T>
    118   Status Create(const string& container, const string& name,
    119                 T* resource) TF_MUST_USE_RESULT;
    121   // If "container" has a resource "name", returns it in "*resource" and
    122   // the caller takes the ownership of one ref on "*resource".
    123   //
    124   // REQUIRES: std::is_base_of<ResourceBase, T>
    125   // REQUIRES: resource != nullptr
    126   template <typename T>
    127   Status Lookup(const string& container, const string& name,
    128                 T** resource) const TF_MUST_USE_RESULT;
    130   // If "container" has a resource "name", returns it in
    131   // "*resource". Otherwise, invokes creator() to create the resource.
    132   // The caller takes the ownership of one ref on "*resource".
    133   //
    134   // REQUIRES: std::is_base_of<ResourceBase, T>
    135   // REQUIRES: resource != nullptr
    136   template <typename T>
    137   Status LookupOrCreate(const string& container, const string& name,
    138                         T** resource,
    139                         std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
    141   // Deletes the resource "name" from the "container".
    142   //
    143   // REQUIRES: std::is_base_of<ResourceBase, T>
    144   template <typename T>
    145   Status Delete(const string& container, const string& name) TF_MUST_USE_RESULT;
    147   // Deletes the resource pointed by "handle".
    148   Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT;
    150   // Deletes all resources from the "container" and removes the container.
    151   Status Cleanup(const string& container) TF_MUST_USE_RESULT;
    153   // Deletes all resources in all containers.
    154   void Clear();
    156   // Returns a text description for all resources.
    157   string DebugString() const;
    159  private:
    160   typedef std::pair<uint64, string> Key;
    161   struct KeyHash {
    162     std::size_t operator()(const Key& k) const {
    163       return Hash64(k.second.data(), k.second.size(), k.first);
    164     }
    165   };
    166   struct KeyEqual {
    167     bool operator()(const Key& x, const Key& y) const {
    168       return (x.second == y.second) && (x.first == y.first);
    169     }
    170   };
    171   typedef std::unordered_map<Key, ResourceBase*, KeyHash, KeyEqual> Container;
    173   const string default_container_;
    174   mutable mutex mu_;
    175   std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_);
    177   Status DoCreate(const string& container, TypeIndex type, const string& name,
    178                   ResourceBase* resource) TF_MUST_USE_RESULT;
    179   Status DoLookup(const string& container, TypeIndex type, const string& name,
    180                   ResourceBase** resource) const TF_MUST_USE_RESULT;
    181   Status DoDelete(const string& container, uint64 type_hash_code,
    182                   const string& resource_name,
    183                   const string& type_name) TF_MUST_USE_RESULT;
    184   Status DoDelete(const string& container, TypeIndex type,
    185                   const string& resource_name) TF_MUST_USE_RESULT;
    187   // Inserts the type name for 'hash_code' into the hash_code to type name map.
    188   Status InsertDebugTypeName(uint64 hash_code, const string& type_name)
    191   // Returns the type name for the 'hash_code'.
    192   // Returns "<unknown>" if a resource with such a type was never inserted into
    193   // the container.
    194   const char* DebugTypeName(uint64 hash_code) const
    195       EXCLUSIVE_LOCKS_REQUIRED(mu_);
    197   // Map from type hash_code to type name.
    198   std::unordered_map<uint64, string> debug_type_names_ GUARDED_BY(mu_);
    200   TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr);
    201 };
    203 // Makes a resource handle with the specified type for a given container /
    204 // name.
    205 ResourceHandle MakeResourceHandle(OpKernelContext* ctx, const string& container,
    206                                   const string& name,
    207                                   const TypeIndex& type_index);
    209 template <typename T>
    210 ResourceHandle MakeResourceHandle(OpKernelContext* ctx, const string& container,
    211                                   const string& name) {
    212   return MakeResourceHandle(ctx, container, name, MakeTypeIndex<T>());
    213 }
    215 Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
    216                                   const string& container, const string& name,
    217                                   const TypeIndex& type_index);
    219 template <typename T>
    220 ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
    221                                          const string& name);
    223 // Returns a resource handle from a numbered op input.
    224 ResourceHandle HandleFromInput(OpKernelContext* ctx, int input);
    225 Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
    226                        ResourceHandle* handle);
    228 // Create a resource pointed by a given resource handle.
    229 template <typename T>
    230 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
    232 // Looks up a resource pointed by a given resource handle.
    233 template <typename T>
    234 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
    236 // Looks up or creates a resource.
    237 template <typename T>
    238 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
    239                               T** value, std::function<Status(T**)> creator);
    241 // Destroys a resource pointed by a given resource handle.
    242 template <typename T>
    243 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
    245 // Same as above, but uses the hash code of the type directly.
    246 // The type name information will be missing in the debug output when the
    247 // resource is not present in the container.
    248 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
    250 // Policy helper to decide which container/shared_name to use for a
    251 // stateful kernel that accesses shared resource.
    252 class ContainerInfo {
    253  public:
    254   // Analyze the node attribute of 'ndef' and decides the container and
    255   // resource name the kernel should use for accessing the shared
    256   // resource.
    257   //
    258   // 'ndef' is expected to have node attribute "container" and
    259   // "shared_name". Returns non-OK if they are not provided or they are
    260   // invalid.
    261   //
    262   // The policy is as following:
    263   // * If the attribute "container" is non-empty, it is used as is.
    264   //   Otherwise, uses the resource manager's default container.
    265   // * If the attribute "shared_name" is non-empty, it is used as is.
    266   //   Otherwise, if "use_node_name_as_default" is true, the kernel's
    267   //   node name is used as the resource name. Otherwise, a string
    268   //   unique to this process is used.
    269   Status Init(ResourceMgr* rmgr, const NodeDef& ndef,
    270               bool use_node_name_as_default);
    271   Status Init(ResourceMgr* rmgr, const NodeDef& ndef) {
    272     return Init(rmgr, ndef, false);
    273   }
    275   // The policy decides that the kernel should access the resource in
    276   // resource_manager(), the resource is in the container() and its
    277   // name is name().  If resource_is_private_to_kernel() is true, the
    278   // kernel should delete the resource when the kernel is deleted.
    279   ResourceMgr* resource_manager() const { return rmgr_; }
    280   const string& container() const { return container_; }
    281   const string& name() const { return name_; }
    282   bool resource_is_private_to_kernel() const {
    283     return resource_is_private_to_kernel_;
    284   }
    286   // Returns a readable string for *this.
    287   string DebugString() const;
    289  private:
    290   ResourceMgr* rmgr_ = nullptr;
    291   string container_;
    292   string name_;
    293   bool resource_is_private_to_kernel_ = false;
    294 };
    296 // Helper for kernels to obtain 'resource' from the
    297 // ctx->resource_manager().
    298 //
    299 // "input_name" specifies the kernel's ref input which gives a string
    300 // tensor with two elements, which specifies the container and
    301 // resource name.
    302 //
    303 // Returns OK if the resource is found and transfers one ref of
    304 // *resource to the caller. Otherwise, returns an error.
    305 template <typename T>
    306 Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
    307                               T** resource);
    309 // Utility op kernel to check if a handle to resource type T is initialized.
    310 template <typename T>
    311 class IsResourceInitialized : public OpKernel {
    312  public:
    313   explicit IsResourceInitialized(OpKernelConstruction* c) : OpKernel(c) {}
    315   void Compute(OpKernelContext* ctx) override;
    316 };
    318 // Registers an op which produces just a resource handle to a resource of the
    319 // specified type. The type will be a part of the generated op name.
    320 // TODO(apassos): figure out how to get non-cpu-allocated tensors to work
    321 // through constant folding so this doesn't have to be marked as stateful.
    322 #define REGISTER_RESOURCE_HANDLE_OP(Type)                   \
    323   REGISTER_OP(#Type "HandleOp")                             \
    324       .Attr("container: string = ''")                       \
    325       .Attr("shared_name: string = ''")                     \
    326       .Output("resource: resource")                         \
    327       .SetIsStateful()                                      \
    328       .SetShapeFn(tensorflow::shape_inference::ScalarShape) \
    329       .Doc("Creates a handle to a " #Type)
    331 // Utility op kernel to produce a handle to a resource of type T.
    332 template <typename T>
    333 class ResourceHandleOp : public OpKernel {
    334  public:
    335   explicit ResourceHandleOp(OpKernelConstruction* context);
    337   void Compute(OpKernelContext* ctx) override;
    339  private:
    340   string container_;
    341   string name_;
    342 };
    344 // Registers a kernel for an op which produces a handle to a resource of the
    345 // specified type.
    346 #define REGISTER_RESOURCE_HANDLE_KERNEL(Type)                        \
    347   REGISTER_KERNEL_BUILDER(Name(#Type "HandleOp").Device(DEVICE_CPU), \
    348                           ResourceHandleOp<Type>)
    350 // Implementation details below.
    352 template <typename T>
    353 void CheckDeriveFromResourceBase() {
    354   static_assert(std::is_base_of<ResourceBase, T>::value,
    355                 "T must derive from ResourceBase");
    356 }
    358 template <typename T>
    359 Status ResourceMgr::Create(const string& container, const string& name,
    360                            T* resource) {
    361   CheckDeriveFromResourceBase<T>();
    362   CHECK(resource != nullptr);
    363   return DoCreate(container, MakeTypeIndex<T>(), name, resource);
    364 }
    366 template <typename T>
    367 Status ResourceMgr::Lookup(const string& container, const string& name,
    368                            T** resource) const {
    369   CheckDeriveFromResourceBase<T>();
    370   ResourceBase* found = nullptr;
    371   Status s = DoLookup(container, MakeTypeIndex<T>(), name, &found);
    372   if (s.ok()) {
    373     // It's safe to down cast 'found' to T* since
    374     // typeid(T).hash_code() is part of the map key.
    375     *resource = static_cast<T*>(found);
    376   }
    377   return s;
    378 }
    380 template <typename T>
    381 Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
    382                                    T** resource,
    383                                    std::function<Status(T**)> creator) {
    384   Status s;
    385   *resource = nullptr;
    386   while (*resource == nullptr) {
    387     s = Lookup(container, name, resource);
    388     if (s.ok()) break;
    389     s = creator(resource);
    390     if (!s.ok()) break;
    391     s = Create(container, name, *resource);
    392     if (s.ok()) {
    393       (*resource)->Ref();
    394       break;
    395     }
    396     // Rare event. Concurrent racy creation. Redo the lookup.
    397     *resource = nullptr;
    398   }
    399   return s;
    400 }
    402 template <typename T>
    403 Status ResourceMgr::Delete(const string& container, const string& name) {
    404   CheckDeriveFromResourceBase<T>();
    405   return DoDelete(container, MakeTypeIndex<T>(), name);
    406 }
    408 template <typename T>
    409 Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
    410                               T** resource) {
    411   DataType dtype;
    412   TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype));
    413   if (dtype == DT_RESOURCE) {
    414     const Tensor* handle;
    415     TF_RETURN_IF_ERROR(ctx->input(input_name, &handle));
    416     return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource);
    417   }
    418   string container;
    419   string shared_name;
    420   {
    421     mutex* mu;
    422     TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
    423     mutex_lock l(*mu);
    424     Tensor tensor;
    425     TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
    426     if (tensor.NumElements() != 2) {
    427       return errors::InvalidArgument(
    428           "Resource handle must have 2 elements, but had shape: ",
    429           tensor.shape().DebugString());
    430     }
    431     container = tensor.flat<string>()(0);
    432     shared_name = tensor.flat<string>()(1);
    433   }
    434   return ctx->resource_manager()->Lookup(container, shared_name, resource);
    435 }
    437 template <typename T>
    438 ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
    439                                          const string& name) {
    440   return MakeResourceHandle<T>(ctx, ctx->step_container()->name(), name);
    441 }
    443 namespace internal {
    445 Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p);
    447 template <typename T>
    448 Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) {
    449   TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
    450   auto type_index = MakeTypeIndex<T>();
    451   if (type_index.hash_code() != p.hash_code()) {
    452     return errors::InvalidArgument(
    453         "Trying to access resource using the wrong type. Expected ",
    454         p.maybe_type_name(), " got ", type_index.name());
    455   }
    456   return Status::OK();
    457 }
    459 }  // namespace internal
    461 template <typename T>
    462 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
    463   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
    464   return ctx->resource_manager()->Create(p.container(), p.name(), value);
    465 }
    467 template <typename T>
    468 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
    469                       T** value) {
    470   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
    471   return ctx->resource_manager()->Lookup(p.container(), p.name(), value);
    472 }
    474 template <typename T>
    475 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
    476                               T** value, std::function<Status(T**)> creator) {
    477   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
    478   return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value,
    479                                                  creator);
    480 }
    482 template <typename T>
    483 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
    484   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
    485   return ctx->resource_manager()->Delete<T>(p.container(), p.name());
    486 }
    488 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
    490 template <typename T>
    491 void IsResourceInitialized<T>::Compute(OpKernelContext* ctx) {
    492   Tensor* output;
    493   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
    494   T* object;
    495   bool found;
    496   if (LookupResource(ctx, HandleFromInput(ctx, 0), &object).ok()) {
    497     found = true;
    498     object->Unref();
    499   } else {
    500     found = false;
    501   }
    503   output->flat<bool>()(0) = found;
    504 }
    506 template <typename T>
    507 ResourceHandleOp<T>::ResourceHandleOp(OpKernelConstruction* context)
    508     : OpKernel(context) {
    509   OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
    510   OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
    511 }
    513 template <typename T>
    514 void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
    515   Tensor* output = nullptr;
    516   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
    517   output->scalar<ResourceHandle>()() =
    518       MakeResourceHandle<T>(ctx, container_, name_);
    519 }
    521 }  //  end namespace tensorflow