Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
     17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
     18 
     19 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
     20 #include "tensorflow/compiler/xla/client/local_client.h"
     21 #include "tensorflow/core/common_runtime/device.h"
     22 #include "tensorflow/core/common_runtime/device_mgr.h"
     23 #include "tensorflow/core/common_runtime/function.h"
     24 #include "tensorflow/core/framework/function.h"
     25 #include "tensorflow/core/platform/env.h"
     26 #include "tensorflow/core/platform/mutex.h"
     27 #include "tensorflow/core/platform/notification.h"
     28 #include "tensorflow/core/platform/thread_annotations.h"
     29 #include "tensorflow/core/public/version.h"
     30 
     31 namespace tensorflow {
     32 
     33 class XlaContext;
     34 
     35 // The XlaCompiler class is responsible for compilation of a self-contained
     36 // subgraph of a TensorFlow computation using the XLA linear algebra runtime.
     37 // It does a symbolic execution of the graph starting from specific input
     38 // shapes, using a JIT device to convert operators into XLA computations.
     39 //
     40 // XlaCompiler is typically invoked from an `_XlaLaunch` operator once the
     41 // shapes of all input parameters to the computation are known. This is
     42 // because the symbolic execution requires known shapes for all operations.
     43 //
     44 // XlaCompiler compiles Tensorflow graphs that received inputs via _Arg nodes,
     45 // and return outputs via _Retval nodes.
     46 //
     47 // The XlaCompiler requires one Argument struct for each _Arg index, that
     48 // describes each argument. Arguments can be compile-time constants
     49 // (kind kConstant), run-time parameters (kind kParameter), or resources
     50 // (kind kResource).
     51 //
     52 // Only kParameter and initialized kResource arguments become runtime parameters
     53 // to the generated XLA computation. The XLA computation will have run-time
     54 // parameters in the following order:
     55 //   +---------------------+-----------------------------------------+
     56 //   |  kParameter values  |  Initial values of kResource arguments  |
     57 //   +---------------------+-----------------------------------------+
     58 // Within each block, the arguments are arranged by the _Arg index from which
     59 // they were derived.
     60 //
     61 // The run-time outputs of the XLA computation are arranged in the following
     62 // order:
     63 //   +------------------+-----------------------------------------+
     64 //   |  _Retval values  |  Updated values of kResource arguments  |
     65 //   +------------------+-----------------------------------------+
     66 // _Retval values are ordered by _Retval index, whereas kResource values are
     67 // ordered by the original _Arg position of the variable.
     68 //
     69 // In both inputs and outputs, kResource values are placed the end. When
     70 // emitting While loop bodies, we must ensure that the loop body has
     71 // identical input and output signatures. By moving variable values
     72 // to the end of the argument list and using the
     73 // `return_updated_values_for_all_variables` option, we can ensure that the
     74 // input and output values of resources appear at the same positions.
     75 //
     76 // Resources are passed as parameters or returned as resource updates in
     77 // "packed" form.
     78 // kStack resources are packed as (array, size of stack) XLA tuples.
     79 // kTensorArray resources without gradients are packed as the array that
     80 // backs the TensorArray. If gradients are present (`tensor_array_gradients`),
     81 // the packed representation is a (array, gradient0, gradient1, ...) tuple,
     82 // where gradient_k is the value of the k-th gradient in the
     83 // `tensor_array_gradients` ordered set.
     84 class XlaCompiler {
     85  public:
     86   // Describes how to derive the value of each _Arg node in the graph/function
     87   // being compiled. There must be one Argument for each _Arg index.
     88   struct Argument {
     89     enum Kind {
     90       // Default value; not a valid kind.
     91       kInvalid,
     92 
     93       // Argument is a compile-time constant. No associated runtime parameter.
     94       kConstant,
     95 
     96       // Argument is a Variable, TensorArray, or Stack resource. Has an
     97       // associated runtime parameter iff `initialized` is true.
     98       kResource,
     99 
    100       // Argument is a run-time parameter.
    101       kParameter,
    102     };
    103 
    104     Kind kind = kInvalid;
    105 
    106     // The type of the argument. If the argument is a resource, this
    107     // is the type of the variable's value, not DT_RESOURCE.
    108     DataType type;
    109 
    110     // The shape of the argument. For:
    111     // * a parameter: the shape of the parameter.
    112     // * a constant: ignored; the shape given by constant_value is used
    113     //     instead.
    114     // * an uninitialized resource: ignored. We don't yet know the shape of an
    115     //     uninitialized resource (otherwise we would have initialized it!)
    116     // * an initialized variable: the shape of the variable's value.
    117     // * an initialized TensorArray or Stack resource: the shape of an entry in
    118     //   the TensorArray/Stack. Note this is the size of a single entry, not the
    119     //   XLA data structure that represents the complete stack/array.
    120     TensorShape shape;
    121 
    122     // The value of the argument, if it is a compile-time constant. Must be a
    123     // host-memory tensor.
    124     Tensor constant_value;
    125 
    126     // The name of this argument, used for debugging.
    127     string name;
    128 
    129     // For a kResource, what kind of resource is it?
    130     XlaResource::Kind resource_kind = XlaResource::kInvalid;
    131 
    132     // For a kResource, has this resource been initialized?
    133     bool initialized = false;
    134 
    135     // For a TensorArray or Stack resource, what is the array's declared size?
    136     // (Used for lazy initialization.)
    137     int64 tensor_array_size = -1;
    138 
    139     // TensorArray resource parameters are passed as (array, gradient array 0,
    140     // ..., gradient array k), where the gradient arrays are in the same order
    141     // as `tensor_array_gradients`.
    142     std::set<string> tensor_array_gradients;
    143 
    144     bool operator==(const Argument& other) const;
    145   };
    146 
    147   // Options pertaining to an individual call to CompileGraph() or
    148   // CompileFunction().
    149   struct CompileOptions {
    150     // If `use_tuple_arg` is true, a single tuple parameter will be used for all
    151     // arguments; if false, each argument gets its own parameter.
    152     bool use_tuple_arg = false;
    153 
    154     // If 'return_updated_values_for_all_resources' is true, then updated
    155     // values of all resource arguments will be included in the
    156     // 'resource_updates' of the computation, even if the resource was not
    157     // modified by the computation. Used when compiling loop bodies to ensure
    158     // the input and output signatures match.
    159     bool return_updated_values_for_all_resources = false;
    160 
    161     // If 'resolve_compile_time_constants' is true, then outputs of a
    162     // computation that are known to be compile-time constants will be returned
    163     // as Tensors at compile-time, rather than as run-time outputs of the
    164     // computation.
    165     bool resolve_compile_time_constants = true;
    166 
    167     // True when compiling the entry computation, false for subcomputations
    168     // (while, call, etc.)
    169     bool is_entry_computation = true;
    170   };
    171 
    172   struct OutputDescription {
    173     // Type and shape of the output.
    174     DataType type;
    175     TensorShape shape;
    176 
    177     // Constant output value, if known to be constant at JIT compilation time.
    178     // 'Tensor' is in host memory.
    179     bool is_constant = false;
    180     Tensor constant_value;
    181   };
    182 
    183   // Describes a variable write side effect of the computation.
    184   struct ResourceUpdate {
    185     // Index of the input that contains the variable resource to write to.
    186     int input_index;
    187 
    188     // Type and shape of the tensor to be written back.
    189     // The `shape` field has the same meaning as the Argument::shape field.
    190     DataType type;
    191     TensorShape shape;
    192 
    193     // Was the value of the variable modified by the computation?
    194     // (Always true, unless `return_updated_values_for_all_resources` is true.)
    195     bool modified;
    196 
    197     // If the resource is a TensorArray, the set of gradients read or written.
    198     std::set<string> tensor_array_gradients_accessed;
    199   };
    200 
    201   struct CompilationResult {
    202     // Vector that maps from the parameters of the XLA computation to their
    203     // original argument positions. To handle compile-time constant inputs and
    204     // resources, the parameters to the XLA computation may be a subset of the
    205     // original arguments, and are not necessarily in the same order.)
    206     std::vector<int> input_mapping;
    207 
    208     // Input shapes of the computation.
    209     std::vector<xla::Shape> xla_input_shapes;
    210 
    211     // Output shape in XLA format. The output shape is always a tuple.
    212     xla::Shape xla_output_shape;
    213 
    214     // TensorFlow shapes of outputs, together with the values of any
    215     // constant arguments. Vector indexed by Tensorflow _Retval number,
    216     // containing both constant and non-constant results.
    217     std::vector<OutputDescription> outputs;
    218 
    219     // Resources whose values were updated by the computation, ordered
    220     // by return value position. Resource updates follow the non-constant
    221     // results in the outputs of XLA computation.
    222     std::vector<ResourceUpdate> resource_updates;
    223 
    224     // The XLA computation built from the tensorflow subgraph.
    225     std::shared_ptr<xla::Computation> computation;
    226   };
    227 
    228   struct Options {
    229     // Name of the compilation device to use. Needs to be live only during
    230     // XlaCompiler's constructor.
    231     const DeviceType* device_type = nullptr;
    232 
    233     xla::Client* client = nullptr;
    234 
    235     // Function library in which to find function definitions. Must be non-null.
    236     const FunctionLibraryDefinition* flib_def = nullptr;
    237 
    238     // The graph def version to be compiled.
    239     int graph_def_version = TF_GRAPH_DEF_VERSION;
    240 
    241     // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall()
    242     // for CPU.
    243     bool allow_cpu_custom_calls = false;
    244 
    245     // If set, the XLA representation of variables represented to XLA as the
    246     // shape given by this shape function. Variables are reshaped to this shape
    247     // on write, and reshaped to their original shape on read.
    248     std::function<TensorShape(const TensorShape&, DataType)>
    249         variable_representation_shape_fn;
    250 
    251     // If not nullptr, populate_resource_manager is called with the
    252     // compilation device's resource manager when the compilation
    253     // device is created, and can be used to create metadata objects
    254     // that can be accessed by XLA op kernels.
    255     std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
    256 
    257     // If not nullptr, this memory allocator can be used by the compiler for
    258     // temporary allocations it might want to make during compilation.
    259     //
    260     // For example, the compiler may want to try out different algorithms and
    261     // choose the fastest one, and it might run those algorithms over buffers
    262     // created using this allocator.
    263     //
    264     // The compiler can function correctly without an explicit allocator given
    265     // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly
    266     // allocate most or all available memory on the device, leaving none for the
    267     // compiler to access, unless it can use TensorFlow's allocator.
    268     xla::DeviceMemoryAllocator* device_allocator = nullptr;
    269   };
    270 
    271   explicit XlaCompiler(Options options);
    272 
    273   ~XlaCompiler();
    274 
    275   Status CompileFunction(const CompileOptions& options,
    276                          const NameAttrList& fn_name_attrs,
    277                          std::vector<Argument> args, CompilationResult* result);
    278 
    279   // Compiles a tensorflow::Graph into an xla::Computation.
    280   // Similar to CompileFunction, but takes a Graph as input rather than a
    281   // function.
    282   Status CompileGraph(const CompileOptions& options, string const& name,
    283                       std::unique_ptr<Graph> graph,
    284                       const std::vector<Argument>& args,
    285                       CompilationResult* result);
    286 
    287   // Returns the shape of the XLA parameter for an argument 'arg'.
    288   // See the class comment for more details about the argument passing
    289   // convention.
    290   Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
    291 
    292   // Retrieves the channel handle associated with `key`. Allocates
    293   // a new channel handle if none exists.
    294   // Channel handles can be used to communicate between different
    295   // computations. Computations that communicate should be compiled with the
    296   // same XlaCompiler.
    297   Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
    298 
    299   const Options& options() const { return options_; }
    300   xla::Client* client() const { return options_.client; }
    301   FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
    302 
    303  private:
    304   // Sets the function body `fbody` to the one registered as `function`.
    305   Status FindFunctionBody(const NameAttrList& function,
    306                           const FunctionBody** fbody);
    307 
    308   // Returns the optimized graph object in this function body.
    309   std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody);
    310 
    311   // Builds XLA computations for each of the arguments to the computation.
    312   // `args` are the arguments to the computation.
    313   Status BuildArguments(const Graph& graph,
    314                         const std::vector<XlaCompiler::Argument>& args,
    315                         bool use_tuple_arg, xla::ComputationBuilder* builder,
    316                         XlaContext* context, std::vector<int>* arg_cores,
    317                         std::vector<XlaExpression>* arg_expressions,
    318                         std::vector<int>* input_mapping,
    319                         std::vector<xla::Shape>* input_shapes,
    320                         bool is_entry_computation);
    321 
    322   // Graph compiler needs to know how to get an optimized graph from a function
    323   // body.
    324   friend class GraphCompiler;
    325   friend class XlaCompilerTest;
    326 
    327   Options options_;
    328 
    329   // Status set to non-OK in the constructor if initialization fails.
    330   Status initialization_status_;
    331 
    332   // Returns the next step sequence number.
    333   int64 NextStepId();
    334 
    335   // Internal sequence number for steps executed on the compilation device.
    336   int64 next_step_id_;
    337 
    338   XlaCompilationDevice* device_;  // Owned by device_mgr_
    339   DeviceMgr device_mgr_;
    340 
    341   // To avoid copying the client's function library, use a local function
    342   // library and runtime for functions created as part of the functionalize
    343   // control flow transformation.
    344   std::unique_ptr<FunctionLibraryDefinition> local_flib_def_;
    345   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
    346   std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_;
    347 
    348   FunctionLibraryRuntime* local_flib_runtime_;  // owned by local_pflr_.
    349   FunctionLibraryRuntime* flib_runtime_;        // owned by pflr_.
    350 
    351   struct SignatureHash {
    352     uint64 operator()(
    353         const std::pair<string, std::vector<Argument>>& signature) const;
    354   };
    355 
    356   std::unordered_map<std::pair<string, std::vector<Argument>>,
    357                      CompilationResult, SignatureHash>
    358       cache_;
    359 
    360   std::unordered_map<string, xla::ChannelHandle> channels_;
    361 
    362   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
    363 };
    364 
    365 }  // namespace tensorflow
    366 
    367 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
    368