1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/c/eager/runtime.h" 17 18 #include "tensorflow/core/common_runtime/device_factory.h" 19 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 20 #include "tensorflow/core/framework/allocator.h" 21 #include "tensorflow/core/framework/node_def.pb.h" 22 #include "tensorflow/core/lib/core/errors.h" 23 #include "tensorflow/core/lib/gtl/map_util.h" 24 #include "tensorflow/core/lib/gtl/stl_util.h" 25 #include "tensorflow/core/platform/fingerprint.h" 26 #include "tensorflow/core/platform/mutex.h" 27 #include "tensorflow/core/public/version.h" 28 #include "tensorflow/core/util/tensor_slice_reader_cache.h" 29 30 namespace tensorflow { 31 namespace { 32 33 mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED); 34 35 std::unordered_map<string, const AttrTypeMap*>* OpNameToAttrTypeMap() { 36 static auto* const m = new std::unordered_map<string, const AttrTypeMap*>; 37 return m; 38 } 39 40 const uint32 kIsList = 1U << 31; 41 42 } // namespace 43 44 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) { 45 mutex_lock l(g_op_name_to_attr_type_map_lock); 46 *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name); 47 if (*out != nullptr) return Status::OK(); 48 const OpRegistrationData* op_reg_data = nullptr; 49 Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data); 50 if (!s.ok()) return s; 51 std::unique_ptr<AttrTypeMap> m(new AttrTypeMap); 52 // TODO(agarwal): Avoid having to create this "registry" at runtime, 53 // perhaps can be done at op registration time? 54 for (const auto& attr : op_reg_data->op_def.attr()) { 55 string type = attr.type(); 56 const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0); 57 if (is_list) { 58 type = type.substr(5, type.length() - 6); 59 } 60 uint32 t = is_list ? kIsList : 0; 61 if (type == "string") { 62 t |= TF_ATTR_STRING; 63 } else if (type == "int") { 64 t |= TF_ATTR_INT; 65 } else if (type == "float") { 66 t |= TF_ATTR_FLOAT; 67 } else if (type == "bool") { 68 t |= TF_ATTR_BOOL; 69 } else if (type == "type") { 70 t |= TF_ATTR_TYPE; 71 } else if (type == "shape") { 72 t |= TF_ATTR_SHAPE; 73 } else if (type == "tensor") { 74 t |= TF_ATTR_TENSOR; 75 } else if (type == "func") { 76 t |= TF_ATTR_FUNC; 77 } else { 78 return errors::Unimplemented( 79 "TODO(agarwal): Enable support for ops with attributes of type '", 80 type, "'"); 81 } 82 gtl::InsertIfNotPresent(m.get(), attr.name(), t); 83 } 84 *out = m.get(); 85 (*OpNameToAttrTypeMap())[op_name] = m.release(); 86 return Status::OK(); 87 } 88 89 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, 90 TF_AttrType* out, unsigned char* is_list) { 91 auto* t = gtl::FindOrNull(m, attr_name); 92 if (t == nullptr) { 93 return errors::InvalidArgument("Attribute '", attr_name, 94 "' does not exist for this operation"); 95 } 96 *out = static_cast<TF_AttrType>(*t & ~kIsList); 97 if (*t & kIsList) { 98 *is_list = 1; 99 } else { 100 *is_list = 0; 101 } 102 return Status::OK(); 103 } 104 105 #define DEFINE_SET_ATTR(value_type, value_field) \ 106 template <> \ 107 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \ 108 value_field.push_back(std::make_pair(attr_name, value)); \ 109 return *this; \ 110 } 111 112 DEFINE_SET_ATTR(StringPiece, string_attrs_); 113 DEFINE_SET_ATTR(float, float_attrs_); 114 DEFINE_SET_ATTR(int, int_attrs_); 115 DEFINE_SET_ATTR(bool, bool_attrs_); 116 DEFINE_SET_ATTR(tensorflow::DataType, type_attrs_); 117 118 #undef DEFINE_SET_ATTR 119 120 AttrBuilder& AttrBuilder::NumInputs(int n) { 121 DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef."; 122 num_inputs_ = n; 123 return *this; 124 } 125 126 void AttrBuilder::FillAttrValueMap(AttrValueMap* m, 127 bool include_those_in_node_def) const { 128 for (const auto& p : string_attrs_) { 129 SetInAttrValueMap(m, p.first, p.second); 130 } 131 for (const auto& p : int_attrs_) { 132 SetInAttrValueMap(m, p.first, p.second); 133 } 134 for (const auto& p : float_attrs_) { 135 SetInAttrValueMap(m, p.first, p.second); 136 } 137 for (const auto& p : bool_attrs_) { 138 SetInAttrValueMap(m, p.first, p.second); 139 } 140 for (const auto& p : type_attrs_) { 141 SetInAttrValueMap(m, p.first, p.second); 142 } 143 if (include_those_in_node_def && node_def_ != nullptr) { 144 for (AttrValueMap::const_iterator it = node_def_->attr().begin(); 145 it != node_def_->attr().end(); ++it) { 146 m->insert(*it); 147 } 148 } 149 } 150 151 const NodeDef& AttrBuilder::BuildNodeDef() { 152 if (node_def_finalized_) return *node_def_; 153 MayBeInitializeNodeDef(); 154 for (int i = 0; i < num_inputs_; ++i) { 155 node_def_->add_input("dummy_input"); 156 } 157 FillAttrValueMap(node_def_->mutable_attr(), false); 158 node_def_finalized_ = true; 159 return *node_def_; 160 } 161 162 namespace { 163 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a, 164 const tensorflow::Fprint128& b) { 165 return {tensorflow::FingerprintCat64(a.low64, b.low64), 166 tensorflow::FingerprintCat64(a.low64, b.low64)}; 167 } 168 169 void CombineUnordered(const tensorflow::Fprint128& a, 170 tensorflow::Fprint128* b) { 171 b->low64 += a.low64; 172 b->high64 += a.high64; 173 } 174 175 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, 176 const tensorflow::Fprint128& b) { 177 // TODO(agarwal): avoid ToString(). 178 tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString()); 179 return FingerprintCat128(a, b); 180 } 181 182 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) { 183 return CacheKeyHelper(s, {b, b}); 184 } 185 186 } // namespace 187 188 tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const { 189 tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_); 190 f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device)); 191 if (node_def_ != nullptr) { 192 // Some attributes are directly written to node_def_ instead of being 193 // stored explicitly. 194 string value; 195 for (const auto& attr : node_def_->attr()) { 196 attr.second.SerializeToString(&value); 197 CombineUnordered( 198 CacheKeyHelper(attr.first, tensorflow::Fingerprint128(value)), &f); 199 } 200 // Note that node_def_ may be created but not finalized. This can happen 201 // when the creation was triggered by a call to Set, but BuildNodeDef has 202 // not been called. 203 if (node_def_finalized_) return f; 204 } 205 for (const auto& p : string_attrs_) { 206 // TODO(agarwal): avoid ToString(). 207 CombineUnordered(CacheKeyHelper(p.first, tensorflow::Fingerprint128( 208 p.second.ToString())), 209 &f); 210 } 211 for (const auto& p : int_attrs_) { 212 CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)), 213 &f); 214 } 215 static std::hash<float> float_hasher; 216 for (const auto& p : float_attrs_) { 217 CombineUnordered( 218 CacheKeyHelper(p.first, static_cast<uint64>(float_hasher(p.second))), 219 &f); 220 } 221 for (const auto& p : bool_attrs_) { 222 CombineUnordered(CacheKeyHelper(p.first, p.second ? 1u : 0u), &f); 223 } 224 for (const auto& p : type_attrs_) { 225 CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)), 226 &f); 227 } 228 return f; 229 } 230 231 void AttrBuilder::MayBeInitializeNodeDef() { 232 if (node_def_ == nullptr) { 233 node_def_.reset(new NodeDef()); 234 node_def_->set_name(op_name_); 235 node_def_->set_op(op_name_); 236 } 237 } 238 239 // static 240 Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef, 241 KernelAndDevice* out) { 242 OpKernel* k = nullptr; 243 Status s = CreateOpKernel(device->device_type().c_str(), device, 244 device->GetAllocator(AllocatorAttributes()), 245 nullptr, ndef, TF_GRAPH_DEF_VERSION, &k); 246 out->device_ = device; 247 out->kernel_.reset(k); 248 out->flib_ = nullptr; 249 return s; 250 } 251 252 // static 253 Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, 254 KernelAndDevice* out) { 255 OpKernel* k = nullptr; 256 Status s = flib->CreateKernel(ndef, &k); 257 out->device_ = flib->device(); 258 out->kernel_.reset(k); 259 out->flib_ = flib; 260 return s; 261 } 262 263 Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors, 264 std::vector<Tensor>* output_tensors, 265 NodeExecStats* stats) { 266 gtl::InlinedVector<TensorValue, 4> inputs; 267 for (Tensor& t : *input_tensors) { 268 inputs.push_back(TensorValue(&t)); 269 } 270 271 std::vector<AllocatorAttributes> out_attrs(kernel_->num_outputs()); 272 for (size_t i = 0; i < out_attrs.size(); ++i) { 273 out_attrs[i].set_on_host(kernel_->output_memory_types()[i] == 274 tensorflow::HOST_MEMORY); 275 } 276 277 OpKernelContext::Params params; 278 params.device = device_; 279 params.frame_iter = FrameAndIter(0, 0); 280 params.inputs = &inputs; 281 params.op_kernel = kernel_.get(); 282 params.resource_manager = device_->resource_manager(); 283 params.output_attr_array = gtl::vector_as_array(&out_attrs); 284 params.function_library = flib_; 285 params.slice_reader_cache = &slice_reader_cache_; 286 params.rendezvous = rendez_; 287 if (stats != nullptr) { 288 params.track_allocations = true; 289 } 290 // TODO(apassos): use a thread pool. 291 std::function<void(std::function<void()>)> runner = 292 [](std::function<void()> f) { f(); }; 293 params.runner = &runner; 294 295 OpKernelContext context(¶ms); 296 device_->Compute(kernel_.get(), &context); 297 if (!context.status().ok()) return context.status(); 298 299 output_tensors->clear(); 300 for (int i = 0; i < context.num_outputs(); ++i) { 301 output_tensors->push_back(Tensor(*context.mutable_output(i))); 302 } 303 if (stats != nullptr) { 304 for (const auto& allocator_pair : context.wrapped_allocators()) { 305 AllocatorMemoryUsed* memory = stats->add_memory(); 306 memory->set_allocator_name(allocator_pair.first->Name()); 307 auto sizes = allocator_pair.second->GetSizes(); 308 memory->set_total_bytes(std::get<0>(sizes)); 309 memory->set_peak_bytes(std::get<1>(sizes)); 310 memory->set_live_bytes(std::get<2>(sizes)); 311 312 AllocatorStats allocator_stats; 313 allocator_pair.first->GetStats(&allocator_stats); 314 memory->set_allocator_bytes_in_use(allocator_stats.bytes_in_use); 315 allocator_pair.second->GetRecordsAndUnRef(); 316 } 317 auto* ms = stats->mutable_memory_stats(); 318 ms->set_temp_memory_size(context.temp_memory_allocated()); 319 for (const auto& alloc_id : context.persistent_alloc_ids()) { 320 ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); 321 } 322 323 ms->set_persistent_memory_size(context.persistent_memory_allocated()); 324 } 325 return Status::OK(); 326 } 327 328 } // namespace tensorflow 329