Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_
     17 #define TENSORFLOW_FRAMEWORK_FUNCTION_H_
     18 
     19 #include <vector>
     20 #include "tensorflow/core/framework/attr_value.pb.h"
     21 #include "tensorflow/core/framework/attr_value_util.h"
     22 #include "tensorflow/core/framework/function.pb.h"
     23 #include "tensorflow/core/framework/node_def_util.h"
     24 #include "tensorflow/core/framework/op.h"
     25 #include "tensorflow/core/framework/selective_registration.h"
     26 #include "tensorflow/core/framework/types.h"
     27 #include "tensorflow/core/lib/gtl/flatmap.h"
     28 #include "tensorflow/core/lib/hash/hash.h"
     29 #include "tensorflow/core/platform/env.h"
     30 #include "tensorflow/core/platform/macros.h"
     31 #include "tensorflow/core/platform/protobuf.h"
     32 
     33 namespace tensorflow {
     34 
     35 class CancellationManager;
     36 class GraphDef;
     37 class OpKernel;
     38 class ProcessFunctionLibraryRuntime;
     39 class ResourceMgr;
     40 class Rendezvous;
     41 class ScopedStepContainer;
     42 class StepStatsCollector;
     43 class Node;
     44 
     45 // FunctionDefHelper::Create is a convenient helper to construct a
     46 // FunctionDef proto.
     47 // E.g.,
     48 //   FunctionDef my_func = FunctionDefHelper::Create(
     49 //     "my_func_name",
     50 //     {"x:T", "y:T" /* one string per argument */},
     51 //     {"z:T" /* one string per return value */},
     52 //     {"T: {float, double}" /* one string per attribute  */},
     53 //     {
     54 //        {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
     55 //        /* one entry per function node */
     56 //     },
     57 //     /* Mapping between function returns and function node outputs. */
     58 //     {{"z", "o:z"}});
     59 //
     60 // For the old Function::Node approach, use FunctionDefHelper::Define()
     61 // E.g.,
     62 //   FunctionDef my_func = FunctionDefHelper::Define(
     63 //     "my_func_name",
     64 //     {"x:T", "y:T" /* one string per argument */},
     65 //     {"z:T" /* one string per return value */},
     66 //     {"T: {float, double}" /* one string per attribute  */},
     67 //     {
     68 //        {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
     69 //        /* one entry per function node */
     70 //     });
     71 class FunctionDefHelper {
     72  public:
     73   // AttrValueWrapper has copy constructors for the type T so that
     74   // it's easy to construct a simple AttrValue proto.
     75   //
     76   // If T is a string type (const char*, string, or StringPiece), and
     77   // it starts with "$", we construct a AttrValue of "placeholder".
     78   //
     79   // E.g.,
     80   //   std::<string, AttrValueWrapper> x = {"T", "$T"}
     81   // is a named attr value placeholder.
     82   struct AttrValueWrapper {
     83     AttrValue proto;
     84 
     85     AttrValueWrapper() {}
     86 
     87     template <typename T>
     88     AttrValueWrapper(T val) {  // NOLINT(runtime/explicit)
     89       SetAttrValue(val, &proto);
     90     }
     91 
     92    private:
     93     void InitFromString(StringPiece val);
     94   };
     95 
     96   // Constructs an AttrValue.func given the "name" and "attrs".
     97   static AttrValueWrapper FunctionRef(
     98       const string& name,
     99       gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
    100   static AttrValueWrapper FunctionRef(const string& name) {
    101     return FunctionRef(name, {});
    102   }
    103 
    104   // Node is used to construct FunctionDef.Node using initialization
    105   // lists. E.g.,
    106   //  Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}};  // z = x * y
    107   struct Node {
    108     // When constructing a NodeDef, the first entry in ret is used as
    109     // the node name, the remaining values are ignored.
    110     std::vector<string> ret;
    111     string op;
    112     std::vector<string> arg;
    113     std::vector<std::pair<string, AttrValueWrapper>> attr;
    114     std::vector<string> dep;
    115 
    116     NodeDef ToNodeDef() const;
    117   };
    118 
    119   // The Create() function uses the new NodeDef field.  `ret_def`
    120   // holds a mapping from the function output names from `out_def` to
    121   // the node outputs from `node_def`.
    122   static FunctionDef Create(const string& function_name,
    123                             gtl::ArraySlice<string> in_def,
    124                             gtl::ArraySlice<string> out_def,
    125                             gtl::ArraySlice<string> attr_def,
    126                             gtl::ArraySlice<Node> node_def,
    127                             gtl::ArraySlice<std::pair<string, string>> ret_def);
    128 
    129   // The two Define() functions use the old FunctionDef::Node field.
    130   // TODO(josh11b): Get rid of these and transition to the one above.
    131   static FunctionDef Define(const string& function_name,
    132                             gtl::ArraySlice<string> arg_def,
    133                             gtl::ArraySlice<string> ret_def,
    134                             gtl::ArraySlice<string> attr_def,
    135                             gtl::ArraySlice<Node> node_def);
    136 
    137   // Defines an anonymous function. I.e., its name is not relevant.
    138   static FunctionDef Define(gtl::ArraySlice<string> arg_def,
    139                             gtl::ArraySlice<string> ret_def,
    140                             gtl::ArraySlice<string> attr_def,
    141                             gtl::ArraySlice<Node> node_def);
    142 
    143   // Helpers to construct a constant scalar.
    144   template <typename T>
    145   static Node Const(const string& name, const T& val) {
    146     Node n = {{name}, "Const"};
    147     const DataType dtype = DataTypeToEnum<T>::value;
    148     n.attr.push_back({"dtype", dtype});
    149     Tensor t(dtype, TensorShape({}));
    150     t.scalar<T>()() = val;
    151     n.attr.push_back({"value", t});
    152     return n;
    153   }
    154 
    155   template <typename T>
    156   static Node Const(const string& name, gtl::ArraySlice<T> vals) {
    157     Node n = {{name}, "Const"};
    158     const DataType dtype = DataTypeToEnum<T>::value;
    159     n.attr.push_back({"dtype", dtype});
    160     int64 num = vals.size();
    161     Tensor t(dtype, TensorShape({num}));
    162     for (size_t i = 0; i < vals.size(); ++i) {
    163       t.flat<T>()(i) = vals[i];
    164     }
    165     n.attr.push_back({"value", t});
    166     return n;
    167   }
    168 };
    169 
    170 template <>
    171 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) {
    172   InitFromString(val);
    173 }
    174 
    175 template <>
    176 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
    177     const string& val) {
    178   InitFromString(val);
    179 }
    180 
    181 template <>
    182 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
    183   InitFromString(val);
    184 }
    185 
    186 // Instantiate a function.
    187 //
    188 // "fdef" encodes a TF function with some attrs in fdef.signature.attr
    189 // containing placeholders.  InstantiateFunction binds these
    190 // placeholders and produces an instantiated function encoded in
    191 // "result.gdef". The value to substitute a placeholder is given by
    192 // "attr_values", which is a map from a placeholder name to an attr
    193 // value.
    194 //
    195 // InstantiateFunction calls "get_function" to find signatures of other
    196 // functions and primitive ops.
    197 
    198 // GetFunctionSignature(func name, opdef) returns OK if the func name is found
    199 // and opdef is filled with a pointer to the corresponding signature
    200 // (a OpDef proto). Otherwise, returns an error.
    201 typedef std::function<Status(const string&, const OpDef**)>
    202     GetFunctionSignature;
    203 
    204 struct InstantiationResult {
    205   DataTypeVector arg_types;
    206   DataTypeVector ret_types;
    207   std::vector<NodeDef> nodes;
    208 };
    209 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
    210                            GetFunctionSignature get_function,
    211                            InstantiationResult* result);
    212 
    213 // Returns a debug string for a function definition.
    214 //
    215 // The returned text is multiple-line. It is intended to be
    216 // human-readable rather than being friendly to parsers. It is _NOT_
    217 // intended to be the canonical string representation of "func_def".
    218 // Particularly, it may not include all information presented in
    219 // "func_def" (e.g., comments, description of the function arguments,
    220 // etc.)
    221 string DebugString(const FunctionDef& func_def);
    222 string DebugString(const GraphDef& instantiated_func_def);
    223 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
    224 
    225 // Returns a debug string for a top level graph (the main program and
    226 // its supporting functions defined in its library).
    227 string DebugStringWhole(const GraphDef& gdef);
    228 
    229 // Returns true if f1 == f2. Compares all fields, including descriptions. Order
    230 // of NodeDefs doesn't matter.
    231 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
    232 
    233 // Return a hash of `fdef` that is consistent with FunctionDefsEqual method.
    234 // In other words, if two fdefs compare equal, their hash values will be the
    235 // same.
    236 uint64 FunctionDefHash(const FunctionDef& fdef);
    237 
    238 class CallFrameInterface {
    239  public:
    240   virtual ~CallFrameInterface() {}
    241 
    242   virtual size_t num_args() const = 0;
    243   virtual size_t num_retvals() const = 0;
    244 
    245   virtual Status GetArg(int index, Tensor* val) const = 0;
    246   virtual Status SetRetval(int index, const Tensor& val) = 0;
    247 };
    248 
    249 // Represents a function call frame. I.e., the data structure used to
    250 // pass arguments to a function and retrieve its results.
    251 //
    252 // Runtime must arrange accesses to one FunctionCallFrame s.t.
    253 //   1. SetArgs() happens before any GetArg();
    254 //   2. GetRetvals happens after all SetRetval();
    255 class FunctionCallFrame : public CallFrameInterface {
    256  public:
    257   FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types);
    258   ~FunctionCallFrame();
    259 
    260   // Caller methods.
    261   Status SetArgs(gtl::ArraySlice<Tensor> args);
    262   Status GetRetvals(std::vector<Tensor>* rets) const;
    263   Status ConsumeRetvals(std::vector<Tensor>* rets);
    264 
    265   size_t num_args() const override { return arg_types_.size(); }
    266   size_t num_retvals() const override { return ret_types_.size(); }
    267 
    268   // Callee methods.
    269   Status GetArg(int index, Tensor* val) const override;
    270   Status SetRetval(int index, const Tensor& val) override;
    271 
    272  private:
    273   DataTypeVector arg_types_;
    274   DataTypeVector ret_types_;
    275   gtl::InlinedVector<Tensor, 4> args_;
    276   struct Retval {
    277     bool has_val = false;
    278     Tensor val;
    279   };
    280   gtl::InlinedVector<Retval, 4> rets_;
    281 
    282   TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame);
    283 };
    284 
    285 // Helper to maintain a map between function names in a given
    286 // FunctionDefLibrary and function definitions.
    287 class FunctionLibraryDefinition : public OpRegistryInterface {
    288  public:
    289   explicit FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def);
    290   FunctionLibraryDefinition(const OpRegistryInterface* default_registry,
    291                             const FunctionDefLibrary& lib_def);
    292   ~FunctionLibraryDefinition() override;
    293 
    294   FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) =
    295       delete;
    296 
    297   // Returns nullptr if "func" is not defined in "lib_def". Otherwise,
    298   // returns its definition proto.
    299   const FunctionDef* Find(const string& func) const;
    300 
    301   // Adds function definition 'fdef' to this function library.
    302   // Returns status 'ok' on success, or error otherwise. This is a no-op if
    303   // 'fdef' already exists in this function library.
    304   // If 'fdef' is successfully added to the library, it will be accessible
    305   // from 'LookUp' and included in the proto returned by 'ToProto'.
    306   // This operation is atomic.
    307   Status AddFunctionDef(const FunctionDef& fdef);
    308 
    309   // Adds gradient definition 'grad' to this function library.
    310   // This is a no-op if 'grad' already exists in this function library.
    311   // If 'grad' is successfully added, it will be accessible via 'FindGradient'
    312   // and included in the proto returned by 'ToProto'.
    313   // This operation is atomic.
    314   Status AddGradientDef(const GradientDef& grad);
    315 
    316   // Remove function `func` from the library. Returns non-OK Status unless
    317   // `func` is in the library.
    318   Status RemoveFunction(const string& func);
    319 
    320   // Remove gradient of function `func` from the library. Returns non-OK Status
    321   // unless `func` has a gradient.
    322   Status RemoveGradient(const string& func);
    323 
    324   // Adds the functions and gradients in 'other' to this function library.
    325   // Duplicate functions and gradients are ignored.
    326   // This operation is atomic.
    327   Status AddLibrary(const FunctionLibraryDefinition& other);
    328 
    329   // Adds the functions and gradients in 'lib_def' to this function library.
    330   // Duplicate functions and gradients are ignored.
    331   // This operation is atomic.
    332   Status AddLibrary(const FunctionDefLibrary& lib_def);
    333 
    334   // If the gradient function for 'func' is specified explicitly in
    335   // the library, returns the gradient function name.  Otherwise,
    336   // returns an empty string.
    337   string FindGradient(const string& func) const;
    338 
    339   // OpRegistryInterface method. Useful for constructing a Graph.
    340   //
    341   // If "op" is defined in the library, returns its signature.
    342   // Otherwise, assume "op" is a primitive op and returns its op
    343   // signature and shape inference function.
    344   Status LookUp(const string& op_type_name,
    345                 const OpRegistrationData** op_reg_data) const override;
    346 
    347   static constexpr const char* const kGradientOp = "SymbolicGradient";
    348   static constexpr const char* const kFuncAttr = "f";
    349 
    350   // Given a node def 'ndef', inspects attributes of the callee
    351   // function to derive the attribute 'value' for 'attr'. Returns OK
    352   // iff the attribute is given by the function's definition.
    353   // TODO(irving): Remove; keep only the const Node& version.
    354   template <typename T>
    355   Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const;
    356 
    357   // Given a node, inspects attributes of the callee function to derive the
    358   // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
    359   // function's definition.
    360   template <typename T>
    361   Status GetAttr(const Node& node, const string& attr, T* value) const;
    362 
    363   // Returns a proto representation of the state of this function library.
    364   FunctionDefLibrary ToProto() const;
    365 
    366   size_t num_functions() const { return function_defs_.size(); }
    367 
    368   const OpRegistryInterface* default_registry() const {
    369     return default_registry_;
    370   }
    371 
    372  private:
    373   // Shape inference for functions is handled separately by ShapeRefiner.
    374 
    375   struct FunctionDefAndOpRegistration {
    376     FunctionDefAndOpRegistration(const FunctionDef& fdef_in);
    377 
    378     FunctionDef fdef;
    379     OpRegistrationData op_registration_data;
    380   };
    381 
    382   // Same as AddFunctionDef/AddGradientDef except these methods set
    383   // `added` to true if the `fdef`/`grad` were actually added to this.
    384   Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added);
    385   Status AddGradientDefHelper(const GradientDef& grad, bool* added);
    386 
    387   const OpRegistryInterface* const default_registry_;
    388   gtl::FlatMap<string, std::unique_ptr<FunctionDefAndOpRegistration>>
    389       function_defs_;
    390   gtl::FlatMap<string, string> func_grad_;
    391 
    392   // Helper function for GetAttr. Returns the FunctionDef* to get the
    393   // attr from.
    394   const FunctionDef* GetAttrImpl(const NodeDef& ndef) const;
    395 
    396   // Remove all functions in `funcs` and all gradients of
    397   // functions in `funcs_with_grads` from this library.
    398   void Remove(const std::vector<string>& funcs,
    399               const std::vector<string>& funcs_with_grads);
    400 };
    401 
    402 // Forward declare. Defined in common_runtime/function.h
    403 struct FunctionBody;
    404 
    405 // Forward declare. Defined in common_runtime/device.h
    406 class Device;
    407 
    408 class FunctionLibraryRuntime {
    409  public:
    410   virtual ~FunctionLibraryRuntime() {}
    411 
    412   // Instantiate a function with the given "attrs".
    413   //
    414   // Returns OK and fills in "handle" if the instantiation succeeds.
    415   // Otherwise returns an error and "handle" is undefined.
    416   struct InstantiateOptions {
    417     // The canonical device name of the device on which the function
    418     // should be instantiated. If empty, the function will be
    419     // instantiated on the local device.
    420     string target;
    421 
    422     // This interface is EXPERIMENTAL and subject to change.
    423     //
    424     // If non-null, the runtime will use `overlay_lib` to resolve
    425     // function(s) named in `function_name` and `attrs`. Otherwise,
    426     // the runtime will use its internal library.
    427     // NOTE(mrry): If provided, all functions defined in `overlay_lib`
    428     // must be self-contained, and cannot refer to functions defined
    429     // in other libraries.
    430     // TODO(mrry): Provide a mechanism for sharing core functions
    431     // between a set of libraries (e.g. by allowing a
    432     // `FunctionLibraryDefinition` to store an `outer_scope` pointer
    433     // and implementing name resolution across libraries).
    434     const FunctionLibraryDefinition* overlay_lib = nullptr;
    435 
    436     // This interface is EXPERIMENTAL and subject to change.
    437     //
    438     // If non-empty, the runtime will use `state_handle` to identify
    439     // cached state related the instantiated function. Two functions
    440     // of the same name and attrs, instantiated with the same
    441     // `state_handle` will have the same handle and share the same
    442     // state (in stateful kernels); and two functions with different
    443     // values for `state_handle` will have independent state.
    444     string state_handle;
    445   };
    446   typedef uint64 Handle;
    447   virtual Status Instantiate(const string& function_name, AttrSlice attrs,
    448                              const InstantiateOptions& options,
    449                              Handle* handle) = 0;
    450   Status Instantiate(const string& function_name, AttrSlice attrs,
    451                      Handle* handle) {
    452     return Instantiate(function_name, attrs, {}, handle);
    453   }
    454 
    455   // Releases state associated with the handle.
    456   virtual Status ReleaseHandle(Handle handle) = 0;
    457 
    458   // Returns the function body for the instantiated function given its
    459   // handle 'h'. Returns nullptr if "h" is not found.
    460   //
    461   // *this keeps the ownership of the returned object, which remains alive
    462   // as long as *this.
    463   virtual const FunctionBody* GetFunctionBody(Handle h) = 0;
    464 
    465   // Asynchronously invokes the instantiated function identified by
    466   // "handle".
    467   //
    468   // If function execution succeeds, "done" is called with OK and
    469   // "*rets" is filled with the function's return values. Otheriwse,
    470   // "done" is called with an error status.
    471   //
    472   // Does not take ownership of "rets".
    473   // In the cross-process scenario, runner isn't used for making the Async
    474   // RPC calls.
    475   struct Options {
    476     // The id of the step that is calling this function.
    477     int64 step_id = 0;
    478     Rendezvous* rendezvous = nullptr;
    479     CancellationManager* cancellation_manager = nullptr;
    480     ScopedStepContainer* step_container = nullptr;
    481     StepStatsCollector* stats_collector = nullptr;
    482 
    483     std::function<void(std::function<void()>)>* runner = nullptr;
    484 
    485     // Parameters for remote function execution.
    486     bool remote_execution = false;
    487     string source_device = "";  // Fully specified device name.
    488 
    489     // Allocator attributes specifying where the args are / rets should be put.
    490     // These should either be {} or match the length of args / retvals. If {},
    491     // the default allocator attributes will be assumed for all args / retvals.
    492     std::vector<AllocatorAttributes> args_alloc_attrs;
    493     std::vector<AllocatorAttributes> rets_alloc_attrs;
    494 
    495     // If true, we create a new IntraProcessRendezvous, else use the existing
    496     // one.
    497     bool create_rendezvous = false;
    498   };
    499   typedef std::function<void(const Status&)> DoneCallback;
    500   virtual void Run(const Options& opts, Handle handle,
    501                    gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
    502                    DoneCallback done) = 0;
    503   virtual void Run(const Options& opts, Handle handle,
    504                    CallFrameInterface* call_frame, DoneCallback done) = 0;
    505 
    506   // Creates a "kernel" for the given node def "ndef".
    507   //
    508   // If succeeds, returns OK and the caller takes the ownership of the
    509   // returned "*kernel". Otherwise, returns an error.
    510   virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0;
    511 
    512   // Returns true iff the function named `function_name` is stateful.
    513   // NOTE(mrry): This method assumes that the runtime is associated with a
    514   // default function library, and looks up `function_name` in that library.
    515   // It does not support overlay libraries.
    516   virtual bool IsStateful(const string& function_name) = 0;
    517 
    518   // Returns the device on which the function executes.
    519   virtual Device* device() = 0;
    520 
    521   // Returns the function library definition that backs this runtime.
    522   // NOTE(mrry): The returned library definition is the default function library
    523   // for this runtime. The runtime may instantiate functions from separate
    524   // overlay libraries, which are not returned by this function.
    525   virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
    526       const = 0;
    527 
    528   // Returns the environment on which the function executes.
    529   virtual Env* env() = 0;
    530 
    531   // Returns a debug string showing the definition of the function of
    532   // 'handle'.
    533   virtual string DebugString(Handle handle) = 0;
    534 
    535   // Returns the graph version number.
    536   virtual int graph_def_version() = 0;
    537 
    538   typedef uint64 LocalHandle;
    539 
    540   virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
    541                        std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
    542                        FunctionLibraryRuntime** out_flr) = 0;
    543 };
    544 
    545 // Returns a canonicalized string for the instantiation of the
    546 // function of the given "name", attributes "attrs", and "options".
    547 //
    548 // The returned string is guaranteed to be stable within one address
    549 // space. But it may be change as the implementation
    550 // evolves. Therefore, it should not be persisted or compared across
    551 // address spaces.
    552 string Canonicalize(const string& funcname, AttrSlice attrs,
    553                     const FunctionLibraryRuntime::InstantiateOptions& options);
    554 inline string Canonicalize(const string& funcname, AttrSlice attrs) {
    555   return Canonicalize(funcname, attrs, {});
    556 }
    557 
    558 const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
    559 const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
    560 typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&,
    561                              std::unique_ptr<OpKernel>*)>
    562     CustomKernelCreator;
    563 
    564 // Used to instantiate and run functions in a distributed system.
    565 class DistributedFunctionLibraryRuntime {
    566  public:
    567   virtual ~DistributedFunctionLibraryRuntime() {}
    568 
    569   // The _target attr in attrs determines where the function is instantiated.
    570   virtual Status Instantiate(
    571       const string& function_name, const FunctionLibraryDefinition& lib_def,
    572       AttrSlice attrs,
    573       const FunctionLibraryRuntime::InstantiateOptions& options,
    574       FunctionLibraryRuntime::LocalHandle* handle) = 0;
    575 
    576   // opts.runner isn't used for execution.
    577   virtual void Run(const FunctionLibraryRuntime::Options& opts,
    578                    FunctionLibraryRuntime::LocalHandle handle,
    579                    gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
    580                    FunctionLibraryRuntime::DoneCallback done) = 0;
    581 };
    582 
    583 // Extracts the actual type from "attr_values" based on its definition
    584 // "arg_def".
    585 //
    586 // If "arg_def" is a N*T type, *is_type_list is set to false, and
    587 // *dtypes is set to be a vector of size N and each element is T.
    588 //
    589 // If "arg_def" is a list(type), *is_type_list is set to true, and
    590 // *dtypes is set to be a vector of types specified in attrs for
    591 // arg_def.
    592 //
    593 // Otherwise (arg_def is a simple type T), *is_type_list is set to
    594 // false, and *dtypes is set to a single element vector, whose only
    595 // element is T.
    596 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
    597                   bool* is_type_list, DataTypeVector* dtypes);
    598 
    599 // To register a gradient function for a builtin op, one should use
    600 //   REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>);
    601 //
    602 // Typically, the c++ grad factory is a plan function that can be
    603 // converted into ::tensorflow::gradient::Creator, which is
    604 //   std::function<Status(const AttrSlice&, FunctionDef*)>.
    605 //
    606 // A ::tensorflow::gradient::Creator should populate in FunctionDef* with a
    607 // definition of a brain function which compute the gradient for the
    608 // <op_name> when the <op_name> is instantiated with the given attrs.
    609 //
    610 // E.g.,
    611 //
    612 // Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
    613 //   bool transpose_a;
    614 //   TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a));
    615 //   bool transpose_b;
    616 //   TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b));
    617 //   DataType dtype;
    618 //   TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype));
    619 //   if (!transpose_a && !transpose_b) {
    620 //     *g = FunctionDefHelper::Define(
    621 //       "MatMulGrad",
    622 //       {"x:T ", "y:T", "dz:T"},    // Inputs to this function
    623 //       {"dx:T", "dy:T"},           // Outputs from this function
    624 //       {"T: {float, double}"},     // Attributes needed by this function
    625 //       {
    626 //         {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}},
    627 //         {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}},
    628 //         {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}},
    629 //         {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}},
    630 //       });
    631 //   } else {
    632 //     ... ...
    633 //   }
    634 //   return Status::OK();
    635 // }
    636 //
    637 // NOTE: $T is substituted with the type variable "T" when the
    638 // gradient function MatMul is instantiated.
    639 //
    640 // TODO(zhifengc): Better documentation somewhere.
    641 
    642 // Macros to define a gradient function factory for a primitive
    643 // operation.
    644 #define REGISTER_OP_GRADIENT(name, fn) \
    645   REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn)
    646 
    647 #define REGISTER_OP_NO_GRADIENT(name) \
    648   REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr)
    649 
    650 #define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
    651   REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
    652 
    653 #define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)                 \
    654   static bool unused_grad_##ctr = SHOULD_REGISTER_OP_GRADIENT && \
    655                                   ::tensorflow::gradient::RegisterOp(name, fn)
    656 
    657 namespace gradient {
    658 // Register a gradient creator for the "op".
    659 typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
    660 bool RegisterOp(const string& op, Creator func);
    661 
    662 // Returns OK the gradient creator for the "op" is found (may be
    663 // nullptr if REGISTER_OP_NO_GRADIENT is used.
    664 Status GetOpGradientCreator(const string& op, Creator* creator);
    665 };  // namespace gradient
    666 
    667 // Declare explicit instantiations of GetAttr
    668 #define GET_ATTR(T)                                          \
    669   extern template Status FunctionLibraryDefinition::GetAttr( \
    670       const Node&, const string&, T*) const;                 \
    671   extern template Status FunctionLibraryDefinition::GetAttr( \
    672       const NodeDef&, const string&, T*) const;
    673 GET_ATTR(string)
    674 GET_ATTR(bool)
    675 #undef GET_ATTR
    676 
    677 }  // end namespace tensorflow
    678 
    679 #endif  // TENSORFLOW_FRAMEWORK_FUNCTION_H_
    680