1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op.h" 17 18 #include <algorithm> 19 #include <memory> 20 #include <vector> 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/lib/core/errors.h" 23 #include "tensorflow/core/lib/gtl/map_util.h" 24 #include "tensorflow/core/platform/host_info.h" 25 #include "tensorflow/core/platform/logging.h" 26 #include "tensorflow/core/platform/mutex.h" 27 #include "tensorflow/core/platform/protobuf.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace tensorflow { 31 32 // OpRegistry ----------------------------------------------------------------- 33 34 OpRegistryInterface::~OpRegistryInterface() {} 35 36 Status OpRegistryInterface::LookUpOpDef(const string& op_type_name, 37 const OpDef** op_def) const { 38 *op_def = nullptr; 39 const OpRegistrationData* op_reg_data = nullptr; 40 TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data)); 41 *op_def = &op_reg_data->op_def; 42 return Status::OK(); 43 } 44 45 OpRegistry::OpRegistry() : initialized_(false) {} 46 47 OpRegistry::~OpRegistry() { 48 for (const auto& e : registry_) delete e.second; 49 } 50 51 void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) { 52 mutex_lock lock(mu_); 53 if (initialized_) { 54 TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory)); 55 } else { 56 deferred_.push_back(op_data_factory); 57 } 58 } 59 60 Status OpRegistry::LookUp(const string& op_type_name, 61 const OpRegistrationData** op_reg_data) const { 62 *op_reg_data = nullptr; 63 const OpRegistrationData* res = nullptr; 64 65 bool first_call = false; 66 bool first_unregistered = false; 67 { // Scope for lock. 68 mutex_lock lock(mu_); 69 first_call = MustCallDeferred(); 70 res = gtl::FindWithDefault(registry_, op_type_name, nullptr); 71 72 static bool unregistered_before = false; 73 first_unregistered = !unregistered_before && (res == nullptr); 74 if (first_unregistered) { 75 unregistered_before = true; 76 } 77 // Note: Can't hold mu_ while calling Export() below. 78 } 79 if (first_call) { 80 TF_QCHECK_OK(ValidateKernelRegistrations(*this)); 81 } 82 if (res == nullptr) { 83 if (first_unregistered) { 84 OpList op_list; 85 Export(true, &op_list); 86 if (VLOG_IS_ON(3)) { 87 LOG(INFO) << "All registered Ops:"; 88 for (const auto& op : op_list.op()) { 89 LOG(INFO) << SummarizeOpDef(op); 90 } 91 } 92 } 93 Status status = 94 errors::NotFound("Op type not registered '", op_type_name, 95 "' in binary running on ", port::Hostname(), ". ", 96 "Make sure the Op and Kernel are registered in the " 97 "binary running in this process."); 98 VLOG(1) << status.ToString(); 99 return status; 100 } 101 *op_reg_data = res; 102 return Status::OK(); 103 } 104 105 void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) { 106 mutex_lock lock(mu_); 107 MustCallDeferred(); 108 for (const auto& p : registry_) { 109 op_defs->push_back(p.second->op_def); 110 } 111 } 112 113 Status OpRegistry::SetWatcher(const Watcher& watcher) { 114 mutex_lock lock(mu_); 115 if (watcher_ && watcher) { 116 return errors::AlreadyExists( 117 "Cannot over-write a valid watcher with another."); 118 } 119 watcher_ = watcher; 120 return Status::OK(); 121 } 122 123 void OpRegistry::Export(bool include_internal, OpList* ops) const { 124 mutex_lock lock(mu_); 125 MustCallDeferred(); 126 127 std::vector<std::pair<string, const OpRegistrationData*>> sorted( 128 registry_.begin(), registry_.end()); 129 std::sort(sorted.begin(), sorted.end()); 130 131 auto out = ops->mutable_op(); 132 out->Clear(); 133 out->Reserve(sorted.size()); 134 135 for (const auto& item : sorted) { 136 if (include_internal || !StringPiece(item.first).starts_with("_")) { 137 *out->Add() = item.second->op_def; 138 } 139 } 140 } 141 142 void OpRegistry::DeferRegistrations() { 143 mutex_lock lock(mu_); 144 initialized_ = false; 145 } 146 147 void OpRegistry::ClearDeferredRegistrations() { 148 mutex_lock lock(mu_); 149 deferred_.clear(); 150 } 151 152 Status OpRegistry::ProcessRegistrations() const { 153 mutex_lock lock(mu_); 154 return CallDeferred(); 155 } 156 157 string OpRegistry::DebugString(bool include_internal) const { 158 OpList op_list; 159 Export(include_internal, &op_list); 160 string ret; 161 for (const auto& op : op_list.op()) { 162 strings::StrAppend(&ret, SummarizeOpDef(op), "\n"); 163 } 164 return ret; 165 } 166 167 bool OpRegistry::MustCallDeferred() const { 168 if (initialized_) return false; 169 initialized_ = true; 170 for (size_t i = 0; i < deferred_.size(); ++i) { 171 TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i])); 172 } 173 deferred_.clear(); 174 return true; 175 } 176 177 Status OpRegistry::CallDeferred() const { 178 if (initialized_) return Status::OK(); 179 initialized_ = true; 180 for (size_t i = 0; i < deferred_.size(); ++i) { 181 Status s = RegisterAlreadyLocked(deferred_[i]); 182 if (!s.ok()) { 183 return s; 184 } 185 } 186 deferred_.clear(); 187 return Status::OK(); 188 } 189 190 Status OpRegistry::RegisterAlreadyLocked( 191 const OpRegistrationDataFactory& op_data_factory) const { 192 std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData); 193 Status s = op_data_factory(op_reg_data.get()); 194 if (s.ok()) { 195 s = ValidateOpDef(op_reg_data->op_def); 196 if (s.ok() && 197 !gtl::InsertIfNotPresent(®istry_, op_reg_data->op_def.name(), 198 op_reg_data.get())) { 199 s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name()); 200 } 201 } 202 Status watcher_status = s; 203 if (watcher_) { 204 watcher_status = watcher_(s, op_reg_data->op_def); 205 } 206 if (s.ok()) { 207 op_reg_data.release(); 208 } else { 209 op_reg_data.reset(); 210 } 211 return watcher_status; 212 } 213 214 // static 215 OpRegistry* OpRegistry::Global() { 216 static OpRegistry* global_op_registry = new OpRegistry; 217 return global_op_registry; 218 } 219 220 // OpListOpRegistry ----------------------------------------------------------- 221 222 OpListOpRegistry::OpListOpRegistry(const OpList* op_list) { 223 for (const OpDef& op_def : op_list->op()) { 224 auto* op_reg_data = new OpRegistrationData(); 225 op_reg_data->op_def = op_def; 226 index_[op_def.name()] = op_reg_data; 227 } 228 } 229 230 OpListOpRegistry::~OpListOpRegistry() { 231 for (const auto& e : index_) delete e.second; 232 } 233 234 Status OpListOpRegistry::LookUp(const string& op_type_name, 235 const OpRegistrationData** op_reg_data) const { 236 auto iter = index_.find(op_type_name); 237 if (iter == index_.end()) { 238 *op_reg_data = nullptr; 239 return errors::NotFound("Op type not registered '", op_type_name, 240 "' in binary running on ", port::Hostname(), ". ", 241 "Make sure the Op and Kernel are registered in the " 242 "binary running in this process."); 243 } 244 *op_reg_data = iter->second; 245 return Status::OK(); 246 } 247 248 // Other registration --------------------------------------------------------- 249 250 namespace register_op { 251 OpDefBuilderReceiver::OpDefBuilderReceiver( 252 const OpDefBuilderWrapper<true>& wrapper) { 253 OpRegistry::Global()->Register( 254 [wrapper](OpRegistrationData* op_reg_data) -> Status { 255 return wrapper.builder().Finalize(op_reg_data); 256 }); 257 } 258 } // namespace register_op 259 260 } // namespace tensorflow 261