Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/op.h"
     17 
     18 #include <algorithm>
     19 #include <memory>
     20 #include <vector>
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/lib/core/errors.h"
     23 #include "tensorflow/core/lib/gtl/map_util.h"
     24 #include "tensorflow/core/platform/host_info.h"
     25 #include "tensorflow/core/platform/logging.h"
     26 #include "tensorflow/core/platform/mutex.h"
     27 #include "tensorflow/core/platform/protobuf.h"
     28 #include "tensorflow/core/platform/types.h"
     29 
     30 namespace tensorflow {
     31 
     32 // OpRegistry -----------------------------------------------------------------
     33 
     34 OpRegistryInterface::~OpRegistryInterface() {}
     35 
     36 Status OpRegistryInterface::LookUpOpDef(const string& op_type_name,
     37                                         const OpDef** op_def) const {
     38   *op_def = nullptr;
     39   const OpRegistrationData* op_reg_data = nullptr;
     40   TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data));
     41   *op_def = &op_reg_data->op_def;
     42   return Status::OK();
     43 }
     44 
     45 OpRegistry::OpRegistry() : initialized_(false) {}
     46 
     47 OpRegistry::~OpRegistry() {
     48   for (const auto& e : registry_) delete e.second;
     49 }
     50 
     51 void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
     52   mutex_lock lock(mu_);
     53   if (initialized_) {
     54     TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
     55   } else {
     56     deferred_.push_back(op_data_factory);
     57   }
     58 }
     59 
     60 Status OpRegistry::LookUp(const string& op_type_name,
     61                           const OpRegistrationData** op_reg_data) const {
     62   *op_reg_data = nullptr;
     63   const OpRegistrationData* res = nullptr;
     64 
     65   bool first_call = false;
     66   bool first_unregistered = false;
     67   {  // Scope for lock.
     68     mutex_lock lock(mu_);
     69     first_call = MustCallDeferred();
     70     res = gtl::FindWithDefault(registry_, op_type_name, nullptr);
     71 
     72     static bool unregistered_before = false;
     73     first_unregistered = !unregistered_before && (res == nullptr);
     74     if (first_unregistered) {
     75       unregistered_before = true;
     76     }
     77     // Note: Can't hold mu_ while calling Export() below.
     78   }
     79   if (first_call) {
     80     TF_QCHECK_OK(ValidateKernelRegistrations(*this));
     81   }
     82   if (res == nullptr) {
     83     if (first_unregistered) {
     84       OpList op_list;
     85       Export(true, &op_list);
     86       if (VLOG_IS_ON(3)) {
     87         LOG(INFO) << "All registered Ops:";
     88         for (const auto& op : op_list.op()) {
     89           LOG(INFO) << SummarizeOpDef(op);
     90         }
     91       }
     92     }
     93     Status status =
     94         errors::NotFound("Op type not registered '", op_type_name,
     95                          "' in binary running on ", port::Hostname(), ". ",
     96                          "Make sure the Op and Kernel are registered in the "
     97                          "binary running in this process.");
     98     VLOG(1) << status.ToString();
     99     return status;
    100   }
    101   *op_reg_data = res;
    102   return Status::OK();
    103 }
    104 
    105 void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) {
    106   mutex_lock lock(mu_);
    107   MustCallDeferred();
    108   for (const auto& p : registry_) {
    109     op_defs->push_back(p.second->op_def);
    110   }
    111 }
    112 
    113 Status OpRegistry::SetWatcher(const Watcher& watcher) {
    114   mutex_lock lock(mu_);
    115   if (watcher_ && watcher) {
    116     return errors::AlreadyExists(
    117         "Cannot over-write a valid watcher with another.");
    118   }
    119   watcher_ = watcher;
    120   return Status::OK();
    121 }
    122 
    123 void OpRegistry::Export(bool include_internal, OpList* ops) const {
    124   mutex_lock lock(mu_);
    125   MustCallDeferred();
    126 
    127   std::vector<std::pair<string, const OpRegistrationData*>> sorted(
    128       registry_.begin(), registry_.end());
    129   std::sort(sorted.begin(), sorted.end());
    130 
    131   auto out = ops->mutable_op();
    132   out->Clear();
    133   out->Reserve(sorted.size());
    134 
    135   for (const auto& item : sorted) {
    136     if (include_internal || !StringPiece(item.first).starts_with("_")) {
    137       *out->Add() = item.second->op_def;
    138     }
    139   }
    140 }
    141 
    142 void OpRegistry::DeferRegistrations() {
    143   mutex_lock lock(mu_);
    144   initialized_ = false;
    145 }
    146 
    147 void OpRegistry::ClearDeferredRegistrations() {
    148   mutex_lock lock(mu_);
    149   deferred_.clear();
    150 }
    151 
    152 Status OpRegistry::ProcessRegistrations() const {
    153   mutex_lock lock(mu_);
    154   return CallDeferred();
    155 }
    156 
    157 string OpRegistry::DebugString(bool include_internal) const {
    158   OpList op_list;
    159   Export(include_internal, &op_list);
    160   string ret;
    161   for (const auto& op : op_list.op()) {
    162     strings::StrAppend(&ret, SummarizeOpDef(op), "\n");
    163   }
    164   return ret;
    165 }
    166 
    167 bool OpRegistry::MustCallDeferred() const {
    168   if (initialized_) return false;
    169   initialized_ = true;
    170   for (size_t i = 0; i < deferred_.size(); ++i) {
    171     TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i]));
    172   }
    173   deferred_.clear();
    174   return true;
    175 }
    176 
    177 Status OpRegistry::CallDeferred() const {
    178   if (initialized_) return Status::OK();
    179   initialized_ = true;
    180   for (size_t i = 0; i < deferred_.size(); ++i) {
    181     Status s = RegisterAlreadyLocked(deferred_[i]);
    182     if (!s.ok()) {
    183       return s;
    184     }
    185   }
    186   deferred_.clear();
    187   return Status::OK();
    188 }
    189 
    190 Status OpRegistry::RegisterAlreadyLocked(
    191     const OpRegistrationDataFactory& op_data_factory) const {
    192   std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
    193   Status s = op_data_factory(op_reg_data.get());
    194   if (s.ok()) {
    195     s = ValidateOpDef(op_reg_data->op_def);
    196     if (s.ok() &&
    197         !gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
    198                                  op_reg_data.get())) {
    199       s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
    200     }
    201   }
    202   Status watcher_status = s;
    203   if (watcher_) {
    204     watcher_status = watcher_(s, op_reg_data->op_def);
    205   }
    206   if (s.ok()) {
    207     op_reg_data.release();
    208   } else {
    209     op_reg_data.reset();
    210   }
    211   return watcher_status;
    212 }
    213 
    214 // static
    215 OpRegistry* OpRegistry::Global() {
    216   static OpRegistry* global_op_registry = new OpRegistry;
    217   return global_op_registry;
    218 }
    219 
    220 // OpListOpRegistry -----------------------------------------------------------
    221 
    222 OpListOpRegistry::OpListOpRegistry(const OpList* op_list) {
    223   for (const OpDef& op_def : op_list->op()) {
    224     auto* op_reg_data = new OpRegistrationData();
    225     op_reg_data->op_def = op_def;
    226     index_[op_def.name()] = op_reg_data;
    227   }
    228 }
    229 
    230 OpListOpRegistry::~OpListOpRegistry() {
    231   for (const auto& e : index_) delete e.second;
    232 }
    233 
    234 Status OpListOpRegistry::LookUp(const string& op_type_name,
    235                                 const OpRegistrationData** op_reg_data) const {
    236   auto iter = index_.find(op_type_name);
    237   if (iter == index_.end()) {
    238     *op_reg_data = nullptr;
    239     return errors::NotFound("Op type not registered '", op_type_name,
    240                             "' in binary running on ", port::Hostname(), ". ",
    241                             "Make sure the Op and Kernel are registered in the "
    242                             "binary running in this process.");
    243   }
    244   *op_reg_data = iter->second;
    245   return Status::OK();
    246 }
    247 
    248 // Other registration ---------------------------------------------------------
    249 
    250 namespace register_op {
    251 OpDefBuilderReceiver::OpDefBuilderReceiver(
    252     const OpDefBuilderWrapper<true>& wrapper) {
    253   OpRegistry::Global()->Register(
    254       [wrapper](OpRegistrationData* op_reg_data) -> Status {
    255         return wrapper.builder().Finalize(op_reg_data);
    256       });
    257 }
    258 }  // namespace register_op
    259 
    260 }  // namespace tensorflow
    261