Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
     17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
     18 
     19 #include <stack>
     20 
     21 #include "absl/types/span.h"
     22 #include "absl/types/variant.h"
     23 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
     24 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
     25 #include "tensorflow/compiler/tf2xla/xla_expression.h"
     26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     27 #include "tensorflow/compiler/xla/client/local_client.h"
     28 #include "tensorflow/compiler/xla/client/xla_builder.h"
     29 #include "tensorflow/compiler/xla/client/xla_computation.h"
     30 #include "tensorflow/compiler/xla/status_macros.h"
     31 #include "tensorflow/core/common_runtime/device.h"
     32 #include "tensorflow/core/common_runtime/device_mgr.h"
     33 #include "tensorflow/core/common_runtime/function.h"
     34 #include "tensorflow/core/framework/function.h"
     35 #include "tensorflow/core/lib/core/errors.h"
     36 #include "tensorflow/core/platform/env.h"
     37 #include "tensorflow/core/platform/mutex.h"
     38 #include "tensorflow/core/platform/notification.h"
     39 #include "tensorflow/core/platform/thread_annotations.h"
     40 #include "tensorflow/core/public/version.h"
     41 
     42 namespace tensorflow {
     43 
     44 class XlaContext;
     45 
     46 // The XlaCompiler class is responsible for compilation of a self-contained
     47 // subgraph of a TensorFlow computation using the XLA linear algebra runtime.
     48 // It does a symbolic execution of the graph starting from specific input
     49 // shapes, using a JIT device to convert operators into XLA computations.
     50 //
     51 // XlaCompiler is typically invoked from an `XlaLaunch` operator once the
     52 // shapes of all input parameters to the computation are known. This is
     53 // because the symbolic execution requires known shapes for all operations.
     54 //
     55 // XlaCompiler compiles Tensorflow graphs that received inputs via _Arg nodes,
     56 // and return outputs via _Retval nodes.
     57 //
     58 // The XlaCompiler requires one Argument struct for each _Arg index, that
     59 // describes each argument. Arguments can be compile-time constants
     60 // (kind kConstant), run-time parameters (kind kParameter), or resources
     61 // (kind kResource).
     62 //
     63 // Only kParameter and initialized kResource arguments become runtime parameters
     64 // to the generated XLA computation.
     65 //
     66 // The run-time outputs of the XLA computation are arranged in the following
     67 // order:
     68 //   +------------------+-----------------------------------------+
     69 //   |  _Retval values  |  Updated values of kResource arguments  |
     70 //   +------------------+-----------------------------------------+
     71 // _Retval values are ordered by _Retval index, whereas kResource values are
     72 // ordered by the original _Arg position of the variable.
     73 //
     74 // If a shape representation function is provided as part of
     75 // XlaCompiler::CompileOptions, kParameter arguments and return values to an
     76 // entry computation will be reshaped in accordance to the shape function.
     77 // Arguments and return values to a non-entry computation are not reshaped.
     78 // Variable resource arguments are passed and returned in reshaped form, even
     79 // for non-entry computations. This feature allows TensorFlow to keep on-device
     80 // tensors with a different shape to their representation inside the XLA
     81 // computation.
     82 //
     83 // In computation outputs, updated kResource values are placed the end. When
     84 // emitting While loop bodies, we must ensure that the loop body has
     85 // identical input and output signatures. By passing variable values
     86 // at the end of the argument list and using the
     87 // `return_updated_values_for_all_variables` option, we can ensure that the
     88 // input and output values of resources appear at the same positions.
     89 //
     90 // Resources are passed as parameters or returned as resource updates in
     91 // "packed" form.
     92 // kStack resources are packed as (array, size of stack) XLA tuples.
     93 // kTensorArray resources without gradients are packed as the array that
     94 // backs the TensorArray. If gradients are present (`tensor_array_gradients`),
     95 // the packed representation is a (array, gradient0, gradient1, ...) tuple,
     96 // where gradient_k is the value of the k-th gradient in the
     97 // `tensor_array_gradients` ordered set.
     98 class XlaCompiler {
     99  public:
    100   // Describes how to derive the value of each _Arg node in the graph/function
    101   // being compiled. There must be one Argument for each _Arg index.
    102   struct Argument {
    103     enum Kind {
    104       // Default value; not a valid kind.
    105       kInvalid,
    106 
    107       // Argument is a compile-time constant. No associated runtime parameter.
    108       kConstant,
    109 
    110       // Argument is a Variable, TensorArray, or Stack resource. Has an
    111       // associated runtime parameter iff `initialized` is true.
    112       kResource,
    113 
    114       // Argument is a run-time parameter.
    115       kParameter,
    116 
    117       // Argument is an XLA token.
    118       kToken,
    119     };
    120 
    121     Kind kind = kInvalid;
    122 
    123     // The type of the argument. If the argument is a resource, this
    124     // is the type of the variable's value, not DT_RESOURCE.
    125     DataType type = DT_INVALID;
    126 
    127     // The shape of the argument. For:
    128     // * a parameter: the shape of the parameter. We allow setting the xla shape
    129     //   if known. This helps avoid conversions to and from TensorShape.
    130     // * a constant: ignored; the shape given by constant_value is used
    131     //     instead.
    132     // * an uninitialized resource: ignored. We don't yet know the shape of an
    133     //     uninitialized resource (otherwise we would have initialized it!)
    134     // * an initialized variable: the shape of the variable's value.
    135     // * an initialized TensorArray or Stack resource: the shape of an entry in
    136     //   the TensorArray/Stack. Note this is the size of a single entry, not the
    137     //   XLA data structure that represents the complete stack/array.
    138     absl::variant<TensorShape, xla::Shape> shape;
    139 
    140     // The value of the argument, if it is a compile-time constant. Must be a
    141     // host-memory tensor.
    142     Tensor constant_value;
    143 
    144     // The name of this argument, used for debugging.
    145     string name;
    146 
    147     // For a kResource, what kind of resource is it?
    148     XlaResource::Kind resource_kind = XlaResource::kInvalid;
    149 
    150     // For a kResource, has this resource been initialized?
    151     bool initialized = false;
    152 
    153     // For a TensorArray or Stack resource, what is the array's declared size?
    154     // (Used for lazy initialization.)
    155     int64 max_array_size = -1;
    156 
    157     // TensorArray resource parameters are passed as (array, gradient array 0,
    158     // ..., gradient array k), where the gradient arrays are in the same order
    159     // as `tensor_array_gradients`.
    160     std::set<string> tensor_array_gradients;
    161 
    162     // dynamic dims to arg number map. Empty if no dynamic shapes.
    163     std::map<int32, int32> dynamic_dim_to_arg_num_map;
    164     bool is_pad_arg = false;
    165 
    166     bool operator==(const Argument& other) const;
    167 
    168     // Returns a human-readable summary of the argument.
    169     string HumanString() const;
    170 
    171     // Returns the dimension sizes for either TensorShape or xla::Shape.
    172     std::vector<int64> DimensionSizes() const;
    173 
    174     // Returns the human-readable string for either TensorShape or xla::Shape.
    175     string ShapeHumanString() const;
    176   };
    177 
    178   // Options pertaining to an individual call to CompileGraph() or
    179   // CompileFunction().
    180   struct CompileOptions {
    181     // If `use_tuple_arg` is true, a single tuple parameter will be used for all
    182     // arguments; if false, each argument gets its own parameter.
    183     bool use_tuple_arg = false;
    184 
    185     // If 'return_updated_values_for_all_resources' is true, then updated
    186     // values of all resource arguments will be included in the
    187     // 'resource_updates' of the computation, even if the resource was not
    188     // modified by the computation. Used when compiling loop bodies to ensure
    189     // the input and output signatures match.
    190     bool return_updated_values_for_all_resources = false;
    191 
    192     // If 'resolve_compile_time_constants' is true, then outputs of a
    193     // computation that are known to be compile-time constants will be returned
    194     // as Tensors at compile-time, rather than as run-time outputs of the
    195     // computation.
    196     bool resolve_compile_time_constants = true;
    197 
    198     // If 'always_return_tuple' is true, then the output of a computation will
    199     // always be a tuple. Otherwise, a single-element output will not be wrapped
    200     // in a tuple.
    201     bool always_return_tuple = true;
    202 
    203     // True when compiling the entry computation, false for subcomputations
    204     // (while, call, etc.)
    205     bool is_entry_computation = true;
    206 
    207     // True when we should add XLA input & output to the graph/function.
    208     bool add_token_input_output = false;
    209   };
    210 
    211   struct OutputDescription {
    212     // Type and shape of the output. The shape is the unflattened shape.
    213     // When `type` is DT_RESOURCE, `shape` is the shape of the resource
    214     // variable's value.
    215     DataType type;
    216     TensorShape shape;
    217 
    218     // Constant output value, if known to be constant at JIT compilation time.
    219     // 'Tensor' is in host memory.
    220     bool is_constant = false;
    221     Tensor constant_value;
    222 
    223     // When this output is a resource, i.e. `type == DT_RESOURCE`, this is
    224     // the index of the input that contains the resource.
    225     int input_index;
    226   };
    227 
    228   // Describes a variable write side effect of the computation.
    229   struct ResourceUpdate {
    230     // Index of the input that contains the variable resource to write to.
    231     int input_index;
    232 
    233     // Type and shape of the tensor to be written back.
    234     // The `shape` field has the same meaning as the Argument::shape field.
    235     DataType type;
    236     TensorShape shape;
    237 
    238     // Was the value of the variable modified by the computation?
    239     // (Always true, unless `return_updated_values_for_all_resources` is true.)
    240     bool modified;
    241 
    242     // If the resource is a TensorArray, the set of gradients read or written.
    243     std::set<string> tensor_array_gradients_accessed;
    244   };
    245 
    246   struct CompilationResult {
    247     // Vector that maps from the parameters of the XLA computation to their
    248     // original argument positions. To handle compile-time constant inputs, the
    249     // parameters to the XLA computation may be a subset of the original
    250     // arguments. The relative ordering of parameters are maintained.
    251     std::vector<int> input_mapping;
    252 
    253     // Input shapes of the computation. If we are flattening inputs, these are
    254     // the flattened shapes.
    255     std::vector<xla::Shape> xla_input_shapes;
    256 
    257     // Output shape in XLA format. The output shape is always a tuple. If we
    258     // are flattening outputs, these are the flattened shapes.
    259     xla::Shape xla_output_shape;
    260 
    261     // TensorFlow shapes of outputs, together with the values of any
    262     // constant arguments. Vector indexed by Tensorflow _Retval number,
    263     // containing both constant and non-constant results.
    264     std::vector<OutputDescription> outputs;
    265 
    266     // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
    267     // matching RecvAtHost/SendFromHost Ops in the outer graph.
    268     tf2xla::HostComputeMetadata host_compute_metadata;
    269 
    270     // Resources whose values were updated by the computation, ordered
    271     // by return value position (which is the same as the order the resources
    272     // were passed as arguments). Resource updates follow the non-constant
    273     // results in the outputs of XLA computation.
    274     std::vector<ResourceUpdate> resource_updates;
    275 
    276     // The XLA computation built from the tensorflow subgraph.
    277     std::shared_ptr<xla::XlaComputation> computation;
    278   };
    279 
    280   typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType)>
    281       ShapeRepresentationFn;
    282   struct Options {
    283     // Name of the compilation device to use. It must be set by the caller.
    284     // The default empty value is invalid.
    285     DeviceType device_type = DeviceType("");
    286 
    287     // The device to use during compilation to execute instructions on, for
    288     // example for auto-tuning.
    289     // Valid values are defined by `xla::Backend::devices_ordinal_supported()`.
    290     // -1 indicates the default device should be used.
    291     int device_ordinal = -1;
    292 
    293     xla::Client* client = nullptr;
    294 
    295     // Function library in which to find function definitions. Must be non-null.
    296     const FunctionLibraryDefinition* flib_def = nullptr;
    297 
    298     // The graph def version to be compiled.
    299     int graph_def_version = TF_GRAPH_DEF_VERSION;
    300 
    301     // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall()
    302     // for CPU.
    303     bool allow_cpu_custom_calls = false;
    304 
    305     // If set, the XLA representation of variables represented to XLA as the
    306     // shape given by this shape function. Variables are reshaped to this shape
    307     // on write, and reshaped to their original shape on read.
    308     ShapeRepresentationFn shape_representation_fn;
    309 
    310     // If not nullptr, populate_resource_manager is called with the
    311     // compilation device's resource manager when the compilation
    312     // device is created, and can be used to create metadata objects
    313     // that can be accessed by XLA op kernels.
    314     std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
    315 
    316     // If not nullptr, this memory allocator can be used by the compiler for
    317     // temporary allocations it might want to make during compilation.
    318     //
    319     // For example, the compiler may want to try out different algorithms and
    320     // choose the fastest one, and it might run those algorithms over buffers
    321     // created using this allocator.
    322     //
    323     // The compiler can function correctly without an explicit allocator given
    324     // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly
    325     // allocate most or all available memory on the device, leaving none for the
    326     // compiler to access, unless it can use TensorFlow's allocator.
    327     xla::DeviceMemoryAllocator* device_allocator = nullptr;
    328   };
    329 
    330   explicit XlaCompiler(Options options);
    331 
    332   ~XlaCompiler();
    333 
    334   Status CompileFunction(const CompileOptions& options,
    335                          const NameAttrList& fn_name_attrs,
    336                          absl::Span<const Argument> args,
    337                          CompilationResult* result);
    338 
    339   // Compiles a tensorflow::Graph into an xla::XlaComputation.
    340   // Similar to CompileFunction, but takes a Graph as input rather than a
    341   // function.
    342   Status CompileGraph(
    343       const CompileOptions& options, string const& name,
    344       std::unique_ptr<Graph> graph, absl::Span<const Argument> args,
    345       absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,
    346       CompilationResult* result);
    347 
    348   // Compiles a single Op, given by `node_def`, into an
    349   // xla::XlaComputation. Similar to CompileFunction but takes a single Op as
    350   // input.
    351   Status CompileSingleOp(const CompileOptions& options, const NodeDef& node_def,
    352                          absl::Span<const Argument> args,
    353                          absl::Span<const DataType> result_types,
    354                          CompilationResult* result);
    355 
    356   // Returns the shape of the XLA parameter for an argument 'arg'.
    357   // See the class comment for more details about the argument passing
    358   // convention.
    359   Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation,
    360                              xla::Shape* xla_shape) const;
    361 
    362   // Retrieves the channel handle associated with `key`. Allocates
    363   // a new channel handle if none exists.
    364   // Channel handles can be used to communicate between different
    365   // computations. Computations that communicate should be compiled with the
    366   // same XlaCompiler.
    367   Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
    368 
    369   // Retrieves the host-to-device channel handle associated with `key`.
    370   // Allocates a new channel handle if none exists.
    371   Status GetHostToDeviceChannelHandle(const string& key,
    372                                       xla::ChannelHandle* channel);
    373 
    374   // Retrieves the device-to-host channel handle associated with `key`.
    375   // Allocates a new channel handle if none exists.
    376   Status GetDeviceToHostChannelHandle(const string& key,
    377                                       xla::ChannelHandle* channel);
    378 
    379   // Sets the shapes and types for the device to host transfer associated with
    380   // 'key'.
    381   Status SetDeviceToHostMetadata(const string& key,
    382                                  absl::Span<const DataType> types,
    383                                  absl::Span<const TensorShape> shapes);
    384 
    385   // Gets the shapes the device to host transfer associated with 'key'.
    386   Status GetDeviceToHostShapes(const string& key,
    387                                std::vector<TensorShape>* shapes) const;
    388 
    389   // Sets the shapes and types for the host to device transfer associated with
    390   // 'key'.
    391   Status SetHostToDeviceMetadata(const string& key,
    392                                  absl::Span<const DataType> types,
    393                                  absl::Span<const TensorShape> shapes);
    394 
    395   // In order to avoid deadlocks from dependencies in host computations, it can
    396   // be necessary to enforce a partial order on the execution of HostCompute
    397   // Ops. In particular it may be necessary to constrain the SendToHost for one
    398   // HostCompute to run before blocking on the RecvAtHost for another
    399   // HostCompute. The compiler maintains a mapping from 'host_compute_name' to
    400   // handle, where the handle is an 'output' of the HostCompute Op corresponding
    401   // to 'host_compute_name'. Another HostCompute Op that needs to be sequenced
    402   // later can add the handle as an 'input' to enforce the constraints.
    403   // 'host_compute_name' can be any string the client wishes to use to identify
    404   // a given HostCompute Op as long as the names are unique within the
    405   // compilation.
    406   Status GetHostComputeControlDependency(const string& host_compute_name,
    407                                          xla::XlaOp* handle);
    408   Status SetHostComputeControlDependency(const string& host_compute_name,
    409                                          const xla::XlaOp& handle);
    410 
    411   const Options& options() const { return options_; }
    412   xla::Client* client() const { return options_.client; }
    413   FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
    414 
    415   void PushNodeTokenMapping();
    416   Status PopNodeTokenMapping();
    417   Status SetNodeToken(const string& node_name, const xla::XlaOp& op);
    418   xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name);
    419 
    420   // Sets the function body `fbody` to the one registered as `function`.
    421   Status FindFunctionBody(const NameAttrList& function,
    422                           const FunctionBody** fbody);
    423 
    424  private:
    425   // Returns the optimized graph object in this function body.
    426   std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody);
    427 
    428   // Builds XLA computations for each of the arguments to the computation.
    429   // `args` are the arguments to the computation.
    430   Status BuildArguments(const Graph& graph,
    431                         const std::vector<XlaCompiler::Argument>& args,
    432                         bool use_tuple_arg, xla::XlaBuilder* builder,
    433                         XlaContext* context,
    434                         const std::map<int, int>& arg_cores,
    435                         std::vector<XlaExpression>* arg_expressions,
    436                         std::vector<int>* input_to_args,
    437                         std::vector<xla::Shape>* input_shapes,
    438                         bool is_entry_computation);
    439 
    440   // Graph compiler needs to know how to get an optimized graph from a function
    441   // body.
    442   friend class GraphCompiler;
    443   friend class XlaCompilerTest;
    444 
    445   Options options_;
    446 
    447   // Status set to non-OK in the constructor if initialization fails.
    448   Status initialization_status_;
    449 
    450   // Returns the next step sequence number.
    451   int64 NextStepId();
    452 
    453   // Internal sequence number for steps executed on the compilation device.
    454   int64 next_step_id_;
    455 
    456   XlaCompilationDevice* device_;  // Owned by device_mgr_
    457   DeviceMgr device_mgr_;
    458 
    459   // To avoid copying the client's function library, use a local function
    460   // library and runtime for functions created as part of the functionalize
    461   // control flow transformation.
    462   std::unique_ptr<FunctionLibraryDefinition> local_flib_def_;
    463   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
    464   std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_;
    465 
    466   FunctionLibraryRuntime* local_flib_runtime_;  // owned by local_pflr_.
    467   FunctionLibraryRuntime* flib_runtime_;        // owned by pflr_.
    468 
    469   struct SignatureHash {
    470     uint64 operator()(
    471         const std::pair<string, std::vector<Argument>>& signature) const;
    472   };
    473 
    474   std::unordered_map<std::pair<string, std::vector<Argument>>,
    475                      CompilationResult, SignatureHash>
    476       cache_;
    477 
    478   std::unordered_map<string, xla::ChannelHandle> channels_;
    479 
    480   std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_;
    481   std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_;
    482 
    483   std::unordered_map<string, xla::XlaOp> host_compute_control_output_;
    484 
    485   // This is used to store <node name, token output> mapping. Side-effecting
    486   // ops call SetNodeToken() to record its token output, so later side-effecting
    487   // ops can use GetNodeToken() to get it and use it as token input.
    488   //
    489   // It's a stack because we need a mapping like this for each level of nested
    490   // CompileGraph() call. In CompileGraph(), we will push a new mapping to the
    491   // stack, and pop the mapping before returning.
    492   std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_;
    493 
    494   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
    495 };
    496 
    497 }  // namespace tensorflow
    498 
    499 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
    500