1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_C_EAGER_RUNTIME_H_ 17 #define TENSORFLOW_C_EAGER_RUNTIME_H_ 18 19 // Support for eager execution of TensorFlow kernels. 20 21 #include <memory> 22 #include <unordered_map> 23 24 #include "tensorflow/c/c_api.h" 25 #include "tensorflow/core/common_runtime/device.h" 26 #include "tensorflow/core/framework/node_def.pb.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/lib/gtl/inlined_vector.h" 31 #include "tensorflow/core/platform/fingerprint.h" 32 #include "tensorflow/core/util/tensor_slice_reader_cache.h" 33 34 namespace tensorflow { 35 36 // Maps attribute name to an encoding of the type of the attribute value. 37 // If the type is not a list type, the value is the same as the TF_AttrType type 38 // of the value. Else, the highest order bit is on, and the rest of the bits 39 // represent the TF_AttrType type of the values in the list. 40 typedef std::unordered_map<string, uint32> AttrTypeMap; 41 42 // Returns the AttrTypeMap for the TensorFlow operation named op_name. 43 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out); 44 45 // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. 46 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, 47 TF_AttrType* out, unsigned char* is_list); 48 49 // KernelAndDevice::Init needs a NodeDef only to pass the attribute map through. 50 // An AttrBuilder is a convenience class to help with that - providing a smaller 51 // interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity 52 // checks (like number of inputs matching the OpDef - we only care about 53 // attributes here). 54 // 55 // TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which 56 // ones make sense to replicate. 57 58 // This is a helper class for creating a NodeDef. Additionally, this class 59 // allows computing a cache key based on fingerprinting the attributes of this 60 // NodeDef. 61 // 62 // Example usage: 63 // AttrBuilder a; 64 // a.NumInputs(2); 65 // a.Set("T", TF_FLOAT); 66 // uint64 cache_key = a.CacheKey("cpu:0"); 67 // const NodeDef& n = a.BuildNodeDef(); 68 // 69 // Note that all calls to Set and NumInputs should happen before calling 70 // BuildNodeDef. Also, calls to NumInputs or Set between multiple invocations 71 // to CacheKey may cause different values to be returned by CacheKey. 72 // 73 // For performance reasons, the class internally delays the actual construction 74 // of the NodeDef till BuildNodeDef is called, or Set is called with certain 75 // uncommon types (see template specializations of Set to see which types 76 // trigger a NodeDef creation). 77 class AttrBuilder { 78 public: 79 explicit AttrBuilder(const char* op) 80 : op_name_(op), 81 num_inputs_(0), 82 node_def_(nullptr), 83 node_def_finalized_(false) {} 84 85 // Needed to work around call to ValidateNodeDef in CreateOpKernel. 86 AttrBuilder& NumInputs(int n); 87 88 template <class T> 89 AttrBuilder& Set(StringPiece attr_name, T&& value) { 90 MayBeInitializeNodeDef(); 91 SetInAttrValueMap(node_def_->mutable_attr(), attr_name, value); 92 return *this; 93 } 94 95 tensorflow::Fprint128 CacheKey(const string& device) const; 96 97 void FillAttrValueMap(AttrValueMap* m) const { FillAttrValueMap(m, true); } 98 const NodeDef& BuildNodeDef(); 99 100 private: 101 template <class T> 102 using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>; 103 104 void MayBeInitializeNodeDef(); 105 void FillAttrValueMap(AttrValueMap* m, bool include_those_in_node_def) const; 106 107 template <class T> 108 void SetInAttrValueMap(AttrValueMap* m, StringPiece attr_name, 109 T&& value) const { 110 DCHECK(!node_def_finalized_) 111 << "Calling SetInAttrValueMap after BuildNodeDef."; 112 // Copied from NodeDefBuilder::Attr 113 const AttrValue* found = AttrSlice(m).Find(attr_name); 114 AttrValue attr_value; 115 if (found == nullptr) { 116 SetAttrValue(value, &attr_value); 117 m->insert(AttrValueMap::value_type(attr_name.ToString(), attr_value)); 118 } else { 119 // TODO(ashankar): Do what is done in 120 // NodeDefBuilder::CheckInconsistency(attr_name, *found, attr_value); 121 SetAttrValue(std::forward<T>(value), &attr_value); 122 (*m)[attr_name.ToString()] = attr_value; 123 } 124 } 125 126 AttrVec<StringPiece> string_attrs_; 127 AttrVec<int> int_attrs_; 128 AttrVec<float> float_attrs_; 129 AttrVec<bool> bool_attrs_; 130 AttrVec<tensorflow::DataType> type_attrs_; 131 const string op_name_; 132 int num_inputs_; 133 std::unique_ptr<NodeDef> node_def_; 134 bool node_def_finalized_; 135 }; // namespace tensorflow 136 137 template <> 138 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value); 139 template <> 140 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value); 141 template <> 142 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value); 143 template <> 144 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, bool&& value); 145 template <> 146 AttrBuilder& AttrBuilder::Set(StringPiece attr_name, 147 tensorflow::DataType&& value); 148 149 // KernelAndDevice encapsulates an instantiated kernel and the device it is on. 150 // 151 // Also see: 152 // https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h 153 // and 154 // https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h 155 class KernelAndDevice { 156 public: 157 // Populates 'out' with a kernel appropriate for 'ndef'. 158 // 159 // The provided FunctionLibraryRuntime MUST outlive all calls to 160 // Run() on the returned KernelAndDevice. 161 // 162 // TODO(ashankar): Figure out thread-safety concerns around 163 // FunctionLibraryRuntime (in particular, how the underlying 164 // FunctionLibraryDefinition might be mutated by another thread as new 165 // functions are registered with it). Conservatively, thread-safe usage of 166 // the FunctionLibraryRuntime is pushed on to the caller (see locking in 167 // c_api.cc). 168 static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, 169 KernelAndDevice* out); 170 // TODO(ashankar): Remove this 171 static Status InitOp(Device* device, const NodeDef& ndef, 172 KernelAndDevice* out); 173 174 KernelAndDevice(tensorflow::Rendezvous* rendez) 175 : device_(nullptr), flib_(nullptr), rendez_(rendez) {} 176 177 // TODO(ashankar): Handle list-valued inputs. 178 Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs, 179 NodeExecStats* stats); 180 181 const OpKernel* kernel() const { return kernel_.get(); } 182 183 private: 184 std::unique_ptr<OpKernel> kernel_; 185 Device* device_; 186 FunctionLibraryRuntime* flib_; 187 checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; 188 Rendezvous* rendez_; 189 }; 190 191 } // namespace tensorflow 192 193 #endif // TENSORFLOW_C_EAGER_RUNTIME_H_ 194