Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // This file defines the contexts used during XLA compilation.
     17 
     18 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
     19 #define TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
     20 
     21 #include <vector>
     22 
     23 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
     24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
     25 #include "tensorflow/compiler/xla/client/computation.h"
     26 #include "tensorflow/compiler/xla/client/computation_builder.h"
     27 #include "tensorflow/compiler/xla/xla_data.pb.h"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/resource_mgr.h"
     30 #include "tensorflow/core/platform/macros.h"
     31 
     32 namespace tensorflow {
     33 
     34 class XlaOpKernelContext;
     35 
     36 // The XlaContext is the data structure that holds the state of an XLA
     37 // compilation, that is accessible from OpKernelContexts when compiling a
     38 // subgraph of Ops using XLA.
     39 class XlaContext : public ResourceBase {
     40  public:
     41   // Retrieves the XlaContext of the current compilation.
     42   static XlaContext& Get(const OpKernelContext* ctx);
     43   static XlaContext& Get(const XlaOpKernelContext* ctx);
     44 
     45   // Creates a new XlaContext.
     46   XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
     47              bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
     48              const std::function<TensorShape(const TensorShape&, DataType)>*
     49                  variable_representation_shape_fn);
     50 
     51   // Virtual method defined by ResourceBase.
     52   string DebugString() override;
     53 
     54   XlaCompiler* compiler() const { return compiler_; }
     55 
     56   // Returns the ComputationBuilder that Ops use for compiling new
     57   // expressions.
     58   xla::ComputationBuilder* builder();
     59 
     60   bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
     61 
     62   const std::vector<XlaExpression>& args() const { return args_; }
     63   void set_args(std::vector<XlaExpression> args);
     64 
     65   const std::vector<XlaExpression>& retvals() { return retvals_; }
     66 
     67   // This is called by the Retval Op to associate a computed value
     68   // with a specific return value of the subgraph.
     69   void AddRetval(int retval_index, DataType type,
     70                  const xla::ComputationDataHandle& handle);
     71 
     72   // As for Retval, but for return values that are compile-time constants.
     73   Status AddConstRetval(int retval_index, DataType dtype,
     74                         const xla::Literal& literal);
     75 
     76   // Creates a resource with resource `kind` and initial value `handle`. `name`
     77   // is a descriptive name for use in error messages. See the `XlaResource`
     78   // constructor for a description of the remaining arguments.
     79   // Fails if the resource already exists.
     80   Status CreateResource(XlaResource::Kind kind, int arg_num, string name,
     81                         DataType type, TensorShape shape,
     82                         const xla::ComputationDataHandle& handle,
     83                         int64 tensor_array_size,
     84                         const std::set<string>& tensor_array_gradients,
     85                         XlaResource** resource);
     86 
     87   const std::vector<std::unique_ptr<XlaResource>>& resources() {
     88     return resources_;
     89   }
     90 
     91   // Returns the XLA shape to be used to represent a variable of TF `shape`
     92   // and `type`.
     93   TensorShape VariableRepresentationShape(const TensorShape& shape,
     94                                           DataType type) const;
     95 
     96   // Get an XLA lambda to compute Max. This is cached in the
     97   // XlaContext since it may be used by multiple Ops. There is a
     98   // separate specialization of the computation for each DataType.
     99   const xla::Computation* GetOrCreateMax(const DataType type);
    100 
    101   // Get an XLA lambda to compute Min. This is cached in the
    102   // XlaContext since it may be used by multiple Ops. There is a
    103   // separate specialization of the computation for each DataType.
    104   const xla::Computation* GetOrCreateMin(const DataType type);
    105 
    106   // Get an XLA lambda to compute Add. This is cached in the
    107   // XlaContext since it may be used by multiple Ops. There is a
    108   // separate specialization of the computation for each DataType.
    109   const xla::Computation* GetOrCreateAdd(const DataType type);
    110 
    111   // Get an XLA lambda to compute Mul. This is cached in the
    112   // XlaContext since it may be used by multiple Ops. There is a
    113   // separate specialization of the computation for each DataType.
    114   const xla::Computation* GetOrCreateMul(const DataType type);
    115 
    116   // The name of the XlaContext resource during symbolic graph execution.
    117   static const char kXlaContextResourceName[];
    118 
    119  private:
    120   XlaCompiler* const compiler_;
    121 
    122   // The ComputationBuilder used to construct the subgraph's compiled
    123   // representation.
    124   xla::ComputationBuilder* builder_;
    125 
    126   // Allow ops to emit CustomCall operations for CPU.
    127   const bool allow_cpu_custom_calls_;
    128 
    129   // If true, constant return values are returned as Tensors instead of
    130   // run-time computation outputs.
    131   const bool resolve_compile_time_constants_;
    132 
    133   // Arguments to the Tensorflow graph, indexed by _Arg index.
    134   // Includes both compile-time constant arguments and runtime parameters.
    135   std::vector<XlaExpression> args_;
    136 
    137   // Return values of the Tensorflow graph, indexed by _Retval index.
    138   std::vector<XlaExpression> retvals_;
    139 
    140   // Holds ownership of resources. The resources are not ordered.
    141   std::vector<std::unique_ptr<XlaResource>> resources_;
    142 
    143   // A function that describes how variable shapes should be represented
    144   // in XLA. Variable values will be reshaped to this shape. Must be non-null.
    145   const std::function<TensorShape(const TensorShape&, DataType)>*
    146       variable_representation_shape_fn_;
    147 
    148   // Cache of prebuilt computations indexed by their type.
    149   using ComputationMap = std::map<DataType, xla::Computation>;
    150 
    151   // Finds the value for the given type in out map if it already
    152   // exists or makes a new value with create function and keeps it the
    153   // map. The returned value != nullptr and is owned by the map.
    154   const xla::Computation* LookupOrCreate(
    155       DataType type, ComputationMap* out,
    156       const std::function<xla::Computation()>& create);
    157 
    158   // Cached computation to compute Max of two elements, specialized by type.
    159   ComputationMap max_func_;
    160 
    161   // Cached computation to compute Min of two elements, specialized by type.
    162   ComputationMap min_func_;
    163 
    164   // Cached computation to compute Sum of two elements, specialized by type.
    165   ComputationMap add_func_;
    166 
    167   // Cached computation to compute Mul of two elements, specialized by type.
    168   ComputationMap mul_func_;
    169 
    170   // Cached computation to compute Sigmoid of an element, specialized by type.
    171   ComputationMap sigmoid_func_;
    172 
    173   TF_DISALLOW_COPY_AND_ASSIGN(XlaContext);
    174 };
    175 
    176 }  // namespace tensorflow
    177 
    178 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
    179