Home | History | Annotate | Download | only in eager
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/c/eager/runtime.h"
     17 
     18 #include "tensorflow/core/common_runtime/device_factory.h"
     19 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
     20 #include "tensorflow/core/framework/allocator.h"
     21 #include "tensorflow/core/framework/node_def.pb.h"
     22 #include "tensorflow/core/lib/core/errors.h"
     23 #include "tensorflow/core/lib/gtl/map_util.h"
     24 #include "tensorflow/core/lib/gtl/stl_util.h"
     25 #include "tensorflow/core/platform/fingerprint.h"
     26 #include "tensorflow/core/platform/mutex.h"
     27 #include "tensorflow/core/public/version.h"
     28 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
     29 
     30 namespace tensorflow {
     31 namespace {
     32 
     33 mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED);
     34 
     35 std::unordered_map<string, const AttrTypeMap*>* OpNameToAttrTypeMap() {
     36   static auto* const m = new std::unordered_map<string, const AttrTypeMap*>;
     37   return m;
     38 }
     39 
     40 const uint32 kIsList = 1U << 31;
     41 
     42 }  // namespace
     43 
     44 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
     45   mutex_lock l(g_op_name_to_attr_type_map_lock);
     46   *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
     47   if (*out != nullptr) return Status::OK();
     48   const OpRegistrationData* op_reg_data = nullptr;
     49   Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
     50   if (!s.ok()) return s;
     51   std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
     52   // TODO(agarwal): Avoid having to create this "registry" at runtime,
     53   // perhaps can be done at op registration time?
     54   for (const auto& attr : op_reg_data->op_def.attr()) {
     55     string type = attr.type();
     56     const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
     57     if (is_list) {
     58       type = type.substr(5, type.length() - 6);
     59     }
     60     uint32 t = is_list ? kIsList : 0;
     61     if (type == "string") {
     62       t |= TF_ATTR_STRING;
     63     } else if (type == "int") {
     64       t |= TF_ATTR_INT;
     65     } else if (type == "float") {
     66       t |= TF_ATTR_FLOAT;
     67     } else if (type == "bool") {
     68       t |= TF_ATTR_BOOL;
     69     } else if (type == "type") {
     70       t |= TF_ATTR_TYPE;
     71     } else if (type == "shape") {
     72       t |= TF_ATTR_SHAPE;
     73     } else if (type == "tensor") {
     74       t |= TF_ATTR_TENSOR;
     75     } else if (type == "func") {
     76       t |= TF_ATTR_FUNC;
     77     } else {
     78       return errors::Unimplemented(
     79           "TODO(agarwal): Enable support for ops with attributes of type '",
     80           type, "'");
     81     }
     82     gtl::InsertIfNotPresent(m.get(), attr.name(), t);
     83   }
     84   *out = m.get();
     85   (*OpNameToAttrTypeMap())[op_name] = m.release();
     86   return Status::OK();
     87 }
     88 
     89 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
     90                       TF_AttrType* out, unsigned char* is_list) {
     91   auto* t = gtl::FindOrNull(m, attr_name);
     92   if (t == nullptr) {
     93     return errors::InvalidArgument("Attribute '", attr_name,
     94                                    "' does not exist for this operation");
     95   }
     96   *out = static_cast<TF_AttrType>(*t & ~kIsList);
     97   if (*t & kIsList) {
     98     *is_list = 1;
     99   } else {
    100     *is_list = 0;
    101   }
    102   return Status::OK();
    103 }
    104 
    105 #define DEFINE_SET_ATTR(value_type, value_field)                             \
    106   template <>                                                                \
    107   AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \
    108     value_field.push_back(std::make_pair(attr_name, value));                 \
    109     return *this;                                                            \
    110   }
    111 
    112 DEFINE_SET_ATTR(StringPiece, string_attrs_);
    113 DEFINE_SET_ATTR(float, float_attrs_);
    114 DEFINE_SET_ATTR(int, int_attrs_);
    115 DEFINE_SET_ATTR(bool, bool_attrs_);
    116 DEFINE_SET_ATTR(tensorflow::DataType, type_attrs_);
    117 
    118 #undef DEFINE_SET_ATTR
    119 
    120 AttrBuilder& AttrBuilder::NumInputs(int n) {
    121   DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef.";
    122   num_inputs_ = n;
    123   return *this;
    124 }
    125 
    126 void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
    127                                    bool include_those_in_node_def) const {
    128   for (const auto& p : string_attrs_) {
    129     SetInAttrValueMap(m, p.first, p.second);
    130   }
    131   for (const auto& p : int_attrs_) {
    132     SetInAttrValueMap(m, p.first, p.second);
    133   }
    134   for (const auto& p : float_attrs_) {
    135     SetInAttrValueMap(m, p.first, p.second);
    136   }
    137   for (const auto& p : bool_attrs_) {
    138     SetInAttrValueMap(m, p.first, p.second);
    139   }
    140   for (const auto& p : type_attrs_) {
    141     SetInAttrValueMap(m, p.first, p.second);
    142   }
    143   if (include_those_in_node_def && node_def_ != nullptr) {
    144     for (AttrValueMap::const_iterator it = node_def_->attr().begin();
    145          it != node_def_->attr().end(); ++it) {
    146       m->insert(*it);
    147     }
    148   }
    149 }
    150 
    151 const NodeDef& AttrBuilder::BuildNodeDef() {
    152   if (node_def_finalized_) return *node_def_;
    153   MayBeInitializeNodeDef();
    154   for (int i = 0; i < num_inputs_; ++i) {
    155     node_def_->add_input("dummy_input");
    156   }
    157   FillAttrValueMap(node_def_->mutable_attr(), false);
    158   node_def_finalized_ = true;
    159   return *node_def_;
    160 }
    161 
    162 namespace {
    163 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
    164                                                const tensorflow::Fprint128& b) {
    165   return {tensorflow::FingerprintCat64(a.low64, b.low64),
    166           tensorflow::FingerprintCat64(a.low64, b.low64)};
    167 }
    168 
    169 void CombineUnordered(const tensorflow::Fprint128& a,
    170                       tensorflow::Fprint128* b) {
    171   b->low64 += a.low64;
    172   b->high64 += a.high64;
    173 }
    174 
    175 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s,
    176                                             const tensorflow::Fprint128& b) {
    177   // TODO(agarwal): avoid ToString().
    178   tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString());
    179   return FingerprintCat128(a, b);
    180 }
    181 
    182 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
    183   return CacheKeyHelper(s, {b, b});
    184 }
    185 
    186 }  // namespace
    187 
    188 tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const {
    189   tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_);
    190   f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
    191   if (node_def_ != nullptr) {
    192     // Some attributes are directly written to node_def_ instead of being
    193     // stored explicitly.
    194     string value;
    195     for (const auto& attr : node_def_->attr()) {
    196       attr.second.SerializeToString(&value);
    197       CombineUnordered(
    198           CacheKeyHelper(attr.first, tensorflow::Fingerprint128(value)), &f);
    199     }
    200     // Note that node_def_ may be created but not finalized. This can happen
    201     // when the creation was triggered by a call to Set, but BuildNodeDef has
    202     // not been called.
    203     if (node_def_finalized_) return f;
    204   }
    205   for (const auto& p : string_attrs_) {
    206     // TODO(agarwal): avoid ToString().
    207     CombineUnordered(CacheKeyHelper(p.first, tensorflow::Fingerprint128(
    208                                                  p.second.ToString())),
    209                      &f);
    210   }
    211   for (const auto& p : int_attrs_) {
    212     CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
    213                      &f);
    214   }
    215   static std::hash<float> float_hasher;
    216   for (const auto& p : float_attrs_) {
    217     CombineUnordered(
    218         CacheKeyHelper(p.first, static_cast<uint64>(float_hasher(p.second))),
    219         &f);
    220   }
    221   for (const auto& p : bool_attrs_) {
    222     CombineUnordered(CacheKeyHelper(p.first, p.second ? 1u : 0u), &f);
    223   }
    224   for (const auto& p : type_attrs_) {
    225     CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
    226                      &f);
    227   }
    228   return f;
    229 }
    230 
    231 void AttrBuilder::MayBeInitializeNodeDef() {
    232   if (node_def_ == nullptr) {
    233     node_def_.reset(new NodeDef());
    234     node_def_->set_name(op_name_);
    235     node_def_->set_op(op_name_);
    236   }
    237 }
    238 
    239 // static
    240 Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
    241                                KernelAndDevice* out) {
    242   OpKernel* k = nullptr;
    243   Status s = CreateOpKernel(device->device_type().c_str(), device,
    244                             device->GetAllocator(AllocatorAttributes()),
    245                             nullptr, ndef, TF_GRAPH_DEF_VERSION, &k);
    246   out->device_ = device;
    247   out->kernel_.reset(k);
    248   out->flib_ = nullptr;
    249   return s;
    250 }
    251 
    252 // static
    253 Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
    254                              KernelAndDevice* out) {
    255   OpKernel* k = nullptr;
    256   Status s = flib->CreateKernel(ndef, &k);
    257   out->device_ = flib->device();
    258   out->kernel_.reset(k);
    259   out->flib_ = flib;
    260   return s;
    261 }
    262 
    263 Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
    264                             std::vector<Tensor>* output_tensors,
    265                             NodeExecStats* stats) {
    266   gtl::InlinedVector<TensorValue, 4> inputs;
    267   for (Tensor& t : *input_tensors) {
    268     inputs.push_back(TensorValue(&t));
    269   }
    270 
    271   std::vector<AllocatorAttributes> out_attrs(kernel_->num_outputs());
    272   for (size_t i = 0; i < out_attrs.size(); ++i) {
    273     out_attrs[i].set_on_host(kernel_->output_memory_types()[i] ==
    274                              tensorflow::HOST_MEMORY);
    275   }
    276 
    277   OpKernelContext::Params params;
    278   params.device = device_;
    279   params.frame_iter = FrameAndIter(0, 0);
    280   params.inputs = &inputs;
    281   params.op_kernel = kernel_.get();
    282   params.resource_manager = device_->resource_manager();
    283   params.output_attr_array = gtl::vector_as_array(&out_attrs);
    284   params.function_library = flib_;
    285   params.slice_reader_cache = &slice_reader_cache_;
    286   params.rendezvous = rendez_;
    287   if (stats != nullptr) {
    288     params.track_allocations = true;
    289   }
    290   // TODO(apassos): use a thread pool.
    291   std::function<void(std::function<void()>)> runner =
    292       [](std::function<void()> f) { f(); };
    293   params.runner = &runner;
    294 
    295   OpKernelContext context(&params);
    296   device_->Compute(kernel_.get(), &context);
    297   if (!context.status().ok()) return context.status();
    298 
    299   output_tensors->clear();
    300   for (int i = 0; i < context.num_outputs(); ++i) {
    301     output_tensors->push_back(Tensor(*context.mutable_output(i)));
    302   }
    303   if (stats != nullptr) {
    304     for (const auto& allocator_pair : context.wrapped_allocators()) {
    305       AllocatorMemoryUsed* memory = stats->add_memory();
    306       memory->set_allocator_name(allocator_pair.first->Name());
    307       auto sizes = allocator_pair.second->GetSizes();
    308       memory->set_total_bytes(std::get<0>(sizes));
    309       memory->set_peak_bytes(std::get<1>(sizes));
    310       memory->set_live_bytes(std::get<2>(sizes));
    311 
    312       AllocatorStats allocator_stats;
    313       allocator_pair.first->GetStats(&allocator_stats);
    314       memory->set_allocator_bytes_in_use(allocator_stats.bytes_in_use);
    315       allocator_pair.second->GetRecordsAndUnRef();
    316     }
    317     auto* ms = stats->mutable_memory_stats();
    318     ms->set_temp_memory_size(context.temp_memory_allocated());
    319     for (const auto& alloc_id : context.persistent_alloc_ids()) {
    320       ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
    321     }
    322 
    323     ms->set_persistent_memory_size(context.persistent_memory_allocated());
    324   }
    325   return Status::OK();
    326 }
    327 
    328 }  // namespace tensorflow
    329