1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/xla_context.h" 17 18 #include <memory> 19 #include <utility> 20 #include <vector> 21 22 #include "tensorflow/compiler/tf2xla/literal_util.h" 23 #include "tensorflow/compiler/tf2xla/shape_util.h" 24 #include "tensorflow/compiler/tf2xla/type_util.h" 25 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 27 #include "tensorflow/compiler/xla/client/client_library.h" 28 #include "tensorflow/compiler/xla/client/computation_builder.h" 29 #include "tensorflow/compiler/xla/layout_util.h" 30 #include "tensorflow/compiler/xla/literal_util.h" 31 #include "tensorflow/compiler/xla/statusor.h" 32 #include "tensorflow/core/common_runtime/dma_helper.h" 33 #include "tensorflow/core/lib/gtl/array_slice.h" 34 #include "tensorflow/core/lib/strings/strcat.h" 35 #include "tensorflow/core/platform/logging.h" 36 37 namespace tensorflow { 38 39 const char XlaContext::kXlaContextResourceName[] = "_xla_context"; 40 41 // Looks up the context associated with the current step. It is stored 42 // in a resource container managed by the device. 43 /* static */ XlaContext& XlaContext::Get(const OpKernelContext* ctx) { 44 // When an Op kernel wants to use an XLA JIT context, the 45 // per-step context is looked up in the resource manager. The 46 // JIT will prepopulate the JITContext. 47 XlaContext* context; 48 TF_CHECK_OK(ctx->resource_manager()->Lookup( 49 ctx->step_container()->name(), kXlaContextResourceName, &context)); 50 // The resource manager handed us a fresh reference to 'context', but retains 51 // a reference itself so the context won't be freed. The resource manager will 52 // outlive the JIT compilation. 53 context->Unref(); 54 return *context; 55 } 56 57 /* static */ XlaContext& XlaContext::Get(const XlaOpKernelContext* ctx) { 58 return Get(ctx->op_kernel_context()); 59 } 60 61 void XlaContext::set_args(std::vector<XlaExpression> args) { 62 args_ = std::move(args); 63 } 64 65 XlaContext::XlaContext( 66 XlaCompiler* compiler, xla::ComputationBuilder* builder, 67 bool allow_cpu_custom_calls, bool resolve_compile_time_constants, 68 const std::function<TensorShape(const TensorShape&, DataType)>* 69 variable_representation_shape_fn) 70 : compiler_(compiler), 71 builder_(builder), 72 allow_cpu_custom_calls_(allow_cpu_custom_calls), 73 resolve_compile_time_constants_(resolve_compile_time_constants), 74 variable_representation_shape_fn_(variable_representation_shape_fn) {} 75 76 string XlaContext::DebugString() { return "TLA JIT context"; } 77 78 // This is called by the Retval Op to associate a computed value 79 // with a specific return value of the subgraph. 80 void XlaContext::AddRetval(int retval_index, DataType type, 81 const xla::ComputationDataHandle& handle) { 82 VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; 83 // Add the return value to the list being built up. 84 if (retvals_.size() <= retval_index) { 85 retvals_.resize(retval_index + 1); 86 } 87 retvals_[retval_index].set_handle(handle); 88 } 89 90 Status XlaContext::AddConstRetval(int retval_index, DataType dtype, 91 const xla::Literal& literal) { 92 VLOG(1) << "Adding retval index " << retval_index 93 << " with non-data-dependent tensor to XLA computation"; 94 if (retvals_.size() <= retval_index) { 95 retvals_.resize(retval_index + 1); 96 } 97 if (resolve_compile_time_constants_) { 98 Tensor value; 99 TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); 100 retvals_[retval_index].set_constant_value(std::move(value)); 101 } else { 102 retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal)); 103 } 104 return Status::OK(); 105 } 106 107 xla::ComputationBuilder* XlaContext::builder() { return builder_; } 108 109 Status XlaContext::CreateResource( 110 XlaResource::Kind kind, int arg_num, string name, DataType type, 111 TensorShape shape, const xla::ComputationDataHandle& handle, 112 int64 tensor_array_size, const std::set<string>& tensor_array_gradients, 113 XlaResource** resource) { 114 resources_.emplace_back( 115 new XlaResource(kind, arg_num, std::move(name), type, std::move(shape), 116 handle, tensor_array_size, tensor_array_gradients)); 117 *resource = resources_.back().get(); 118 return Status::OK(); 119 } 120 121 TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape, 122 DataType type) const { 123 return (*variable_representation_shape_fn_)(shape, type); 124 } 125 126 const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) { 127 return LookupOrCreate(type, &max_func_, [this, type] { 128 const string type_string = DataTypeString(type); 129 VLOG(1) << "Building Max() for " << type_string; 130 xla::ComputationBuilder b(builder()->client(), "max<" + type_string + ">"); 131 xla::PrimitiveType xla_type; 132 TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); 133 auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); 134 auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); 135 b.Max(x, y); 136 return b.Build().ConsumeValueOrDie(); 137 }); 138 } 139 140 const xla::Computation* XlaContext::GetOrCreateMin(const DataType type) { 141 return LookupOrCreate(type, &min_func_, [this, type] { 142 const string type_string = DataTypeString(type); 143 VLOG(1) << "Building Min() for " << type_string; 144 xla::ComputationBuilder b(builder()->client(), "min<" + type_string + ">"); 145 xla::PrimitiveType xla_type; 146 TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); 147 auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); 148 auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); 149 b.Min(x, y); 150 return b.Build().ConsumeValueOrDie(); 151 }); 152 } 153 154 const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { 155 return LookupOrCreate(type, &add_func_, [this, type] { 156 const string type_string = DataTypeString(type); 157 VLOG(1) << "Building Add() for " << type_string; 158 xla::ComputationBuilder b(builder()->client(), "add<" + type_string + ">"); 159 xla::PrimitiveType xla_type; 160 TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); 161 auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); 162 auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); 163 b.Add(x, y); 164 return b.Build().ConsumeValueOrDie(); 165 }); 166 } 167 168 const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) { 169 return LookupOrCreate(type, &mul_func_, [this, type] { 170 const string type_string = DataTypeString(type); 171 VLOG(1) << "Building Mul() for " << type_string; 172 xla::ComputationBuilder b(builder()->client(), "mul<" + type_string + ">"); 173 xla::PrimitiveType xla_type; 174 TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); 175 auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); 176 auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); 177 b.Mul(x, y); 178 return b.Build().ConsumeValueOrDie(); 179 }); 180 } 181 182 const xla::Computation* XlaContext::LookupOrCreate( 183 DataType type, ComputationMap* out, 184 const std::function<xla::Computation()>& create) { 185 { 186 const auto& entry = (*out)[type]; 187 if (!entry.IsNull()) { 188 return &entry; 189 } 190 } 191 auto new_entry = create(); 192 { 193 // Somebody else might have made one concurrently. 194 auto& entry = (*out)[type]; 195 if (entry.IsNull()) { 196 entry = std::move(new_entry); 197 } 198 return &entry; 199 } 200 } 201 202 } // namespace tensorflow 203