1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 17 18 #include <numeric> 19 #include <vector> 20 21 #include "absl/memory/memory.h" 22 #include "tensorflow/compiler/tf2xla/graph_compiler.h" 23 #include "tensorflow/compiler/tf2xla/shape_util.h" 24 #include "tensorflow/compiler/tf2xla/sharding_util.h" 25 #include "tensorflow/compiler/tf2xla/side_effect_util.h" 26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h" 27 #include "tensorflow/compiler/tf2xla/type_util.h" 28 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" 29 #include "tensorflow/compiler/tf2xla/xla_context.h" 30 #include "tensorflow/compiler/xla/client/client_library.h" 31 #include "tensorflow/compiler/xla/client/xla_builder.h" 32 #include "tensorflow/compiler/xla/client/xla_computation.h" 33 #include "tensorflow/compiler/xla/util.h" 34 #include "tensorflow/core/common_runtime/device.h" 35 #include "tensorflow/core/common_runtime/executor.h" 36 #include "tensorflow/core/common_runtime/function.h" 37 #include "tensorflow/core/common_runtime/graph_optimizer.h" 38 #include "tensorflow/core/framework/attr_value_util.h" 39 #include "tensorflow/core/framework/function.h" 40 #include "tensorflow/core/framework/node_def_util.h" 41 #include "tensorflow/core/framework/types.h" 42 #include "tensorflow/core/graph/algorithm.h" 43 #include "tensorflow/core/graph/graph_constructor.h" 44 #include "tensorflow/core/graph/node_builder.h" 45 #include "tensorflow/core/lib/core/error_codes.pb.h" 46 #include "tensorflow/core/lib/core/errors.h" 47 #include "tensorflow/core/lib/gtl/cleanup.h" 48 #include "tensorflow/core/lib/hash/hash.h" 49 #include "tensorflow/core/platform/logging.h" 50 #include "tensorflow/core/util/dump_graph.h" 51 52 namespace tensorflow { 53 namespace { 54 55 // Checks that arguments `args` match types `types`. 56 Status CheckSignature(const DataTypeVector& types, 57 absl::Span<const XlaCompiler::Argument> args) { 58 if (args.size() != types.size()) { 59 return errors::Internal("Compilation arguments have ", args.size(), 60 " elements while function has ", types.size()); 61 } 62 for (int i = 0; i < types.size(); ++i) { 63 // Don't perform type checks on resource variables and tensor 64 // lists (DT_VARIANT) as we have to trick the type system in order to 65 // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor. 66 if (types[i] != args[i].type && types[i] != DT_RESOURCE && 67 types[i] != DT_VARIANT) { 68 return errors::Internal( 69 "Argument ", i, " has declared type ", DataTypeString(args[i].type), 70 " but function parameter has type ", DataTypeString(types[i])); 71 } 72 } 73 return Status::OK(); 74 } 75 76 // Uses the _Arg and _Retval nodes in the graph to determine a core assignment 77 // for each argument and return value. 78 xla::StatusOr<std::pair<std::map<int, int>, std::map<int, int>>> 79 ComputeArgAndRetvalCores(const Graph& graph) { 80 auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr<int> { 81 TF_ASSIGN_OR_RETURN( 82 auto sharding, 83 ParseShardingFromDevice(*n, std::numeric_limits<int32>::max())); 84 if (sharding.has_value()) { 85 TF_RET_CHECK(sharding.value().type() == 86 xla::OpSharding::Type::OpSharding_Type_MAXIMAL); 87 return sharding.value().tile_assignment_devices(0); 88 } else { 89 return -1; 90 } 91 }; 92 std::map<int, int> arg_cores; 93 std::map<int, int> retval_cores; 94 for (const Node* n : graph.nodes()) { 95 if (n->IsArg()) { 96 TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); 97 if (core < 0) continue; 98 int index; 99 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); 100 TF_RET_CHECK(index >= 0) << "Negative _Arg index"; 101 arg_cores[index] = core; 102 } else if (n->IsRetval()) { 103 TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); 104 if (core < 0) continue; 105 int index; 106 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); 107 TF_RET_CHECK(index >= 0) << "Negative _Retval index"; 108 TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n)); 109 retval_cores[index] = core; 110 } 111 } 112 return std::make_pair(std::move(arg_cores), std::move(retval_cores)); 113 } 114 115 Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph, 116 XlaCompilationDevice* device, FunctionLibraryRuntime* flib, 117 int64 step_id) { 118 // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the 119 // resource manager takes ownership via Create, and unrefs via Cleanup. We 120 // explicitly add a reference to ensure the refcount at entry is maintained at 121 // all exit points; Create and Cleanup are always called in this function. 122 // 123 // The Executor requires us to use ScopedStepContainer. We wrap it in a 124 // unique_ptr so we can capture the cleanup status in the end. 125 xla_context->Ref(); 126 Status status; 127 auto step_container = absl::make_unique<ScopedStepContainer>( 128 step_id, [&status, device](const string& name) { 129 status = device->resource_manager()->Cleanup(name); 130 }); 131 TF_RETURN_IF_ERROR(device->resource_manager()->Create( 132 step_container->name(), XlaContext::kXlaContextResourceName, 133 xla_context)); 134 135 GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); 136 TF_RETURN_IF_ERROR(graph_compiler.Compile()); 137 // Explicitly clean up the step container, to capture the cleanup status. 138 step_container.reset(); 139 return Status::OK(); 140 } 141 142 // Builds the XLA computation. 143 // - `args` is the list of input arguments 144 // - `retvals` is the list of retvals produced by _Retval operators, in index 145 // order. 146 // - `args_core` and `retval_cores` are mapping from arg/return indices to core 147 // assignments. 148 // - If `return_updated_values_for_all_resources` is true, all resources will be 149 // included in `resource_updates`, regardless of whether their value changed. 150 // - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. 151 // - Sets `*resource_updates` to a description of resources whose values are 152 // written by the computation; the variable writes are the last 153 // - `resource_updates.size()` return values from the computation. Each entry in 154 // `resource_updates` is a ResourceUpdate, whose `index` is the index of a 155 // resource variable argument to the computation to be updated, and `type` is 156 // the type of the final output. 157 Status BuildComputation( 158 const std::vector<XlaCompiler::Argument>& args, 159 const std::vector<XlaExpression>& retvals, 160 const std::map<int, int>& arg_cores, const std::map<int, int>& retval_cores, 161 const std::vector<std::unique_ptr<XlaResource>>& resources, 162 std::unique_ptr<xla::XlaOp> token_output, 163 const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, 164 bool return_updated_values_for_all_resources, bool always_return_tuple, 165 xla::XlaBuilder* builder, xla::XlaComputation* computation, 166 int* num_computation_outputs, int* num_nonconst_outputs, 167 std::vector<XlaCompiler::OutputDescription>* outputs, 168 std::vector<XlaCompiler::ResourceUpdate>* resource_updates, 169 xla::Shape* output_shape) { 170 // Attach a common operator name as metadata. This has no semantic effect it 171 // merely makes the HLO graph more readable when visualized via TensorBoard, 172 // since TensorBoard forms groups out of operators with similar names. 173 xla::OpMetadata retval_metadata; 174 retval_metadata.set_op_name("XLA_Retvals"); 175 builder->SetOpMetadata(retval_metadata); 176 auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); }); 177 178 // Builds a no-op XLA computation. We need to set the sharding of outputs, but 179 // cannot change the sharding of the existing output op. To do this, we build 180 // a new identity op to which shardings can be applied. 181 auto identity_op = [builder](xla::XlaOp op) { 182 return xla::GetTupleElement(xla::Tuple(builder, {op}), 0); 183 }; 184 185 std::vector<xla::XlaOp> elems; 186 elems.reserve(retvals.size()); 187 188 // Keeps track of the layout of each retval. If a retval is not in this list, 189 // a descending layout is used. The first element is the output index, second 190 // element is the new layout. 191 std::vector<std::pair<int64, xla::Layout>> retval_index_and_layout; 192 for (int i = 0; i < retvals.size(); ++i) { 193 XlaCompiler::OutputDescription& output = (*outputs)[i]; 194 const XlaExpression& retval = retvals[i]; 195 output.type = retval.dtype(); 196 switch (retval.kind()) { 197 case XlaExpression::Kind::kConstant: 198 output.is_constant = true; 199 output.constant_value = retval.constant_value(); 200 output.shape = output.constant_value.shape(); 201 break; 202 203 case XlaExpression::Kind::kTensorList: 204 TF_FALLTHROUGH_INTENDED; 205 case XlaExpression::Kind::kXlaOp: { 206 output.is_constant = false; 207 TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); 208 xla::XlaOp value = retval.handle(); 209 auto it = retval_cores.find(i); 210 xla::XlaScopedShardingAssignment assign_sharding( 211 builder, it == retval_cores.end() 212 ? absl::optional<xla::OpSharding>() 213 : xla::sharding_builder::AssignDevice(it->second)); 214 if (shape_representation_fn) { 215 // If there is a shape representation function, reshape the output 216 // tensor to the shape given by the representation shape function. 217 TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( 218 output.shape, output.type)); 219 value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); 220 retval_index_and_layout.emplace_back(elems.size(), shape.layout()); 221 } else if (it != retval_cores.end()) { 222 // Apply the sharding to the output, if there is a core assignment. 223 value = identity_op(value); 224 } 225 226 elems.push_back(value); 227 break; 228 } 229 230 case XlaExpression::Kind::kResource: 231 output.is_constant = false; 232 output.input_index = retval.resource()->arg_num(); 233 output.shape = retval.resource()->shape(); 234 break; 235 236 case XlaExpression::Kind::kInvalid: 237 return errors::InvalidArgument( 238 "Invalid expression returned by computation. " 239 "This probably means a return value was not set."); 240 } 241 } 242 *num_nonconst_outputs = elems.size(); 243 244 // Add return values for resources whose values have changed. 245 std::vector<const XlaResource*> arg_resources; 246 arg_resources.reserve(resources.size()); 247 for (const auto& resource : resources) { 248 if (resource->arg_num() >= 0) { 249 arg_resources.push_back(resource.get()); 250 } 251 } 252 std::sort(arg_resources.begin(), arg_resources.end(), 253 [](const XlaResource* a, const XlaResource* b) { 254 return a->arg_num() < b->arg_num(); 255 }); 256 257 for (const XlaResource* resource : arg_resources) { 258 DCHECK_LT(resource->arg_num(), args.size()); 259 const XlaCompiler::Argument& arg = args[resource->arg_num()]; 260 auto it = arg_cores.find(resource->arg_num()); 261 const int core = it == arg_cores.end() ? -1 : it->second; 262 bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); 263 // TensorArray gradients were modified if their values changed or there are 264 // any newly created gradients. 265 for (const auto& grad : resource->tensor_array_gradients()) { 266 modified = 267 modified || 268 !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || 269 arg.tensor_array_gradients.count(grad.first) == 0; 270 } 271 if (return_updated_values_for_all_resources || modified) { 272 resource_updates->emplace_back(); 273 XlaCompiler::ResourceUpdate& update = resource_updates->back(); 274 update.input_index = resource->arg_num(); 275 update.type = resource->type(); 276 update.shape = resource->shape(); 277 update.modified = modified; 278 for (const auto& grad : resource->tensor_array_gradients()) { 279 update.tensor_array_gradients_accessed.insert(grad.first); 280 } 281 282 // Request that the value be returned on a specific core. 283 xla::XlaScopedShardingAssignment assign_sharding( 284 builder, core == -1 ? absl::optional<xla::OpSharding>() 285 : xla::sharding_builder::AssignDevice(core)); 286 287 xla::XlaOp handle; 288 TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); 289 290 // Ensures the correct sharding is applied to the output. 291 handle = identity_op(handle); 292 293 // Set layout of the retval to device representation layout. 294 if (resource->representation_shape().has_value()) { 295 retval_index_and_layout.emplace_back( 296 elems.size(), resource->representation_shape()->layout()); 297 } 298 elems.push_back(handle); 299 } 300 } 301 302 // If we have token output, append it as the last one. 303 if (token_output) { 304 elems.push_back(*token_output); 305 } 306 307 *num_computation_outputs = elems.size(); 308 309 // Builds the XLA computation. We *always* form a tuple here to ensure that 310 // the output value is the last thing added into the XLA computation, even 311 // if there is only one output value. 312 auto tuple = xla::Tuple(builder, elems); 313 if (!always_return_tuple && elems.size() == 1) { 314 xla::GetTupleElement(tuple, 0); 315 } 316 317 xla::StatusOr<xla::XlaComputation> computation_status = builder->Build(); 318 if (!computation_status.ok()) { 319 return computation_status.status(); 320 } 321 *computation = computation_status.ConsumeValueOrDie(); 322 323 TF_ASSIGN_OR_RETURN(const auto& program_shape, 324 computation->GetProgramShape()); 325 *output_shape = program_shape.result(); 326 // Update the output layout to the layout of retval. 327 for (auto& index_and_layout : retval_index_and_layout) { 328 if (!always_return_tuple && elems.size() == 1) { 329 *output_shape->mutable_layout() = index_and_layout.second; 330 continue; 331 } 332 333 xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape( 334 output_shape, {index_and_layout.first}); 335 *output_sub_shape->mutable_layout() = index_and_layout.second; 336 } 337 return Status::OK(); 338 } 339 340 } // namespace 341 342 bool XlaCompiler::Argument::operator==( 343 const XlaCompiler::Argument& other) const { 344 if (std::tie(kind, resource_kind, type, name, initialized, max_array_size, 345 tensor_array_gradients) != 346 std::tie(other.kind, other.resource_kind, other.type, other.name, 347 other.initialized, other.max_array_size, 348 other.tensor_array_gradients)) { 349 return false; 350 } 351 if (absl::holds_alternative<xla::Shape>(shape)) { 352 if (!absl::holds_alternative<xla::Shape>(other.shape)) { 353 return false; 354 } 355 if (!xla::Shape::Equal()(absl::get<xla::Shape>(shape), 356 absl::get<xla::Shape>(other.shape))) { 357 return false; 358 } 359 } else { 360 if (!absl::holds_alternative<TensorShape>(other.shape)) { 361 return false; 362 } 363 if (absl::get<TensorShape>(shape) != absl::get<TensorShape>(other.shape)) { 364 return false; 365 } 366 } 367 if (constant_value.shape() != other.constant_value.shape()) { 368 return false; 369 } 370 return constant_value.tensor_data() == other.constant_value.tensor_data(); 371 } 372 373 string XlaCompiler::Argument::HumanString() const { 374 string common; 375 if (!name.empty()) { 376 common = absl::StrCat(" name=", name); 377 } 378 absl::StrAppend(&common, " type=", DataTypeString(type), 379 " shape=", ShapeHumanString()); 380 switch (kind) { 381 case kInvalid: 382 return "invalid"; 383 case kConstant: 384 return absl::StrCat("kind=constant", common, 385 " value=", constant_value.DebugString()); 386 case kResource: { 387 string output = absl::StrCat("kind=resource", common, " resource_kind=", 388 XlaResource::KindToString(resource_kind), 389 " initialized=", initialized); 390 if (max_array_size >= 0) { 391 absl::StrAppend(&output, " max_array_size=", max_array_size); 392 } 393 if (!tensor_array_gradients.empty()) { 394 absl::StrAppend(&output, " tensor_array_gradients=", 395 absl::StrJoin(tensor_array_gradients, ",")); 396 } 397 return output; 398 } 399 case kParameter: 400 return absl::StrCat("kind=parameter", common); 401 case kToken: 402 return absl::StrCat("token", common); 403 } 404 } 405 406 std::vector<int64> XlaCompiler::Argument::DimensionSizes() const { 407 if (absl::holds_alternative<TensorShape>(shape)) { 408 return xla::InlinedVectorToVector( 409 absl::get<TensorShape>(shape).dim_sizes()); 410 } else { 411 return absl::get<xla::Shape>(shape).dimensions(); 412 } 413 } 414 415 string XlaCompiler::Argument::ShapeHumanString() const { 416 if (absl::holds_alternative<TensorShape>(shape)) { 417 return absl::get<TensorShape>(shape).DebugString(); 418 } else { 419 return absl::get<xla::Shape>(shape).DebugString(); 420 } 421 } 422 423 XlaCompiler::XlaCompiler(XlaCompiler::Options options) 424 : options_(options), 425 initialization_status_(Status::OK()), 426 next_step_id_(1), 427 device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), 428 device_mgr_(absl::WrapUnique(device_)) { 429 CHECK(!options_.device_type.type_string().empty()); 430 if (options_.populate_resource_manager) { 431 initialization_status_ = 432 (*options_.populate_resource_manager)(device_->resource_manager()); 433 } 434 435 local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), 436 FunctionDefLibrary{})); 437 local_pflr_.reset(new ProcessFunctionLibraryRuntime( 438 &device_mgr_, Env::Default(), options.graph_def_version, 439 local_flib_def_.get(), OptimizerOptions(), 440 nullptr /* custom_kernel_creator */)); 441 pflr_.reset(new ProcessFunctionLibraryRuntime( 442 &device_mgr_, Env::Default(), options.graph_def_version, options.flib_def, 443 OptimizerOptions(), nullptr /* custom_kernel_creator */)); 444 445 local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); 446 flib_runtime_ = pflr_->GetFLR(device_->name()); 447 448 // The default shape representation function is the identity. 449 if (!options_.shape_representation_fn) { 450 options_.shape_representation_fn = 451 [](const TensorShape& shape, 452 DataType dtype) -> xla::StatusOr<xla::Shape> { 453 xla::Shape xla_shape; 454 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); 455 return xla_shape; 456 }; 457 } 458 } 459 460 XlaCompiler::~XlaCompiler() = default; 461 462 int64 XlaCompiler::NextStepId() { return next_step_id_++; } 463 464 uint64 XlaCompiler::SignatureHash::operator()( 465 const std::pair<string, std::vector<Argument>>& signature) const { 466 return std::hash<string>()(signature.first); 467 } 468 469 static Status GetFunctionBody(const NameAttrList& function, 470 FunctionLibraryRuntime* flib_runtime, 471 const FunctionBody** fbody) { 472 FunctionLibraryRuntime::Handle handle; 473 TF_RETURN_IF_ERROR(flib_runtime->Instantiate( 474 function.name(), AttrSlice(&function.attr()), &handle)); 475 476 *fbody = flib_runtime->GetFunctionBody(handle); 477 TF_RET_CHECK(*fbody); 478 return Status::OK(); 479 } 480 481 Status XlaCompiler::FindFunctionBody(const NameAttrList& function, 482 const FunctionBody** fbody) { 483 // The function may be in either the local_flib_runtime_ or flib_runtime_. 484 // Look up the function in local first and if it is not found then look up the 485 // function in flib_runtime_. 486 auto status = GetFunctionBody(function, local_flib_runtime_, fbody); 487 if (!status.ok()) { 488 if (!errors::IsNotFound(status)) { 489 return status; 490 } 491 TF_RETURN_WITH_CONTEXT_IF_ERROR( 492 GetFunctionBody(function, flib_runtime_, fbody), 493 "Local lookup failed with: ", status.error_message()); 494 VLOG(4) << "Function " << function.name() << " in flib_runtime_"; 495 } else { 496 VLOG(4) << "Function " << function.name() << " in local_flib_runtime_"; 497 } 498 return Status::OK(); 499 } 500 501 std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) { 502 std::unique_ptr<Graph> graph(new Graph(options_.flib_def)); 503 CopyGraph(*fbody->graph, graph.get()); 504 OptimizerOptions opts; 505 opts.set_opt_level(OptimizerOptions::L0); 506 opts.set_do_common_subexpression_elimination(false); 507 opts.set_do_function_inlining(true); 508 opts.set_do_constant_folding(true); 509 GraphOptimizer optimizer(opts); 510 // Do not constant fold nodes that output DT_VARIANT type tensors. 511 // XLA does not support Const nodes of Variant type since it needs 512 // to know the original ops to be able to compile them to the relevant 513 // XLA form. 514 // TODO(srbs): This filter is a little conservative. E.g. a subgraph of 515 // the form: 516 // Const 517 // | 518 // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op 519 // | 520 // (Discard popped list) 521 // 522 // Would have been reduced to "Const -> Op" without this filter. 523 // However since we are only allowed to specify the filter at the "Node" 524 // level there is no good way to allow the above behavior. So we 525 // disallow any sort of constant folding on Variant nodes for now. 526 auto cf_consider_fn = [](const Node* n) { 527 for (const auto& output_arg : n->op_def().output_arg()) { 528 if (output_arg.type() == DT_VARIANT) { 529 return false; 530 } 531 } 532 return true; 533 }; 534 GraphOptimizer::Options graph_optimizer_options; 535 graph_optimizer_options.cf_consider_fn = cf_consider_fn; 536 optimizer.Optimize(flib_runtime_, flib_runtime_->env(), 537 /*device=*/nullptr, &graph, graph_optimizer_options); 538 539 return graph; 540 } 541 542 Status XlaCompiler::CompileFunction( 543 const XlaCompiler::CompileOptions& options, const NameAttrList& function, 544 absl::Span<const XlaCompiler::Argument> args, 545 XlaCompiler::CompilationResult* result) { 546 const string function_id = 547 Canonicalize(function.name(), AttrSlice(&function.attr())); 548 VLOG(1) << "XlaCompiler::CompileFunction " << function_id; 549 550 const std::vector<XlaCompiler::Argument> arg_vector(args.begin(), args.end()); 551 auto it = cache_.find({function_id, arg_vector}); 552 if (it != cache_.end()) { 553 *result = it->second; 554 return Status::OK(); 555 } 556 557 const FunctionBody* fbody; 558 TF_RETURN_IF_ERROR(FindFunctionBody(function, &fbody)); 559 560 TF_RETURN_WITH_CONTEXT_IF_ERROR( 561 CheckSignature(fbody->arg_types, args), 562 "Signature check failure while compiling: ", function.name()); 563 564 std::unique_ptr<Graph> graph = GetGraph(fbody); 565 566 // Clear the "_kernel" attribute if it is set to "host". This is used to 567 // indicate that a computation should happen on the host instead of the 568 // accelerator, but doesn't make sense in XLA. 569 const char* const kKernelAttr = "_kernel"; 570 for (Node* n : graph->nodes()) { 571 string value; 572 if (GetNodeAttrSimple(n->attrs(), kKernelAttr, &value) && value == "host") { 573 n->ClearAttr(kKernelAttr); 574 } 575 } 576 577 // _Arg and _Retval nodes don't exist in the stored subgraph for the function; 578 // they are added by the function body looked up. Therefore, they don't have 579 // core assignments here. 580 // Attempt to assign a core to each _Retval and _Arg. Chooses the 581 // lowest-numbered core that consumes the argument. We choose the 582 // lowest-numbered core so the assignment is deterministic. 583 for (Node* n : graph->nodes()) { 584 if (n->IsArg()) { 585 TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); 586 } 587 } 588 // Do _Retval as a second loop, in case the retval's input is an _Arg (which 589 // may have gotten a device assignment from the first loop). 590 for (Node* n : graph->nodes()) { 591 if (n->IsRetval()) { 592 TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); 593 } 594 } 595 596 if (VLOG_IS_ON(2)) { 597 VLOG(2) << "XlaCompiler::CompileFunction: " 598 << DumpGraphToFile( 599 absl::StrCat("xla_compile_function_", function_id), *graph); 600 } 601 602 VLOG(1) << "===================================================="; 603 TF_RETURN_IF_ERROR( 604 CompileGraph(options, function_id, std::move(graph), args, {}, result)); 605 VLOG(1) << "===================================================="; 606 607 cache_[{function_id, arg_vector}] = *result; 608 return Status::OK(); 609 } 610 611 // Computes the XLA shape for argument 'arg'. 612 Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, 613 bool is_entry_computation, 614 xla::Shape* xla_shape) const { 615 switch (arg.kind) { 616 case XlaCompiler::Argument::kConstant: 617 LOG(FATAL) << "Unreachable case"; 618 case XlaCompiler::Argument::kParameter: { 619 if (is_entry_computation) { 620 TensorShape shape; 621 if (absl::holds_alternative<TensorShape>(arg.shape)) { 622 shape = absl::get<TensorShape>(arg.shape); 623 } else { 624 TF_RETURN_IF_ERROR( 625 XLAShapeToTensorShape(absl::get<xla::Shape>(arg.shape), &shape)); 626 } 627 TF_ASSIGN_OR_RETURN(*xla_shape, 628 options_.shape_representation_fn(shape, arg.type)); 629 } else { 630 if (absl::holds_alternative<xla::Shape>(arg.shape)) { 631 *xla_shape = absl::get<xla::Shape>(arg.shape); 632 } else { 633 TF_RETURN_IF_ERROR(TensorShapeToXLAShape( 634 arg.type, absl::get<TensorShape>(arg.shape), xla_shape)); 635 } 636 } 637 return Status::OK(); 638 } 639 case XlaCompiler::Argument::kResource: { 640 TF_RET_CHECK(arg.initialized); 641 642 switch (arg.resource_kind) { 643 case XlaResource::kVariable: { 644 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape)); 645 TF_ASSIGN_OR_RETURN(*xla_shape, 646 options_.shape_representation_fn( 647 absl::get<TensorShape>(arg.shape), arg.type)); 648 649 return Status::OK(); 650 } 651 case XlaResource::kTensorArray: { 652 if (arg.max_array_size < 0) { 653 return errors::InvalidArgument( 654 "Negative max_array_size in XLAShapeForArgument"); 655 } 656 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape)); 657 TensorShape shape; 658 shape.AddDim(arg.max_array_size); 659 shape.AppendShape(absl::get<TensorShape>(arg.shape)); 660 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape)); 661 662 if (!arg.tensor_array_gradients.empty()) { 663 std::vector<xla::Shape> tuple_shape( 664 arg.tensor_array_gradients.size() + 1, *xla_shape); 665 *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape); 666 } 667 return Status::OK(); 668 } 669 case XlaResource::kStack: { 670 if (arg.max_array_size < 0) { 671 return errors::InvalidArgument( 672 "Negative max_array_size in XLAShapeForArgument"); 673 } 674 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape)); 675 TensorShape shape; 676 shape.AddDim(arg.max_array_size); 677 shape.AppendShape(absl::get<TensorShape>(arg.shape)); 678 xla::Shape buffer_shape; 679 TF_RETURN_IF_ERROR( 680 TensorShapeToXLAShape(arg.type, shape, &buffer_shape)); 681 *xla_shape = xla::ShapeUtil::MakeTupleShape( 682 {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})}); 683 return Status::OK(); 684 } 685 686 case XlaResource::kInvalid: 687 return errors::Internal( 688 "Invalid resource type in XLAShapeForArgument()"); 689 } 690 } 691 case XlaCompiler::Argument::kToken: { 692 *xla_shape = xla::ShapeUtil::MakeTokenShape(); 693 return Status::OK(); 694 } 695 case XlaCompiler::Argument::kInvalid: 696 return errors::Internal("Invalid argument type in XLAShapeForArgument()"); 697 } 698 } 699 700 // Builds XLA computations for each of the arguments to the computation. 701 // `args` are the arguments to the computation. 702 Status XlaCompiler::BuildArguments( 703 const Graph& graph, const std::vector<XlaCompiler::Argument>& args, 704 bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, 705 const std::map<int, int>& arg_cores, 706 std::vector<XlaExpression>* arg_expressions, 707 std::vector<int>* input_to_args, std::vector<xla::Shape>* input_shapes, 708 bool is_entry_computation) { 709 arg_expressions->resize(args.size()); 710 711 // Argument numbers of arguments and resources that are to be passed to the 712 // XLA computation as runtime parameters. `input_to_args[a] = b` means that 713 // the a'th XLA input corresponds to the b'th original arg indexes. 714 input_to_args->clear(); 715 input_to_args->reserve(args.size()); 716 717 // Fills in constant arguments, and computes non-constant argument order. 718 for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size(); 719 ++i) { 720 const XlaCompiler::Argument& arg = args[i]; 721 XlaExpression& arg_expression = (*arg_expressions)[i]; 722 switch (arg.kind) { 723 case XlaCompiler::Argument::kResource: { 724 TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid); 725 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape)); 726 // TODO(phawkins): this code assumes that resource arguments do not 727 // alias. 728 XlaResource* resource = 729 context->AddResource(absl::make_unique<XlaResource>( 730 arg.resource_kind, i, arg.name, arg.type, 731 absl::get<TensorShape>(arg.shape), xla::XlaOp(), 732 /*max_array_size=*/arg.max_array_size, 733 /*tensor_array_gradients=*/arg.tensor_array_gradients, 734 /*tensor_array_multiple_writes_aggregate=*/true)); 735 arg_expression = XlaExpression::Resource(resource); 736 if (arg.initialized) { 737 input_to_args->push_back(i); 738 } 739 break; 740 } 741 case XlaCompiler::Argument::kParameter: 742 case XlaCompiler::Argument::kToken: { 743 input_to_args->push_back(i); 744 break; 745 } 746 case XlaCompiler::Argument::kConstant: 747 arg_expression = XlaExpression::Constant(arg.constant_value); 748 break; 749 case XlaCompiler::Argument::kInvalid: 750 return errors::Internal( 751 "Unreachable case in BuildArguments() while filling constant args"); 752 } 753 } 754 755 if (input_to_args->empty()) { 756 return Status::OK(); 757 } 758 759 // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds 760 // to the d'th XLA input. Note that the value -1 corresponds to constants, or 761 // other args that don't correspond to an input. 762 std::vector<int> arg_to_inputs(args.size(), -1); 763 for (int i = 0; i < input_to_args->size(); i++) { 764 arg_to_inputs[input_to_args->at(i)] = i; 765 } 766 767 std::vector<xla::Shape> arg_shapes(input_to_args->size()); 768 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) { 769 // Computes the shapes of non-constant arguments. 770 TF_RETURN_IF_ERROR(XLAShapeForArgument( 771 args[(*input_to_args)[i]], is_entry_computation, &arg_shapes[i])); 772 } 773 774 if (use_tuple_arg) { 775 input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes)); 776 } else { 777 *input_shapes = arg_shapes; 778 } 779 780 // Attach a common operator name as metadata. This has no semantic effect it 781 // merely makes the HLO graph more readable when visualized via TensorBoard, 782 // since TensorBoard forms groups out of operators with similar names. 783 xla::OpMetadata arg_metadata; 784 arg_metadata.set_op_name("XLA_Args"); 785 builder->SetOpMetadata(arg_metadata); 786 787 // Build parameter handles for non-constant arguments. 788 std::vector<xla::XlaOp> arg_handles(input_to_args->size()); 789 if (use_tuple_arg) { 790 xla::XlaOp tuple; 791 if (is_entry_computation) { 792 xla::OpSharding tuple_sharding; 793 tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); 794 for (int64 parameter : *input_to_args) { 795 auto it = arg_cores.find(parameter); 796 const int core = it == arg_cores.end() ? 0 : it->second; 797 *tuple_sharding.add_tuple_shardings() = 798 xla::sharding_builder::AssignDevice(core); 799 } 800 xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, 801 tuple_sharding); 802 tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); 803 } else { 804 tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); 805 } 806 807 for (int i = 0; i < input_to_args->size(); ++i) { 808 const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; 809 for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { 810 int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); 811 TF_RETURN_IF_ERROR(builder->SetDynamicBinding( 812 /*dynamic_size_param_num=*/0, {dynamic_size_param_index}, 813 /*target_param_num=*/0, /*target_param_index=*/{i}, 814 dim_and_arg_num.first)); 815 } 816 } 817 818 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) { 819 auto it = arg_cores.find(i); 820 const int core = it == arg_cores.end() ? -1 : it->second; 821 xla::XlaScopedShardingAssignment assign_sharding( 822 builder, core == -1 ? absl::optional<xla::OpSharding>() 823 : xla::sharding_builder::AssignDevice(core)); 824 arg_handles[i] = xla::GetTupleElement(tuple, i); 825 } 826 } else { 827 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) { 828 auto it = arg_cores.find(i); 829 const int core = it == arg_cores.end() ? -1 : it->second; 830 xla::XlaScopedShardingAssignment assign_sharding( 831 builder, core == -1 ? absl::optional<xla::OpSharding>() 832 : xla::sharding_builder::AssignDevice(core)); 833 arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], 834 absl::StrCat("arg", i)); 835 } 836 837 for (int i = 0; i < input_to_args->size(); ++i) { 838 const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; 839 for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { 840 int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); 841 TF_RETURN_IF_ERROR(builder->SetDynamicBinding( 842 /*dynamic_size_param_num=*/dynamic_size_param_index, {}, 843 /*target_param_num=*/i, /*target_param_index=*/{}, 844 dim_and_arg_num.first)); 845 } 846 } 847 } 848 849 builder->ClearOpMetadata(); 850 851 // Fill in the handles in non-constant arguments, and reshape parameters 852 // back to their correct shapes. 853 VLOG(2) << "XLA computation inputs:"; 854 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) { 855 const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; 856 VLOG(2) << " XLA arg " << i 857 << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) 858 << " name: " << arg.name << " TF arg " << input_to_args->at(i); 859 XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)]; 860 switch (arg.kind) { 861 case XlaCompiler::Argument::kResource: { 862 TF_RET_CHECK(arg.initialized); 863 XlaResource* resource = arg_expression.resource(); 864 TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients, 865 arg_handles[i], builder)); 866 VLOG(2) << " resource: num_gradients: " 867 << arg.tensor_array_gradients.size(); 868 break; 869 } 870 case XlaCompiler::Argument::kParameter: 871 // Reshape parameters back to their correct shapes. 872 // TODO(b/76097077): propagate device assignments onto arguments and 873 // return values of functions, and then reshape unconditionally. 874 if (is_entry_computation) { 875 arg_expression = XlaExpression::XlaOp( 876 xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type); 877 } else { 878 arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); 879 } 880 break; 881 case XlaCompiler::Argument::kToken: { 882 arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); 883 break; 884 } 885 case XlaCompiler::Argument::kConstant: 886 case XlaCompiler::Argument::kInvalid: 887 return errors::Internal( 888 "Unreachable case in BuildArguments() while filling handles"); 889 } 890 } 891 892 return Status::OK(); 893 } 894 895 Status XlaCompiler::CompileSingleOp( 896 const XlaCompiler::CompileOptions& options, const NodeDef& node_def, 897 absl::Span<const XlaCompiler::Argument> args, 898 absl::Span<const DataType> result_types, CompilationResult* result) { 899 // TODO(b/74182462): We implement this by creating a new dummy Graph including 900 // _Arg nodes, and let CompileGraph walk it. This could be optimized. 901 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 902 903 Status status; 904 // First create the actual node we care about computing. 905 Node* main_node = graph->AddNode(node_def, &status); 906 TF_RETURN_IF_ERROR(status); 907 908 // Create dummy _Arg nodes. Link these to `node` and also via a control 909 // dependency edge to the _SOURCE node. 910 for (int64 i = 0; i < args.size(); ++i) { 911 Node* node; 912 string arg_name = absl::StrCat("_arg", i); 913 Status status = 914 NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp) 915 .ControlInput(graph->source_node()) 916 .Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE 917 : args[i].type) 918 .Attr("index", i) 919 .Finalize(graph.get(), &node); 920 TF_RETURN_IF_ERROR(status); 921 graph->AddEdge(node, 0, main_node, i); 922 } 923 924 // Similarly with return values, create dummy _Retval nodes fed by `node`. 925 for (int64 i = 0; i < result_types.size(); ++i) { 926 Node* node; 927 string retval_name = absl::StrCat("_retval", i); 928 Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) 929 .Input(main_node, i) 930 .Attr("T", result_types[i]) 931 .Attr("index", i) 932 .Finalize(graph.get(), &node); 933 TF_RETURN_IF_ERROR(status); 934 } 935 FixupSourceAndSinkEdges(graph.get()); 936 937 return CompileGraph(options, node_def.name(), std::move(graph), args, {}, 938 result); 939 } 940 941 namespace { 942 943 // Check that the ops of all non-functional nodes have been registered. 944 Status ValidateFunctionDef(const FunctionDef* fdef, 945 const FunctionLibraryDefinition& flib_def) { 946 for (const NodeDef& node : fdef->node_def()) { 947 const string& op = node.op(); 948 if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { 949 continue; 950 } 951 const OpDef* op_def; 952 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def)); 953 } 954 return Status::OK(); 955 } 956 957 // If node is PartitionedCall or StatefulPartitionedCall, returns the 958 // name from the "f" attr, else returns node.def().op(). 959 // Returned pointer points to the internal string either in node's attributes 960 // or in its NodeDef. This pointer is valid as long as the node has not been 961 // modified. 962 Status GetPotentialFunctionName(const Node& node, const string** name) { 963 if (node.IsPartitionedCall()) { 964 const AttrValue* attr_value; 965 TF_RETURN_IF_ERROR( 966 node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value)); 967 if (!attr_value->has_func()) { 968 return errors::InvalidArgument( 969 "The attribute value for attribute 'f' in node ", node.DebugString(), 970 " does not have 'func' field set"); 971 } 972 *name = &attr_value->func().name(); 973 return Status::OK(); 974 } 975 *name = &node.type_string(); 976 return Status::OK(); 977 } 978 979 // Check that the graph doesn't have any invalid nodes (e.g. incompatible with 980 // given device_type, invalid data type, missing attributes...) 981 Status ValidateGraph(const Graph* graph, 982 const FunctionLibraryDefinition& flib_def, 983 const DeviceType& device_type, const string& name) { 984 auto maybe_error = [&](const Node* node, const Status& s) -> Status { 985 if (!s.ok()) { 986 return errors::InvalidArgument(absl::StrCat( 987 "Detected unsupported operations when trying to compile graph ", name, 988 " on ", device_type.type_string(), ": ", node->def().op(), " (", 989 s.error_message(), ")", FormatNodeForError(*node))); 990 } 991 return Status::OK(); 992 }; 993 994 for (const Node* node : graph->nodes()) { 995 if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { 996 continue; 997 } 998 const string* function_name; 999 TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name)); 1000 const FunctionDef* fdef = flib_def.Find(*function_name); 1001 Status s; 1002 if (fdef) { 1003 s = ValidateFunctionDef(fdef, flib_def); 1004 TF_RETURN_IF_ERROR(maybe_error(node, s)); 1005 continue; 1006 } 1007 const OpDef* op_def; 1008 s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def); 1009 TF_RETURN_IF_ERROR(maybe_error(node, s)); 1010 TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def)); 1011 s = FindKernelDef(device_type, node->def(), nullptr, nullptr); 1012 TF_RETURN_IF_ERROR(maybe_error(node, s)); 1013 } 1014 return Status::OK(); 1015 } 1016 1017 // Converts the value of any expressions whose values are known at compile-time 1018 // to constants. 1019 Status ResolveConstantExpressionsToConstants( 1020 xla::Client* client, absl::Span<XlaExpression> expressions) { 1021 for (XlaExpression& expression : expressions) { 1022 if (expression.kind() == XlaExpression::Kind::kXlaOp) { 1023 TF_ASSIGN_OR_RETURN(absl::optional<Tensor> constant, 1024 expression.ResolveConstant(client)); 1025 if (constant.has_value()) { 1026 expression = XlaExpression::Constant(*constant); 1027 } 1028 } 1029 } 1030 return Status::OK(); 1031 } 1032 1033 void ConvertConstantsToExpressions(xla::XlaBuilder* builder, 1034 absl::Span<XlaExpression> expressions) { 1035 for (XlaExpression& expression : expressions) { 1036 if (expression.kind() == XlaExpression::Kind::kConstant) { 1037 expression = 1038 XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype()); 1039 } 1040 } 1041 } 1042 1043 } // namespace 1044 1045 Status XlaCompiler::CompileGraph( 1046 const XlaCompiler::CompileOptions& options, string const& name, 1047 std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args, 1048 absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases, 1049 CompilationResult* result) { 1050 VLOG(1) << "Executing graph symbolically to populate XlaBuilder."; 1051 1052 TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes( 1053 graph.get(), options_.flib_def, local_flib_def_.get())); 1054 if (VLOG_IS_ON(2)) { 1055 VLOG(2) << "XlaCompiler::CompileGraph: " 1056 << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph, 1057 flib_runtime_->GetFunctionLibraryDefinition()); 1058 } 1059 1060 // Report the error here if initialization failed. 1061 TF_RETURN_IF_ERROR(initialization_status_); 1062 1063 // Detect invalid nodes. 1064 // FunctionalizeControlFlow may remove some nodes from the graph. 1065 TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, 1066 options_.device_type, name)); 1067 1068 xla::XlaBuilder builder(name); 1069 XlaContext* context = new XlaContext(this, &builder); 1070 core::ScopedUnref context_unref(context); 1071 1072 std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end()); 1073 int token_input_index = -1; 1074 std::unique_ptr<xla::XlaOp> token_output; 1075 if (options.add_token_input_output) { 1076 // Add extra token input. 1077 token_input_index = real_args.size(); 1078 1079 XlaCompiler::Argument token_arg; 1080 token_arg.kind = XlaCompiler::Argument::kToken; 1081 real_args.push_back(token_arg); 1082 } 1083 1084 std::map<int, int> arg_cores; 1085 std::map<int, int> retval_cores; 1086 TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores), 1087 ComputeArgAndRetvalCores(*graph)); 1088 1089 std::vector<XlaExpression> arg_expressions; 1090 TF_RETURN_IF_ERROR(BuildArguments( 1091 *graph, real_args, options.use_tuple_arg, &builder, context, arg_cores, 1092 &arg_expressions, &result->input_mapping, &result->xla_input_shapes, 1093 options.is_entry_computation)); 1094 context->set_args(std::move(arg_expressions)); 1095 1096 // Propagate any aliases given to us by the user. 1097 for (const xla::XlaBuilder::InputOutputAlias& alias : user_aliases) { 1098 builder.SetUpAlias(alias.output_index, alias.param_number, 1099 alias.param_index); 1100 } 1101 1102 PushNodeTokenMapping(); 1103 // Use std::set instead of std::unordered_set to ensure determinism. 1104 std::set<std::string> output_node_token_inputs; 1105 if (token_input_index != -1) { 1106 // Original token comes from input. 1107 auto arg_expression = context->args()[token_input_index]; 1108 TF_RETURN_IF_ERROR( 1109 SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle())); 1110 1111 // Calculate token inputs for output token. 1112 output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph); 1113 1114 // If there's no side-effecting op in the graph, use token input as token 1115 // output. 1116 if (output_node_token_inputs.empty()) { 1117 output_node_token_inputs.insert(kXlaTokenArgNodeName); 1118 } 1119 } else if (options.is_entry_computation) { 1120 // Original token is manually created. 1121 if (HasSideEffectingNodes(*graph)) { 1122 TF_RETURN_IF_ERROR( 1123 SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder))); 1124 } 1125 } 1126 1127 TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, 1128 flib_runtime_, NextStepId())); 1129 if (token_input_index != -1) { 1130 // Add extra token output. 1131 std::vector<xla::XlaOp> token_inputs; 1132 for (const auto& node_name : output_node_token_inputs) { 1133 auto token_or = GetNodeToken(node_name); 1134 TF_RETURN_IF_ERROR(token_or.status()); 1135 token_inputs.push_back(token_or.ValueOrDie()); 1136 } 1137 token_output.reset(new xla::XlaOp(xla::AfterAll(&builder, token_inputs))); 1138 } 1139 TF_RETURN_IF_ERROR(PopNodeTokenMapping()); 1140 1141 int num_nonconst_outputs; 1142 int num_computation_outputs; 1143 result->computation = std::make_shared<xla::XlaComputation>(); 1144 result->outputs.resize(context->retvals().size()); 1145 std::vector<XlaExpression> retvals = context->retvals(); 1146 if (options.resolve_compile_time_constants) { 1147 Status status = ResolveConstantExpressionsToConstants( 1148 client(), absl::Span<XlaExpression>(retvals)); 1149 1150 // If the HloEvaluator has not implemented an expression, just evaluate it 1151 // at runtime. 1152 if (status.code() == error::UNIMPLEMENTED) { 1153 ConvertConstantsToExpressions(&builder, 1154 absl::Span<XlaExpression>(retvals)); 1155 } else { 1156 TF_RETURN_IF_ERROR(status); 1157 } 1158 } else { 1159 ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals)); 1160 } 1161 TF_RETURN_IF_ERROR(BuildComputation( 1162 real_args, retvals, arg_cores, retval_cores, context->resources(), 1163 std::move(token_output), 1164 options.is_entry_computation ? options_.shape_representation_fn 1165 : ShapeRepresentationFn{}, 1166 options.return_updated_values_for_all_resources, 1167 options.always_return_tuple, &builder, result->computation.get(), 1168 &num_computation_outputs, &num_nonconst_outputs, &result->outputs, 1169 &result->resource_updates, &result->xla_output_shape)); 1170 1171 VLOG(2) << "Outputs: total: " << context->retvals().size() 1172 << " nonconstant: " << num_nonconst_outputs; 1173 VLOG(2) << "XLA output shape: " 1174 << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape); 1175 return Status::OK(); 1176 } 1177 1178 Status XlaCompiler::GetChannelHandle(const string& key, 1179 xla::ChannelHandle* channel) { 1180 auto result = channels_.emplace(key, xla::ChannelHandle()); 1181 if (result.second) { 1182 TF_ASSIGN_OR_RETURN(result.first->second, client()->CreateChannelHandle()); 1183 } 1184 *channel = result.first->second; 1185 VLOG(1) << "Channel: " << key << " " << channel->DebugString(); 1186 return Status::OK(); 1187 } 1188 1189 Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key, 1190 xla::ChannelHandle* channel) { 1191 auto result = channels_.emplace(key, xla::ChannelHandle()); 1192 if (result.second) { 1193 TF_ASSIGN_OR_RETURN(result.first->second, 1194 client()->CreateHostToDeviceChannelHandle()); 1195 } 1196 *channel = result.first->second; 1197 VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString(); 1198 return Status::OK(); 1199 } 1200 1201 Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key, 1202 xla::ChannelHandle* channel) { 1203 auto result = channels_.emplace(key, xla::ChannelHandle()); 1204 if (result.second) { 1205 TF_ASSIGN_OR_RETURN(result.first->second, 1206 client()->CreateDeviceToHostChannelHandle()); 1207 } 1208 *channel = result.first->second; 1209 VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString(); 1210 return Status::OK(); 1211 } 1212 1213 namespace { 1214 1215 void SetTransfer(const string& key, absl::Span<const DataType> types, 1216 absl::Span<const TensorShape> shapes, 1217 tf2xla::HostTransferMetadata* transfer) { 1218 transfer->set_key(key); 1219 CHECK(types.size() == shapes.size()); 1220 for (int i = 0; i < types.size(); ++i) { 1221 tf2xla::TensorMetadata* metadata = transfer->add_metadata(); 1222 metadata->set_type(types[i]); 1223 shapes[i].AsProto(metadata->mutable_shape()); 1224 } 1225 } 1226 1227 } // namespace 1228 1229 Status XlaCompiler::SetDeviceToHostMetadata( 1230 const string& key, absl::Span<const DataType> types, 1231 absl::Span<const TensorShape> shapes) { 1232 if (host_compute_sends_.find(key) != host_compute_sends_.end()) { 1233 return errors::InvalidArgument( 1234 "Duplicate calls to SetDeviceToHostMetadata with key ", key); 1235 } 1236 tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key]; 1237 SetTransfer(key, types, shapes, &transfer); 1238 return Status::OK(); 1239 } 1240 1241 Status XlaCompiler::GetDeviceToHostShapes( 1242 const string& key, std::vector<TensorShape>* shapes) const { 1243 const auto iter = host_compute_sends_.find(key); 1244 if (iter == host_compute_sends_.end()) { 1245 return errors::InvalidArgument( 1246 "No host compute send shapes registered for key ", key); 1247 } 1248 shapes->clear(); 1249 for (int i = 0; i < iter->second.metadata_size(); ++i) { 1250 TensorShape shape(iter->second.metadata(i).shape()); 1251 shapes->push_back(shape); 1252 } 1253 return Status::OK(); 1254 } 1255 1256 Status XlaCompiler::SetHostToDeviceMetadata( 1257 const string& key, absl::Span<const DataType> types, 1258 absl::Span<const TensorShape> shapes) { 1259 if (host_compute_recvs_.find(key) != host_compute_sends_.end()) { 1260 return errors::InvalidArgument( 1261 "Duplicate calls to SetHostToDeviceMetadata with key ", key); 1262 } 1263 tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key]; 1264 SetTransfer(key, types, shapes, &transfer); 1265 return Status::OK(); 1266 } 1267 1268 Status XlaCompiler::GetHostComputeControlDependency( 1269 const string& host_compute_name, xla::XlaOp* handle) { 1270 const auto iter = host_compute_control_output_.find(host_compute_name); 1271 if (iter == host_compute_control_output_.end()) { 1272 return errors::InvalidArgument( 1273 "No registered control handle for host compute Op '", host_compute_name, 1274 "'"); 1275 } else { 1276 *handle = iter->second; 1277 } 1278 return Status::OK(); 1279 } 1280 1281 Status XlaCompiler::SetHostComputeControlDependency( 1282 const string& host_compute_name, const xla::XlaOp& handle) { 1283 if (host_compute_control_output_.find(host_compute_name) != 1284 host_compute_control_output_.end()) { 1285 return errors::InvalidArgument( 1286 "Duplicate control handles registered for for host compute Op ", 1287 host_compute_name); 1288 } 1289 host_compute_control_output_[host_compute_name] = handle; 1290 return Status::OK(); 1291 } 1292 1293 void XlaCompiler::PushNodeTokenMapping() { 1294 node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{}); 1295 } 1296 1297 Status XlaCompiler::PopNodeTokenMapping() { 1298 if (node_token_mapping_stack_.empty()) { 1299 return errors::FailedPrecondition( 1300 "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is " 1301 "empty."); 1302 } 1303 node_token_mapping_stack_.pop(); 1304 return Status::OK(); 1305 } 1306 1307 Status XlaCompiler::SetNodeToken(const string& node_name, 1308 const xla::XlaOp& op) { 1309 if (node_token_mapping_stack_.empty()) { 1310 return errors::FailedPrecondition( 1311 "Calling SetNodeToken() when node_token_mapping_stack_ is " 1312 "empty."); 1313 } 1314 auto insert_result = node_token_mapping_stack_.top().insert({node_name, op}); 1315 if (!insert_result.second) { 1316 return errors::FailedPrecondition("Token mapping already exists for node ", 1317 node_name); 1318 } 1319 return Status::OK(); 1320 } 1321 1322 xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) { 1323 if (node_token_mapping_stack_.empty()) { 1324 return errors::FailedPrecondition( 1325 "Calling GetNodeToken() when node_token_mapping_stack_ is " 1326 "empty."); 1327 } 1328 auto iter = node_token_mapping_stack_.top().find(node_name); 1329 if (iter == node_token_mapping_stack_.top().end()) { 1330 return errors::FailedPrecondition("Cannot find token mapping for node ", 1331 node_name); 1332 } 1333 return iter->second; 1334 } 1335 1336 } // namespace tensorflow 1337