1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_ 17 #define TENSORFLOW_FRAMEWORK_FUNCTION_H_ 18 19 #include <vector> 20 #include "tensorflow/core/framework/attr_value.pb.h" 21 #include "tensorflow/core/framework/attr_value_util.h" 22 #include "tensorflow/core/framework/function.pb.h" 23 #include "tensorflow/core/framework/node_def_util.h" 24 #include "tensorflow/core/framework/op.h" 25 #include "tensorflow/core/framework/selective_registration.h" 26 #include "tensorflow/core/framework/types.h" 27 #include "tensorflow/core/lib/gtl/flatmap.h" 28 #include "tensorflow/core/lib/hash/hash.h" 29 #include "tensorflow/core/platform/env.h" 30 #include "tensorflow/core/platform/macros.h" 31 #include "tensorflow/core/platform/protobuf.h" 32 33 namespace tensorflow { 34 35 class CancellationManager; 36 class GraphDef; 37 class OpKernel; 38 class ProcessFunctionLibraryRuntime; 39 class ResourceMgr; 40 class Rendezvous; 41 class ScopedStepContainer; 42 class StepStatsCollector; 43 class Node; 44 45 // FunctionDefHelper::Create is a convenient helper to construct a 46 // FunctionDef proto. 47 // E.g., 48 // FunctionDef my_func = FunctionDefHelper::Create( 49 // "my_func_name", 50 // {"x:T", "y:T" /* one string per argument */}, 51 // {"z:T" /* one string per return value */}, 52 // {"T: {float, double}" /* one string per attribute */}, 53 // { 54 // {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}} 55 // /* one entry per function node */ 56 // }, 57 // /* Mapping between function returns and function node outputs. */ 58 // {{"z", "o:z"}}); 59 // 60 // For the old Function::Node approach, use FunctionDefHelper::Define() 61 // E.g., 62 // FunctionDef my_func = FunctionDefHelper::Define( 63 // "my_func_name", 64 // {"x:T", "y:T" /* one string per argument */}, 65 // {"z:T" /* one string per return value */}, 66 // {"T: {float, double}" /* one string per attribute */}, 67 // { 68 // {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}} 69 // /* one entry per function node */ 70 // }); 71 class FunctionDefHelper { 72 public: 73 // AttrValueWrapper has copy constructors for the type T so that 74 // it's easy to construct a simple AttrValue proto. 75 // 76 // If T is a string type (const char*, string, or StringPiece), and 77 // it starts with "$", we construct a AttrValue of "placeholder". 78 // 79 // E.g., 80 // std::<string, AttrValueWrapper> x = {"T", "$T"} 81 // is a named attr value placeholder. 82 struct AttrValueWrapper { 83 AttrValue proto; 84 85 AttrValueWrapper() {} 86 87 template <typename T> 88 AttrValueWrapper(T val) { // NOLINT(runtime/explicit) 89 SetAttrValue(val, &proto); 90 } 91 92 private: 93 void InitFromString(StringPiece val); 94 }; 95 96 // Constructs an AttrValue.func given the "name" and "attrs". 97 static AttrValueWrapper FunctionRef( 98 const string& name, 99 gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs); 100 static AttrValueWrapper FunctionRef(const string& name) { 101 return FunctionRef(name, {}); 102 } 103 104 // Node is used to construct FunctionDef.Node using initialization 105 // lists. E.g., 106 // Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y 107 struct Node { 108 // When constructing a NodeDef, the first entry in ret is used as 109 // the node name, the remaining values are ignored. 110 std::vector<string> ret; 111 string op; 112 std::vector<string> arg; 113 std::vector<std::pair<string, AttrValueWrapper>> attr; 114 std::vector<string> dep; 115 116 NodeDef ToNodeDef() const; 117 }; 118 119 // The Create() function uses the new NodeDef field. `ret_def` 120 // holds a mapping from the function output names from `out_def` to 121 // the node outputs from `node_def`. 122 static FunctionDef Create(const string& function_name, 123 gtl::ArraySlice<string> in_def, 124 gtl::ArraySlice<string> out_def, 125 gtl::ArraySlice<string> attr_def, 126 gtl::ArraySlice<Node> node_def, 127 gtl::ArraySlice<std::pair<string, string>> ret_def); 128 129 // The two Define() functions use the old FunctionDef::Node field. 130 // TODO(josh11b): Get rid of these and transition to the one above. 131 static FunctionDef Define(const string& function_name, 132 gtl::ArraySlice<string> arg_def, 133 gtl::ArraySlice<string> ret_def, 134 gtl::ArraySlice<string> attr_def, 135 gtl::ArraySlice<Node> node_def); 136 137 // Defines an anonymous function. I.e., its name is not relevant. 138 static FunctionDef Define(gtl::ArraySlice<string> arg_def, 139 gtl::ArraySlice<string> ret_def, 140 gtl::ArraySlice<string> attr_def, 141 gtl::ArraySlice<Node> node_def); 142 143 // Helpers to construct a constant scalar. 144 template <typename T> 145 static Node Const(const string& name, const T& val) { 146 Node n = {{name}, "Const"}; 147 const DataType dtype = DataTypeToEnum<T>::value; 148 n.attr.push_back({"dtype", dtype}); 149 Tensor t(dtype, TensorShape({})); 150 t.scalar<T>()() = val; 151 n.attr.push_back({"value", t}); 152 return n; 153 } 154 155 template <typename T> 156 static Node Const(const string& name, gtl::ArraySlice<T> vals) { 157 Node n = {{name}, "Const"}; 158 const DataType dtype = DataTypeToEnum<T>::value; 159 n.attr.push_back({"dtype", dtype}); 160 int64 num = vals.size(); 161 Tensor t(dtype, TensorShape({num})); 162 for (size_t i = 0; i < vals.size(); ++i) { 163 t.flat<T>()(i) = vals[i]; 164 } 165 n.attr.push_back({"value", t}); 166 return n; 167 } 168 }; 169 170 template <> 171 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) { 172 InitFromString(val); 173 } 174 175 template <> 176 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper( 177 const string& val) { 178 InitFromString(val); 179 } 180 181 template <> 182 inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) { 183 InitFromString(val); 184 } 185 186 // Instantiate a function. 187 // 188 // "fdef" encodes a TF function with some attrs in fdef.signature.attr 189 // containing placeholders. InstantiateFunction binds these 190 // placeholders and produces an instantiated function encoded in 191 // "result.gdef". The value to substitute a placeholder is given by 192 // "attr_values", which is a map from a placeholder name to an attr 193 // value. 194 // 195 // InstantiateFunction calls "get_function" to find signatures of other 196 // functions and primitive ops. 197 198 // GetFunctionSignature(func name, opdef) returns OK if the func name is found 199 // and opdef is filled with a pointer to the corresponding signature 200 // (a OpDef proto). Otherwise, returns an error. 201 typedef std::function<Status(const string&, const OpDef**)> 202 GetFunctionSignature; 203 204 struct InstantiationResult { 205 DataTypeVector arg_types; 206 DataTypeVector ret_types; 207 std::vector<NodeDef> nodes; 208 }; 209 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, 210 GetFunctionSignature get_function, 211 InstantiationResult* result); 212 213 // Returns a debug string for a function definition. 214 // 215 // The returned text is multiple-line. It is intended to be 216 // human-readable rather than being friendly to parsers. It is _NOT_ 217 // intended to be the canonical string representation of "func_def". 218 // Particularly, it may not include all information presented in 219 // "func_def" (e.g., comments, description of the function arguments, 220 // etc.) 221 string DebugString(const FunctionDef& func_def); 222 string DebugString(const GraphDef& instantiated_func_def); 223 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes); 224 225 // Returns a debug string for a top level graph (the main program and 226 // its supporting functions defined in its library). 227 string DebugStringWhole(const GraphDef& gdef); 228 229 // Returns true if f1 == f2. Compares all fields, including descriptions. Order 230 // of NodeDefs doesn't matter. 231 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2); 232 233 // Return a hash of `fdef` that is consistent with FunctionDefsEqual method. 234 // In other words, if two fdefs compare equal, their hash values will be the 235 // same. 236 uint64 FunctionDefHash(const FunctionDef& fdef); 237 238 class CallFrameInterface { 239 public: 240 virtual ~CallFrameInterface() {} 241 242 virtual size_t num_args() const = 0; 243 virtual size_t num_retvals() const = 0; 244 245 virtual Status GetArg(int index, Tensor* val) const = 0; 246 virtual Status SetRetval(int index, const Tensor& val) = 0; 247 }; 248 249 // Represents a function call frame. I.e., the data structure used to 250 // pass arguments to a function and retrieve its results. 251 // 252 // Runtime must arrange accesses to one FunctionCallFrame s.t. 253 // 1. SetArgs() happens before any GetArg(); 254 // 2. GetRetvals happens after all SetRetval(); 255 class FunctionCallFrame : public CallFrameInterface { 256 public: 257 FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types); 258 ~FunctionCallFrame(); 259 260 // Caller methods. 261 Status SetArgs(gtl::ArraySlice<Tensor> args); 262 Status GetRetvals(std::vector<Tensor>* rets) const; 263 Status ConsumeRetvals(std::vector<Tensor>* rets); 264 265 size_t num_args() const override { return arg_types_.size(); } 266 size_t num_retvals() const override { return ret_types_.size(); } 267 268 // Callee methods. 269 Status GetArg(int index, Tensor* val) const override; 270 Status SetRetval(int index, const Tensor& val) override; 271 272 private: 273 DataTypeVector arg_types_; 274 DataTypeVector ret_types_; 275 gtl::InlinedVector<Tensor, 4> args_; 276 struct Retval { 277 bool has_val = false; 278 Tensor val; 279 }; 280 gtl::InlinedVector<Retval, 4> rets_; 281 282 TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame); 283 }; 284 285 // Helper to maintain a map between function names in a given 286 // FunctionDefLibrary and function definitions. 287 class FunctionLibraryDefinition : public OpRegistryInterface { 288 public: 289 explicit FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def); 290 FunctionLibraryDefinition(const OpRegistryInterface* default_registry, 291 const FunctionDefLibrary& lib_def); 292 ~FunctionLibraryDefinition() override; 293 294 FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) = 295 delete; 296 297 // Returns nullptr if "func" is not defined in "lib_def". Otherwise, 298 // returns its definition proto. 299 const FunctionDef* Find(const string& func) const; 300 301 // Adds function definition 'fdef' to this function library. 302 // Returns status 'ok' on success, or error otherwise. This is a no-op if 303 // 'fdef' already exists in this function library. 304 // If 'fdef' is successfully added to the library, it will be accessible 305 // from 'LookUp' and included in the proto returned by 'ToProto'. 306 // This operation is atomic. 307 Status AddFunctionDef(const FunctionDef& fdef); 308 309 // Adds gradient definition 'grad' to this function library. 310 // This is a no-op if 'grad' already exists in this function library. 311 // If 'grad' is successfully added, it will be accessible via 'FindGradient' 312 // and included in the proto returned by 'ToProto'. 313 // This operation is atomic. 314 Status AddGradientDef(const GradientDef& grad); 315 316 // Remove function `func` from the library. Returns non-OK Status unless 317 // `func` is in the library. 318 Status RemoveFunction(const string& func); 319 320 // Remove gradient of function `func` from the library. Returns non-OK Status 321 // unless `func` has a gradient. 322 Status RemoveGradient(const string& func); 323 324 // Adds the functions and gradients in 'other' to this function library. 325 // Duplicate functions and gradients are ignored. 326 // This operation is atomic. 327 Status AddLibrary(const FunctionLibraryDefinition& other); 328 329 // Adds the functions and gradients in 'lib_def' to this function library. 330 // Duplicate functions and gradients are ignored. 331 // This operation is atomic. 332 Status AddLibrary(const FunctionDefLibrary& lib_def); 333 334 // If the gradient function for 'func' is specified explicitly in 335 // the library, returns the gradient function name. Otherwise, 336 // returns an empty string. 337 string FindGradient(const string& func) const; 338 339 // OpRegistryInterface method. Useful for constructing a Graph. 340 // 341 // If "op" is defined in the library, returns its signature. 342 // Otherwise, assume "op" is a primitive op and returns its op 343 // signature and shape inference function. 344 Status LookUp(const string& op_type_name, 345 const OpRegistrationData** op_reg_data) const override; 346 347 static constexpr const char* const kGradientOp = "SymbolicGradient"; 348 static constexpr const char* const kFuncAttr = "f"; 349 350 // Given a node def 'ndef', inspects attributes of the callee 351 // function to derive the attribute 'value' for 'attr'. Returns OK 352 // iff the attribute is given by the function's definition. 353 // TODO(irving): Remove; keep only the const Node& version. 354 template <typename T> 355 Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const; 356 357 // Given a node, inspects attributes of the callee function to derive the 358 // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the 359 // function's definition. 360 template <typename T> 361 Status GetAttr(const Node& node, const string& attr, T* value) const; 362 363 // Returns a proto representation of the state of this function library. 364 FunctionDefLibrary ToProto() const; 365 366 size_t num_functions() const { return function_defs_.size(); } 367 368 const OpRegistryInterface* default_registry() const { 369 return default_registry_; 370 } 371 372 private: 373 // Shape inference for functions is handled separately by ShapeRefiner. 374 375 struct FunctionDefAndOpRegistration { 376 FunctionDefAndOpRegistration(const FunctionDef& fdef_in); 377 378 FunctionDef fdef; 379 OpRegistrationData op_registration_data; 380 }; 381 382 // Same as AddFunctionDef/AddGradientDef except these methods set 383 // `added` to true if the `fdef`/`grad` were actually added to this. 384 Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added); 385 Status AddGradientDefHelper(const GradientDef& grad, bool* added); 386 387 const OpRegistryInterface* const default_registry_; 388 gtl::FlatMap<string, std::unique_ptr<FunctionDefAndOpRegistration>> 389 function_defs_; 390 gtl::FlatMap<string, string> func_grad_; 391 392 // Helper function for GetAttr. Returns the FunctionDef* to get the 393 // attr from. 394 const FunctionDef* GetAttrImpl(const NodeDef& ndef) const; 395 396 // Remove all functions in `funcs` and all gradients of 397 // functions in `funcs_with_grads` from this library. 398 void Remove(const std::vector<string>& funcs, 399 const std::vector<string>& funcs_with_grads); 400 }; 401 402 // Forward declare. Defined in common_runtime/function.h 403 struct FunctionBody; 404 405 // Forward declare. Defined in common_runtime/device.h 406 class Device; 407 408 class FunctionLibraryRuntime { 409 public: 410 virtual ~FunctionLibraryRuntime() {} 411 412 // Instantiate a function with the given "attrs". 413 // 414 // Returns OK and fills in "handle" if the instantiation succeeds. 415 // Otherwise returns an error and "handle" is undefined. 416 struct InstantiateOptions { 417 // The canonical device name of the device on which the function 418 // should be instantiated. If empty, the function will be 419 // instantiated on the local device. 420 string target; 421 422 // This interface is EXPERIMENTAL and subject to change. 423 // 424 // If non-null, the runtime will use `overlay_lib` to resolve 425 // function(s) named in `function_name` and `attrs`. Otherwise, 426 // the runtime will use its internal library. 427 // NOTE(mrry): If provided, all functions defined in `overlay_lib` 428 // must be self-contained, and cannot refer to functions defined 429 // in other libraries. 430 // TODO(mrry): Provide a mechanism for sharing core functions 431 // between a set of libraries (e.g. by allowing a 432 // `FunctionLibraryDefinition` to store an `outer_scope` pointer 433 // and implementing name resolution across libraries). 434 const FunctionLibraryDefinition* overlay_lib = nullptr; 435 436 // This interface is EXPERIMENTAL and subject to change. 437 // 438 // If non-empty, the runtime will use `state_handle` to identify 439 // cached state related the instantiated function. Two functions 440 // of the same name and attrs, instantiated with the same 441 // `state_handle` will have the same handle and share the same 442 // state (in stateful kernels); and two functions with different 443 // values for `state_handle` will have independent state. 444 string state_handle; 445 }; 446 typedef uint64 Handle; 447 virtual Status Instantiate(const string& function_name, AttrSlice attrs, 448 const InstantiateOptions& options, 449 Handle* handle) = 0; 450 Status Instantiate(const string& function_name, AttrSlice attrs, 451 Handle* handle) { 452 return Instantiate(function_name, attrs, {}, handle); 453 } 454 455 // Releases state associated with the handle. 456 virtual Status ReleaseHandle(Handle handle) = 0; 457 458 // Returns the function body for the instantiated function given its 459 // handle 'h'. Returns nullptr if "h" is not found. 460 // 461 // *this keeps the ownership of the returned object, which remains alive 462 // as long as *this. 463 virtual const FunctionBody* GetFunctionBody(Handle h) = 0; 464 465 // Asynchronously invokes the instantiated function identified by 466 // "handle". 467 // 468 // If function execution succeeds, "done" is called with OK and 469 // "*rets" is filled with the function's return values. Otheriwse, 470 // "done" is called with an error status. 471 // 472 // Does not take ownership of "rets". 473 // In the cross-process scenario, runner isn't used for making the Async 474 // RPC calls. 475 struct Options { 476 // The id of the step that is calling this function. 477 int64 step_id = 0; 478 Rendezvous* rendezvous = nullptr; 479 CancellationManager* cancellation_manager = nullptr; 480 ScopedStepContainer* step_container = nullptr; 481 StepStatsCollector* stats_collector = nullptr; 482 483 std::function<void(std::function<void()>)>* runner = nullptr; 484 485 // Parameters for remote function execution. 486 bool remote_execution = false; 487 string source_device = ""; // Fully specified device name. 488 489 // Allocator attributes specifying where the args are / rets should be put. 490 // These should either be {} or match the length of args / retvals. If {}, 491 // the default allocator attributes will be assumed for all args / retvals. 492 std::vector<AllocatorAttributes> args_alloc_attrs; 493 std::vector<AllocatorAttributes> rets_alloc_attrs; 494 495 // If true, we create a new IntraProcessRendezvous, else use the existing 496 // one. 497 bool create_rendezvous = false; 498 }; 499 typedef std::function<void(const Status&)> DoneCallback; 500 virtual void Run(const Options& opts, Handle handle, 501 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, 502 DoneCallback done) = 0; 503 virtual void Run(const Options& opts, Handle handle, 504 CallFrameInterface* call_frame, DoneCallback done) = 0; 505 506 // Creates a "kernel" for the given node def "ndef". 507 // 508 // If succeeds, returns OK and the caller takes the ownership of the 509 // returned "*kernel". Otherwise, returns an error. 510 virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0; 511 512 // Returns true iff the function named `function_name` is stateful. 513 // NOTE(mrry): This method assumes that the runtime is associated with a 514 // default function library, and looks up `function_name` in that library. 515 // It does not support overlay libraries. 516 virtual bool IsStateful(const string& function_name) = 0; 517 518 // Returns the device on which the function executes. 519 virtual Device* device() = 0; 520 521 // Returns the function library definition that backs this runtime. 522 // NOTE(mrry): The returned library definition is the default function library 523 // for this runtime. The runtime may instantiate functions from separate 524 // overlay libraries, which are not returned by this function. 525 virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition() 526 const = 0; 527 528 // Returns the environment on which the function executes. 529 virtual Env* env() = 0; 530 531 // Returns a debug string showing the definition of the function of 532 // 'handle'. 533 virtual string DebugString(Handle handle) = 0; 534 535 // Returns the graph version number. 536 virtual int graph_def_version() = 0; 537 538 typedef uint64 LocalHandle; 539 540 virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, 541 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, 542 FunctionLibraryRuntime** out_flr) = 0; 543 }; 544 545 // Returns a canonicalized string for the instantiation of the 546 // function of the given "name", attributes "attrs", and "options". 547 // 548 // The returned string is guaranteed to be stable within one address 549 // space. But it may be change as the implementation 550 // evolves. Therefore, it should not be persisted or compared across 551 // address spaces. 552 string Canonicalize(const string& funcname, AttrSlice attrs, 553 const FunctionLibraryRuntime::InstantiateOptions& options); 554 inline string Canonicalize(const string& funcname, AttrSlice attrs) { 555 return Canonicalize(funcname, attrs, {}); 556 } 557 558 const FunctionLibraryRuntime::Handle kInvalidHandle = -1; 559 const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1; 560 typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&, 561 std::unique_ptr<OpKernel>*)> 562 CustomKernelCreator; 563 564 // Used to instantiate and run functions in a distributed system. 565 class DistributedFunctionLibraryRuntime { 566 public: 567 virtual ~DistributedFunctionLibraryRuntime() {} 568 569 // The _target attr in attrs determines where the function is instantiated. 570 virtual Status Instantiate( 571 const string& function_name, const FunctionLibraryDefinition& lib_def, 572 AttrSlice attrs, 573 const FunctionLibraryRuntime::InstantiateOptions& options, 574 FunctionLibraryRuntime::LocalHandle* handle) = 0; 575 576 // opts.runner isn't used for execution. 577 virtual void Run(const FunctionLibraryRuntime::Options& opts, 578 FunctionLibraryRuntime::LocalHandle handle, 579 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, 580 FunctionLibraryRuntime::DoneCallback done) = 0; 581 }; 582 583 // Extracts the actual type from "attr_values" based on its definition 584 // "arg_def". 585 // 586 // If "arg_def" is a N*T type, *is_type_list is set to false, and 587 // *dtypes is set to be a vector of size N and each element is T. 588 // 589 // If "arg_def" is a list(type), *is_type_list is set to true, and 590 // *dtypes is set to be a vector of types specified in attrs for 591 // arg_def. 592 // 593 // Otherwise (arg_def is a simple type T), *is_type_list is set to 594 // false, and *dtypes is set to a single element vector, whose only 595 // element is T. 596 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, 597 bool* is_type_list, DataTypeVector* dtypes); 598 599 // To register a gradient function for a builtin op, one should use 600 // REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>); 601 // 602 // Typically, the c++ grad factory is a plan function that can be 603 // converted into ::tensorflow::gradient::Creator, which is 604 // std::function<Status(const AttrSlice&, FunctionDef*)>. 605 // 606 // A ::tensorflow::gradient::Creator should populate in FunctionDef* with a 607 // definition of a brain function which compute the gradient for the 608 // <op_name> when the <op_name> is instantiated with the given attrs. 609 // 610 // E.g., 611 // 612 // Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) { 613 // bool transpose_a; 614 // TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a)); 615 // bool transpose_b; 616 // TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b)); 617 // DataType dtype; 618 // TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype)); 619 // if (!transpose_a && !transpose_b) { 620 // *g = FunctionDefHelper::Define( 621 // "MatMulGrad", 622 // {"x:T ", "y:T", "dz:T"}, // Inputs to this function 623 // {"dx:T", "dy:T"}, // Outputs from this function 624 // {"T: {float, double}"}, // Attributes needed by this function 625 // { 626 // {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}}, 627 // {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}}, 628 // {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}}, 629 // {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}}, 630 // }); 631 // } else { 632 // ... ... 633 // } 634 // return Status::OK(); 635 // } 636 // 637 // NOTE: $T is substituted with the type variable "T" when the 638 // gradient function MatMul is instantiated. 639 // 640 // TODO(zhifengc): Better documentation somewhere. 641 642 // Macros to define a gradient function factory for a primitive 643 // operation. 644 #define REGISTER_OP_GRADIENT(name, fn) \ 645 REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn) 646 647 #define REGISTER_OP_NO_GRADIENT(name) \ 648 REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr) 649 650 #define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \ 651 REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) 652 653 #define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \ 654 static bool unused_grad_##ctr = SHOULD_REGISTER_OP_GRADIENT && \ 655 ::tensorflow::gradient::RegisterOp(name, fn) 656 657 namespace gradient { 658 // Register a gradient creator for the "op". 659 typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator; 660 bool RegisterOp(const string& op, Creator func); 661 662 // Returns OK the gradient creator for the "op" is found (may be 663 // nullptr if REGISTER_OP_NO_GRADIENT is used. 664 Status GetOpGradientCreator(const string& op, Creator* creator); 665 }; // namespace gradient 666 667 // Declare explicit instantiations of GetAttr 668 #define GET_ATTR(T) \ 669 extern template Status FunctionLibraryDefinition::GetAttr( \ 670 const Node&, const string&, T*) const; \ 671 extern template Status FunctionLibraryDefinition::GetAttr( \ 672 const NodeDef&, const string&, T*) const; 673 GET_ATTR(string) 674 GET_ATTR(bool) 675 #undef GET_ATTR 676 677 } // end namespace tensorflow 678 679 #endif // TENSORFLOW_FRAMEWORK_FUNCTION_H_ 680