Home | History | Annotate | Download | only in eager
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_C_EAGER_RUNTIME_H_
     17 #define TENSORFLOW_C_EAGER_RUNTIME_H_
     18 
     19 // Support for eager execution of TensorFlow kernels.
     20 
     21 #include <memory>
     22 #include <unordered_map>
     23 
     24 #include "tensorflow/c/c_api.h"
     25 #include "tensorflow/core/common_runtime/device.h"
     26 #include "tensorflow/core/framework/node_def.pb.h"
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     31 #include "tensorflow/core/platform/fingerprint.h"
     32 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
     33 
     34 namespace tensorflow {
     35 
     36 // Maps attribute name to an encoding of the type of the attribute value.
     37 // If the type is not a list type, the value is the same as the TF_AttrType type
     38 // of the value. Else, the highest order bit is on, and the rest of the bits
     39 // represent the TF_AttrType type of the values in the list.
     40 typedef std::unordered_map<string, uint32> AttrTypeMap;
     41 
     42 // Returns the AttrTypeMap for the TensorFlow operation named op_name.
     43 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out);
     44 
     45 // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'.
     46 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
     47                       TF_AttrType* out, unsigned char* is_list);
     48 
     49 // KernelAndDevice::Init needs a NodeDef only to pass the attribute map through.
     50 // An AttrBuilder is a convenience class to help with that - providing a smaller
     51 // interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity
     52 // checks (like number of inputs matching the OpDef - we only care about
     53 // attributes here).
     54 //
     55 // TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which
     56 // ones make sense to replicate.
     57 
     58 // This is a helper class for creating a NodeDef. Additionally, this class
     59 // allows computing a cache key based on fingerprinting the attributes of this
     60 // NodeDef.
     61 //
     62 // Example usage:
     63 // AttrBuilder a;
     64 // a.NumInputs(2);
     65 // a.Set("T", TF_FLOAT);
     66 // uint64 cache_key = a.CacheKey("cpu:0");
     67 // const NodeDef& n = a.BuildNodeDef();
     68 //
     69 // Note that all calls to Set and NumInputs should happen before calling
     70 // BuildNodeDef. Also, calls to NumInputs or Set between multiple invocations
     71 // to CacheKey may cause different values to be returned by CacheKey.
     72 //
     73 // For performance reasons, the class internally delays the actual construction
     74 // of the NodeDef till BuildNodeDef is called, or Set is called with certain
     75 // uncommon types (see template specializations of Set to see which types
     76 // trigger a NodeDef creation).
     77 class AttrBuilder {
     78  public:
     79   explicit AttrBuilder(const char* op)
     80       : op_name_(op),
     81         num_inputs_(0),
     82         node_def_(nullptr),
     83         node_def_finalized_(false) {}
     84 
     85   // Needed to work around call to ValidateNodeDef in CreateOpKernel.
     86   AttrBuilder& NumInputs(int n);
     87 
     88   template <class T>
     89   AttrBuilder& Set(StringPiece attr_name, T&& value) {
     90     MayBeInitializeNodeDef();
     91     SetInAttrValueMap(node_def_->mutable_attr(), attr_name, value);
     92     return *this;
     93   }
     94 
     95   tensorflow::Fprint128 CacheKey(const string& device) const;
     96 
     97   void FillAttrValueMap(AttrValueMap* m) const { FillAttrValueMap(m, true); }
     98   const NodeDef& BuildNodeDef();
     99 
    100  private:
    101   template <class T>
    102   using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>;
    103 
    104   void MayBeInitializeNodeDef();
    105   void FillAttrValueMap(AttrValueMap* m, bool include_those_in_node_def) const;
    106 
    107   template <class T>
    108   void SetInAttrValueMap(AttrValueMap* m, StringPiece attr_name,
    109                          T&& value) const {
    110     DCHECK(!node_def_finalized_)
    111         << "Calling SetInAttrValueMap after BuildNodeDef.";
    112     // Copied from NodeDefBuilder::Attr
    113     const AttrValue* found = AttrSlice(m).Find(attr_name);
    114     AttrValue attr_value;
    115     if (found == nullptr) {
    116       SetAttrValue(value, &attr_value);
    117       m->insert(AttrValueMap::value_type(attr_name.ToString(), attr_value));
    118     } else {
    119       // TODO(ashankar): Do what is done in
    120       // NodeDefBuilder::CheckInconsistency(attr_name, *found, attr_value);
    121       SetAttrValue(std::forward<T>(value), &attr_value);
    122       (*m)[attr_name.ToString()] = attr_value;
    123     }
    124   }
    125 
    126   AttrVec<StringPiece> string_attrs_;
    127   AttrVec<int> int_attrs_;
    128   AttrVec<float> float_attrs_;
    129   AttrVec<bool> bool_attrs_;
    130   AttrVec<tensorflow::DataType> type_attrs_;
    131   const string op_name_;
    132   int num_inputs_;
    133   std::unique_ptr<NodeDef> node_def_;
    134   bool node_def_finalized_;
    135 };  // namespace tensorflow
    136 
    137 template <>
    138 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value);
    139 template <>
    140 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value);
    141 template <>
    142 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value);
    143 template <>
    144 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, bool&& value);
    145 template <>
    146 AttrBuilder& AttrBuilder::Set(StringPiece attr_name,
    147                               tensorflow::DataType&& value);
    148 
    149 // KernelAndDevice encapsulates an instantiated kernel and the device it is on.
    150 //
    151 // Also see:
    152 // https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
    153 // and
    154 // https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
    155 class KernelAndDevice {
    156  public:
    157   // Populates 'out' with a kernel appropriate for 'ndef'.
    158   //
    159   // The provided FunctionLibraryRuntime MUST outlive all calls to
    160   // Run() on the returned KernelAndDevice.
    161   //
    162   // TODO(ashankar): Figure out thread-safety concerns around
    163   // FunctionLibraryRuntime (in particular, how the underlying
    164   // FunctionLibraryDefinition might be mutated by another thread as new
    165   // functions are registered with it).  Conservatively, thread-safe usage of
    166   // the FunctionLibraryRuntime is pushed on to the caller (see locking in
    167   // c_api.cc).
    168   static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
    169                      KernelAndDevice* out);
    170   // TODO(ashankar): Remove this
    171   static Status InitOp(Device* device, const NodeDef& ndef,
    172                        KernelAndDevice* out);
    173 
    174   KernelAndDevice(tensorflow::Rendezvous* rendez)
    175       : device_(nullptr), flib_(nullptr), rendez_(rendez) {}
    176 
    177   // TODO(ashankar): Handle list-valued inputs.
    178   Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
    179              NodeExecStats* stats);
    180 
    181   const OpKernel* kernel() const { return kernel_.get(); }
    182 
    183  private:
    184   std::unique_ptr<OpKernel> kernel_;
    185   Device* device_;
    186   FunctionLibraryRuntime* flib_;
    187   checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
    188   Rendezvous* rendez_;
    189 };
    190 
    191 }  // namespace tensorflow
    192 
    193 #endif  // TENSORFLOW_C_EAGER_RUNTIME_H_
    194