Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/resource_mgr.h"
     17 
     18 #include "tensorflow/core/framework/device_attributes.pb.h"
     19 #include "tensorflow/core/framework/node_def.pb.h"
     20 #include "tensorflow/core/framework/node_def_util.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 #include "tensorflow/core/lib/gtl/map_util.h"
     23 #include "tensorflow/core/lib/strings/scanner.h"
     24 #include "tensorflow/core/lib/strings/str_util.h"
     25 #include "tensorflow/core/lib/strings/stringprintf.h"
     26 #include "tensorflow/core/platform/demangle.h"
     27 
     28 namespace tensorflow {
     29 ResourceHandle MakeResourceHandle(OpKernelContext* ctx, const string& container,
     30                                   const string& name,
     31                                   const TypeIndex& type_index) {
     32   ResourceHandle result;
     33   result.set_device(ctx->device()->attributes().name());
     34   string actual_container;
     35   if (!container.empty()) {
     36     actual_container = container;
     37   } else {
     38     actual_container = ctx->resource_manager()->default_container();
     39   }
     40   result.set_container(actual_container);
     41   result.set_name(name);
     42   result.set_hash_code(type_index.hash_code());
     43   result.set_maybe_type_name(type_index.name());
     44   return result;
     45 }
     46 
     47 Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
     48                                   const string& container, const string& name,
     49                                   const TypeIndex& type_index) {
     50   Tensor* handle;
     51   TF_RETURN_IF_ERROR(
     52       context->allocate_output(output_index, TensorShape({}), &handle));
     53   handle->scalar<ResourceHandle>()() =
     54       MakeResourceHandle(context, container, name, type_index);
     55   return Status::OK();
     56 }
     57 
     58 namespace internal {
     59 
     60 Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) {
     61   if (ctx->device()->attributes().name() != p.device()) {
     62     return errors::InvalidArgument(
     63         "Trying to access resource located in device ", p.device(),
     64         " from device ", ctx->device()->attributes().name());
     65   }
     66   return Status::OK();
     67 }
     68 
     69 }  // end namespace internal
     70 
     71 Status ResourceMgr::InsertDebugTypeName(uint64 hash_code,
     72                                         const string& type_name) {
     73   auto iter = debug_type_names_.emplace(hash_code, type_name);
     74   if (iter.first->second != type_name) {
     75     return errors::AlreadyExists("Duplicate hash code found for type ",
     76                                  type_name);
     77   }
     78   return Status::OK();
     79 }
     80 
     81 const char* ResourceMgr::DebugTypeName(uint64 hash_code) const {
     82   auto type_name_iter = debug_type_names_.find(hash_code);
     83   if (type_name_iter == debug_type_names_.end()) {
     84     return "<unknown>";
     85   } else {
     86     return type_name_iter->second.c_str();
     87   }
     88 }
     89 
     90 ResourceMgr::ResourceMgr() : default_container_("localhost") {}
     91 
     92 ResourceMgr::ResourceMgr(const string& default_container)
     93     : default_container_(default_container) {}
     94 
     95 ResourceMgr::~ResourceMgr() { Clear(); }
     96 
     97 void ResourceMgr::Clear() {
     98   mutex_lock l(mu_);
     99   for (const auto& p : containers_) {
    100     for (const auto& q : *p.second) {
    101       q.second->Unref();
    102     }
    103     delete p.second;
    104   }
    105   containers_.clear();
    106 }
    107 
    108 string ResourceMgr::DebugString() const {
    109   mutex_lock l(mu_);
    110   struct Line {
    111     const string* container;
    112     const string type;
    113     const string* resource;
    114     const string detail;
    115   };
    116   std::vector<Line> lines;
    117   for (const auto& p : containers_) {
    118     const string& container = p.first;
    119     for (const auto& q : *p.second) {
    120       const Key& key = q.first;
    121       const char* type = DebugTypeName(key.first);
    122       const string& resource = key.second;
    123       Line l{&container, port::Demangle(type), &resource,
    124              q.second->DebugString()};
    125       lines.push_back(l);
    126     }
    127   }
    128   std::vector<string> text;
    129   text.reserve(lines.size());
    130   for (const Line& line : lines) {
    131     text.push_back(strings::Printf(
    132         "%-20s | %-40s | %-40s | %-s", line.container->c_str(),
    133         line.type.c_str(), line.resource->c_str(), line.detail.c_str()));
    134   }
    135   std::sort(text.begin(), text.end());
    136   return str_util::Join(text, "\n");
    137 }
    138 
    139 Status ResourceMgr::DoCreate(const string& container, TypeIndex type,
    140                              const string& name, ResourceBase* resource) {
    141   {
    142     mutex_lock l(mu_);
    143     Container** b = &containers_[container];
    144     if (*b == nullptr) {
    145       *b = new Container;
    146     }
    147     if ((*b)->insert({{type.hash_code(), name}, resource}).second) {
    148       TF_RETURN_IF_ERROR(InsertDebugTypeName(type.hash_code(), type.name()));
    149       return Status::OK();
    150     }
    151   }
    152   resource->Unref();
    153   return errors::AlreadyExists("Resource ", container, "/", name, "/",
    154                                type.name());
    155 }
    156 
    157 Status ResourceMgr::DoLookup(const string& container, TypeIndex type,
    158                              const string& name,
    159                              ResourceBase** resource) const {
    160   tf_shared_lock l(mu_);
    161   const Container* b = gtl::FindPtrOrNull(containers_, container);
    162   if (b == nullptr) {
    163     return errors::NotFound("Container ", container,
    164                             " does not exist. (Could not find resource: ",
    165                             container, "/", name, ")");
    166   }
    167   auto r = gtl::FindPtrOrNull(*b, {type.hash_code(), name});
    168   if (r == nullptr) {
    169     return errors::NotFound("Resource ", container, "/", name, "/", type.name(),
    170                             " does not exist.");
    171   }
    172   *resource = const_cast<ResourceBase*>(r);
    173   (*resource)->Ref();
    174   return Status::OK();
    175 }
    176 
    177 Status ResourceMgr::DoDelete(const string& container, uint64 type_hash_code,
    178                              const string& resource_name,
    179                              const string& type_name) {
    180   ResourceBase* base = nullptr;
    181   {
    182     mutex_lock l(mu_);
    183     Container* b = gtl::FindPtrOrNull(containers_, container);
    184     if (b == nullptr) {
    185       return errors::NotFound("Container ", container, " does not exist.");
    186     }
    187     auto iter = b->find({type_hash_code, resource_name});
    188     if (iter == b->end()) {
    189       return errors::NotFound("Resource ", container, "/", resource_name, "/",
    190                               type_name, " does not exist.");
    191     }
    192     base = iter->second;
    193     b->erase(iter);
    194   }
    195   CHECK(base != nullptr);
    196   base->Unref();
    197   return Status::OK();
    198 }
    199 
    200 Status ResourceMgr::DoDelete(const string& container, TypeIndex type,
    201                              const string& resource_name) {
    202   return DoDelete(container, type.hash_code(), resource_name, type.name());
    203 }
    204 
    205 Status ResourceMgr::Delete(const ResourceHandle& handle) {
    206   return DoDelete(handle.container(), handle.hash_code(), handle.name(),
    207                   "<unknown>");
    208 }
    209 
    210 Status ResourceMgr::Cleanup(const string& container) {
    211   Container* b = nullptr;
    212   {
    213     mutex_lock l(mu_);
    214     auto iter = containers_.find(container);
    215     if (iter == containers_.end()) {
    216       // Nothing to cleanup, it's OK.
    217       return Status::OK();
    218     }
    219     b = iter->second;
    220     containers_.erase(iter);
    221   }
    222   CHECK(b != nullptr);
    223   for (const auto& p : *b) {
    224     p.second->Unref();
    225   }
    226   delete b;
    227   return Status::OK();
    228 }
    229 
    230 static bool IsValidContainerName(StringPiece s) {
    231   using ::tensorflow::strings::Scanner;
    232   return Scanner(s)
    233       .One(Scanner::LETTER_DIGIT_DOT)
    234       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH)
    235       .Eos()
    236       .GetResult();
    237 }
    238 
    239 Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef,
    240                            bool use_node_name_as_default) {
    241   CHECK(rmgr);
    242   rmgr_ = rmgr;
    243   string attr_container;
    244   TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "container", &attr_container));
    245   if (!attr_container.empty() && !IsValidContainerName(attr_container)) {
    246     return errors::InvalidArgument("container contains invalid characters: ",
    247                                    attr_container);
    248   }
    249   string attr_shared_name;
    250   TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "shared_name", &attr_shared_name));
    251   if (!attr_shared_name.empty() && (attr_shared_name[0] == '_')) {
    252     return errors::InvalidArgument("shared_name cannot start with '_':",
    253                                    attr_shared_name);
    254   }
    255   if (!attr_container.empty()) {
    256     container_ = attr_container;
    257   } else {
    258     container_ = rmgr_->default_container();
    259   }
    260   if (!attr_shared_name.empty()) {
    261     name_ = attr_shared_name;
    262   } else if (use_node_name_as_default) {
    263     name_ = ndef.name();
    264   } else {
    265     resource_is_private_to_kernel_ = true;
    266     static std::atomic<int64> counter(0);
    267     name_ = strings::StrCat("_", counter.fetch_add(1), "_", ndef.name());
    268   }
    269   return Status::OK();
    270 }
    271 
    272 string ContainerInfo::DebugString() const {
    273   return strings::StrCat("[", container(), ",", name(), ",",
    274                          resource_is_private_to_kernel() ? "private" : "public",
    275                          "]");
    276 }
    277 
    278 ResourceHandle HandleFromInput(OpKernelContext* ctx, int input) {
    279   return ctx->input(input).flat<ResourceHandle>()(0);
    280 }
    281 
    282 Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
    283                        ResourceHandle* handle) {
    284   const Tensor* tensor;
    285   TF_RETURN_IF_ERROR(ctx->input(input, &tensor));
    286   *handle = tensor->flat<ResourceHandle>()(0);
    287   return Status::OK();
    288 }
    289 
    290 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
    291   TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
    292   return ctx->resource_manager()->Delete(p);
    293 }
    294 
    295 }  //  end namespace tensorflow
    296