1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file defines the contexts used during XLA compilation. 17 18 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ 19 #define TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ 20 21 #include <vector> 22 23 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" 24 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 25 #include "tensorflow/compiler/xla/client/computation.h" 26 #include "tensorflow/compiler/xla/client/computation_builder.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/resource_mgr.h" 30 #include "tensorflow/core/platform/macros.h" 31 32 namespace tensorflow { 33 34 class XlaOpKernelContext; 35 36 // The XlaContext is the data structure that holds the state of an XLA 37 // compilation, that is accessible from OpKernelContexts when compiling a 38 // subgraph of Ops using XLA. 39 class XlaContext : public ResourceBase { 40 public: 41 // Retrieves the XlaContext of the current compilation. 42 static XlaContext& Get(const OpKernelContext* ctx); 43 static XlaContext& Get(const XlaOpKernelContext* ctx); 44 45 // Creates a new XlaContext. 46 XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, 47 bool allow_cpu_custom_calls, bool resolve_compile_time_constants, 48 const std::function<TensorShape(const TensorShape&, DataType)>* 49 variable_representation_shape_fn); 50 51 // Virtual method defined by ResourceBase. 52 string DebugString() override; 53 54 XlaCompiler* compiler() const { return compiler_; } 55 56 // Returns the ComputationBuilder that Ops use for compiling new 57 // expressions. 58 xla::ComputationBuilder* builder(); 59 60 bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } 61 62 const std::vector<XlaExpression>& args() const { return args_; } 63 void set_args(std::vector<XlaExpression> args); 64 65 const std::vector<XlaExpression>& retvals() { return retvals_; } 66 67 // This is called by the Retval Op to associate a computed value 68 // with a specific return value of the subgraph. 69 void AddRetval(int retval_index, DataType type, 70 const xla::ComputationDataHandle& handle); 71 72 // As for Retval, but for return values that are compile-time constants. 73 Status AddConstRetval(int retval_index, DataType dtype, 74 const xla::Literal& literal); 75 76 // Creates a resource with resource `kind` and initial value `handle`. `name` 77 // is a descriptive name for use in error messages. See the `XlaResource` 78 // constructor for a description of the remaining arguments. 79 // Fails if the resource already exists. 80 Status CreateResource(XlaResource::Kind kind, int arg_num, string name, 81 DataType type, TensorShape shape, 82 const xla::ComputationDataHandle& handle, 83 int64 tensor_array_size, 84 const std::set<string>& tensor_array_gradients, 85 XlaResource** resource); 86 87 const std::vector<std::unique_ptr<XlaResource>>& resources() { 88 return resources_; 89 } 90 91 // Returns the XLA shape to be used to represent a variable of TF `shape` 92 // and `type`. 93 TensorShape VariableRepresentationShape(const TensorShape& shape, 94 DataType type) const; 95 96 // Get an XLA lambda to compute Max. This is cached in the 97 // XlaContext since it may be used by multiple Ops. There is a 98 // separate specialization of the computation for each DataType. 99 const xla::Computation* GetOrCreateMax(const DataType type); 100 101 // Get an XLA lambda to compute Min. This is cached in the 102 // XlaContext since it may be used by multiple Ops. There is a 103 // separate specialization of the computation for each DataType. 104 const xla::Computation* GetOrCreateMin(const DataType type); 105 106 // Get an XLA lambda to compute Add. This is cached in the 107 // XlaContext since it may be used by multiple Ops. There is a 108 // separate specialization of the computation for each DataType. 109 const xla::Computation* GetOrCreateAdd(const DataType type); 110 111 // Get an XLA lambda to compute Mul. This is cached in the 112 // XlaContext since it may be used by multiple Ops. There is a 113 // separate specialization of the computation for each DataType. 114 const xla::Computation* GetOrCreateMul(const DataType type); 115 116 // The name of the XlaContext resource during symbolic graph execution. 117 static const char kXlaContextResourceName[]; 118 119 private: 120 XlaCompiler* const compiler_; 121 122 // The ComputationBuilder used to construct the subgraph's compiled 123 // representation. 124 xla::ComputationBuilder* builder_; 125 126 // Allow ops to emit CustomCall operations for CPU. 127 const bool allow_cpu_custom_calls_; 128 129 // If true, constant return values are returned as Tensors instead of 130 // run-time computation outputs. 131 const bool resolve_compile_time_constants_; 132 133 // Arguments to the Tensorflow graph, indexed by _Arg index. 134 // Includes both compile-time constant arguments and runtime parameters. 135 std::vector<XlaExpression> args_; 136 137 // Return values of the Tensorflow graph, indexed by _Retval index. 138 std::vector<XlaExpression> retvals_; 139 140 // Holds ownership of resources. The resources are not ordered. 141 std::vector<std::unique_ptr<XlaResource>> resources_; 142 143 // A function that describes how variable shapes should be represented 144 // in XLA. Variable values will be reshaped to this shape. Must be non-null. 145 const std::function<TensorShape(const TensorShape&, DataType)>* 146 variable_representation_shape_fn_; 147 148 // Cache of prebuilt computations indexed by their type. 149 using ComputationMap = std::map<DataType, xla::Computation>; 150 151 // Finds the value for the given type in out map if it already 152 // exists or makes a new value with create function and keeps it the 153 // map. The returned value != nullptr and is owned by the map. 154 const xla::Computation* LookupOrCreate( 155 DataType type, ComputationMap* out, 156 const std::function<xla::Computation()>& create); 157 158 // Cached computation to compute Max of two elements, specialized by type. 159 ComputationMap max_func_; 160 161 // Cached computation to compute Min of two elements, specialized by type. 162 ComputationMap min_func_; 163 164 // Cached computation to compute Sum of two elements, specialized by type. 165 ComputationMap add_func_; 166 167 // Cached computation to compute Mul of two elements, specialized by type. 168 ComputationMap mul_func_; 169 170 // Cached computation to compute Sigmoid of an element, specialized by type. 171 ComputationMap sigmoid_func_; 172 173 TF_DISALLOW_COPY_AND_ASSIGN(XlaContext); 174 }; 175 176 } // namespace tensorflow 177 178 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ 179