Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
     17 
     18 #include <numeric>
     19 #include <vector>
     20 
     21 #include "absl/memory/memory.h"
     22 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
     23 #include "tensorflow/compiler/tf2xla/shape_util.h"
     24 #include "tensorflow/compiler/tf2xla/sharding_util.h"
     25 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
     26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
     27 #include "tensorflow/compiler/tf2xla/type_util.h"
     28 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
     29 #include "tensorflow/compiler/tf2xla/xla_context.h"
     30 #include "tensorflow/compiler/xla/client/client_library.h"
     31 #include "tensorflow/compiler/xla/client/xla_builder.h"
     32 #include "tensorflow/compiler/xla/client/xla_computation.h"
     33 #include "tensorflow/compiler/xla/util.h"
     34 #include "tensorflow/core/common_runtime/device.h"
     35 #include "tensorflow/core/common_runtime/executor.h"
     36 #include "tensorflow/core/common_runtime/function.h"
     37 #include "tensorflow/core/common_runtime/graph_optimizer.h"
     38 #include "tensorflow/core/framework/attr_value_util.h"
     39 #include "tensorflow/core/framework/function.h"
     40 #include "tensorflow/core/framework/node_def_util.h"
     41 #include "tensorflow/core/framework/types.h"
     42 #include "tensorflow/core/graph/algorithm.h"
     43 #include "tensorflow/core/graph/graph_constructor.h"
     44 #include "tensorflow/core/graph/node_builder.h"
     45 #include "tensorflow/core/lib/core/error_codes.pb.h"
     46 #include "tensorflow/core/lib/core/errors.h"
     47 #include "tensorflow/core/lib/gtl/cleanup.h"
     48 #include "tensorflow/core/lib/hash/hash.h"
     49 #include "tensorflow/core/platform/logging.h"
     50 #include "tensorflow/core/util/dump_graph.h"
     51 
     52 namespace tensorflow {
     53 namespace {
     54 
     55 // Checks that arguments `args` match types `types`.
     56 Status CheckSignature(const DataTypeVector& types,
     57                       absl::Span<const XlaCompiler::Argument> args) {
     58   if (args.size() != types.size()) {
     59     return errors::Internal("Compilation arguments have ", args.size(),
     60                             " elements while function has ", types.size());
     61   }
     62   for (int i = 0; i < types.size(); ++i) {
     63     // Don't perform type checks on resource variables and tensor
     64     // lists (DT_VARIANT) as we have to trick the type system in order to
     65     // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor.
     66     if (types[i] != args[i].type && types[i] != DT_RESOURCE &&
     67         types[i] != DT_VARIANT) {
     68       return errors::Internal(
     69           "Argument ", i, " has declared type ", DataTypeString(args[i].type),
     70           " but function parameter has type ", DataTypeString(types[i]));
     71     }
     72   }
     73   return Status::OK();
     74 }
     75 
     76 // Uses the _Arg and _Retval nodes in the graph to determine a core assignment
     77 // for each argument and return value.
     78 xla::StatusOr<std::pair<std::map<int, int>, std::map<int, int>>>
     79 ComputeArgAndRetvalCores(const Graph& graph) {
     80   auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr<int> {
     81     TF_ASSIGN_OR_RETURN(
     82         auto sharding,
     83         ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
     84     if (sharding.has_value()) {
     85       TF_RET_CHECK(sharding.value().type() ==
     86                    xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
     87       return sharding.value().tile_assignment_devices(0);
     88     } else {
     89       return -1;
     90     }
     91   };
     92   std::map<int, int> arg_cores;
     93   std::map<int, int> retval_cores;
     94   for (const Node* n : graph.nodes()) {
     95     if (n->IsArg()) {
     96       TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
     97       if (core < 0) continue;
     98       int index;
     99       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
    100       TF_RET_CHECK(index >= 0) << "Negative _Arg index";
    101       arg_cores[index] = core;
    102     } else if (n->IsRetval()) {
    103       TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
    104       if (core < 0) continue;
    105       int index;
    106       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
    107       TF_RET_CHECK(index >= 0) << "Negative _Retval index";
    108       TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n));
    109       retval_cores[index] = core;
    110     }
    111   }
    112   return std::make_pair(std::move(arg_cores), std::move(retval_cores));
    113 }
    114 
    115 Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
    116                     XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
    117                     int64 step_id) {
    118   // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
    119   // resource manager takes ownership via Create, and unrefs via Cleanup.  We
    120   // explicitly add a reference to ensure the refcount at entry is maintained at
    121   // all exit points; Create and Cleanup are always called in this function.
    122   //
    123   // The Executor requires us to use ScopedStepContainer. We wrap it in a
    124   // unique_ptr so we can capture the cleanup status in the end.
    125   xla_context->Ref();
    126   Status status;
    127   auto step_container = absl::make_unique<ScopedStepContainer>(
    128       step_id, [&status, device](const string& name) {
    129         status = device->resource_manager()->Cleanup(name);
    130       });
    131   TF_RETURN_IF_ERROR(device->resource_manager()->Create(
    132       step_container->name(), XlaContext::kXlaContextResourceName,
    133       xla_context));
    134 
    135   GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
    136   TF_RETURN_IF_ERROR(graph_compiler.Compile());
    137   // Explicitly clean up the step container, to capture the cleanup status.
    138   step_container.reset();
    139   return Status::OK();
    140 }
    141 
    142 // Builds the XLA computation.
    143 // - `args` is the list of input arguments
    144 // - `retvals` is the list of retvals produced by _Retval operators, in index
    145 //   order.
    146 // - `args_core` and `retval_cores` are mapping from arg/return indices to core
    147 //   assignments.
    148 // - If `return_updated_values_for_all_resources` is true, all resources will be
    149 //   included in `resource_updates`, regardless of whether their value changed.
    150 // - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
    151 // - Sets `*resource_updates` to a description of resources whose values are
    152 //   written by the computation; the variable writes are the last
    153 // - `resource_updates.size()` return values from the computation. Each entry in
    154 //   `resource_updates` is a ResourceUpdate, whose `index` is the index of a
    155 //   resource variable argument to the computation to be updated, and `type` is
    156 //   the type of the final output.
    157 Status BuildComputation(
    158     const std::vector<XlaCompiler::Argument>& args,
    159     const std::vector<XlaExpression>& retvals,
    160     const std::map<int, int>& arg_cores, const std::map<int, int>& retval_cores,
    161     const std::vector<std::unique_ptr<XlaResource>>& resources,
    162     std::unique_ptr<xla::XlaOp> token_output,
    163     const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
    164     bool return_updated_values_for_all_resources, bool always_return_tuple,
    165     xla::XlaBuilder* builder, xla::XlaComputation* computation,
    166     int* num_computation_outputs, int* num_nonconst_outputs,
    167     std::vector<XlaCompiler::OutputDescription>* outputs,
    168     std::vector<XlaCompiler::ResourceUpdate>* resource_updates,
    169     xla::Shape* output_shape) {
    170   // Attach a common operator name as metadata. This has no semantic effect  it
    171   // merely makes the HLO graph more readable when visualized via TensorBoard,
    172   // since TensorBoard forms groups out of operators with similar names.
    173   xla::OpMetadata retval_metadata;
    174   retval_metadata.set_op_name("XLA_Retvals");
    175   builder->SetOpMetadata(retval_metadata);
    176   auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
    177 
    178   // Builds a no-op XLA computation. We need to set the sharding of outputs, but
    179   // cannot change the sharding of the existing output op. To do this, we build
    180   // a new identity op to which shardings can be applied.
    181   auto identity_op = [builder](xla::XlaOp op) {
    182     return xla::GetTupleElement(xla::Tuple(builder, {op}), 0);
    183   };
    184 
    185   std::vector<xla::XlaOp> elems;
    186   elems.reserve(retvals.size());
    187 
    188   // Keeps track of the layout of each retval. If a retval is not in this list,
    189   // a descending layout is used. The first element is the output index, second
    190   // element is the new layout.
    191   std::vector<std::pair<int64, xla::Layout>> retval_index_and_layout;
    192   for (int i = 0; i < retvals.size(); ++i) {
    193     XlaCompiler::OutputDescription& output = (*outputs)[i];
    194     const XlaExpression& retval = retvals[i];
    195     output.type = retval.dtype();
    196     switch (retval.kind()) {
    197       case XlaExpression::Kind::kConstant:
    198         output.is_constant = true;
    199         output.constant_value = retval.constant_value();
    200         output.shape = output.constant_value.shape();
    201         break;
    202 
    203       case XlaExpression::Kind::kTensorList:
    204         TF_FALLTHROUGH_INTENDED;
    205       case XlaExpression::Kind::kXlaOp: {
    206         output.is_constant = false;
    207         TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
    208         xla::XlaOp value = retval.handle();
    209         auto it = retval_cores.find(i);
    210         xla::XlaScopedShardingAssignment assign_sharding(
    211             builder, it == retval_cores.end()
    212                          ? absl::optional<xla::OpSharding>()
    213                          : xla::sharding_builder::AssignDevice(it->second));
    214         if (shape_representation_fn) {
    215           // If there is a shape representation function, reshape the output
    216           // tensor to the shape given by the representation shape function.
    217           TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
    218                                                     output.shape, output.type));
    219           value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
    220           retval_index_and_layout.emplace_back(elems.size(), shape.layout());
    221         } else if (it != retval_cores.end()) {
    222           // Apply the sharding to the output, if there is a core assignment.
    223           value = identity_op(value);
    224         }
    225 
    226         elems.push_back(value);
    227         break;
    228       }
    229 
    230       case XlaExpression::Kind::kResource:
    231         output.is_constant = false;
    232         output.input_index = retval.resource()->arg_num();
    233         output.shape = retval.resource()->shape();
    234         break;
    235 
    236       case XlaExpression::Kind::kInvalid:
    237         return errors::InvalidArgument(
    238             "Invalid expression returned by computation. "
    239             "This probably means a return value was not set.");
    240     }
    241   }
    242   *num_nonconst_outputs = elems.size();
    243 
    244   // Add return values for resources whose values have changed.
    245   std::vector<const XlaResource*> arg_resources;
    246   arg_resources.reserve(resources.size());
    247   for (const auto& resource : resources) {
    248     if (resource->arg_num() >= 0) {
    249       arg_resources.push_back(resource.get());
    250     }
    251   }
    252   std::sort(arg_resources.begin(), arg_resources.end(),
    253             [](const XlaResource* a, const XlaResource* b) {
    254               return a->arg_num() < b->arg_num();
    255             });
    256 
    257   for (const XlaResource* resource : arg_resources) {
    258     DCHECK_LT(resource->arg_num(), args.size());
    259     const XlaCompiler::Argument& arg = args[resource->arg_num()];
    260     auto it = arg_cores.find(resource->arg_num());
    261     const int core = it == arg_cores.end() ? -1 : it->second;
    262     bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
    263     // TensorArray gradients were modified if their values changed or there are
    264     // any newly created gradients.
    265     for (const auto& grad : resource->tensor_array_gradients()) {
    266       modified =
    267           modified ||
    268           !grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
    269           arg.tensor_array_gradients.count(grad.first) == 0;
    270     }
    271     if (return_updated_values_for_all_resources || modified) {
    272       resource_updates->emplace_back();
    273       XlaCompiler::ResourceUpdate& update = resource_updates->back();
    274       update.input_index = resource->arg_num();
    275       update.type = resource->type();
    276       update.shape = resource->shape();
    277       update.modified = modified;
    278       for (const auto& grad : resource->tensor_array_gradients()) {
    279         update.tensor_array_gradients_accessed.insert(grad.first);
    280       }
    281 
    282       // Request that the value be returned on a specific core.
    283       xla::XlaScopedShardingAssignment assign_sharding(
    284           builder, core == -1 ? absl::optional<xla::OpSharding>()
    285                               : xla::sharding_builder::AssignDevice(core));
    286 
    287       xla::XlaOp handle;
    288       TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
    289 
    290       // Ensures the correct sharding is applied to the output.
    291       handle = identity_op(handle);
    292 
    293       // Set layout of the retval to device representation layout.
    294       if (resource->representation_shape().has_value()) {
    295         retval_index_and_layout.emplace_back(
    296             elems.size(), resource->representation_shape()->layout());
    297       }
    298       elems.push_back(handle);
    299     }
    300   }
    301 
    302   // If we have token output, append it as the last one.
    303   if (token_output) {
    304     elems.push_back(*token_output);
    305   }
    306 
    307   *num_computation_outputs = elems.size();
    308 
    309   // Builds the XLA computation. We *always* form a tuple here to ensure that
    310   // the output value is the last thing added into the XLA computation, even
    311   // if there is only one output value.
    312   auto tuple = xla::Tuple(builder, elems);
    313   if (!always_return_tuple && elems.size() == 1) {
    314     xla::GetTupleElement(tuple, 0);
    315   }
    316 
    317   xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
    318   if (!computation_status.ok()) {
    319     return computation_status.status();
    320   }
    321   *computation = computation_status.ConsumeValueOrDie();
    322 
    323   TF_ASSIGN_OR_RETURN(const auto& program_shape,
    324                       computation->GetProgramShape());
    325   *output_shape = program_shape.result();
    326   // Update the output layout to the layout of retval.
    327   for (auto& index_and_layout : retval_index_and_layout) {
    328     if (!always_return_tuple && elems.size() == 1) {
    329       *output_shape->mutable_layout() = index_and_layout.second;
    330       continue;
    331     }
    332 
    333     xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape(
    334         output_shape, {index_and_layout.first});
    335     *output_sub_shape->mutable_layout() = index_and_layout.second;
    336   }
    337   return Status::OK();
    338 }
    339 
    340 }  // namespace
    341 
    342 bool XlaCompiler::Argument::operator==(
    343     const XlaCompiler::Argument& other) const {
    344   if (std::tie(kind, resource_kind, type, name, initialized, max_array_size,
    345                tensor_array_gradients) !=
    346       std::tie(other.kind, other.resource_kind, other.type, other.name,
    347                other.initialized, other.max_array_size,
    348                other.tensor_array_gradients)) {
    349     return false;
    350   }
    351   if (absl::holds_alternative<xla::Shape>(shape)) {
    352     if (!absl::holds_alternative<xla::Shape>(other.shape)) {
    353       return false;
    354     }
    355     if (!xla::Shape::Equal()(absl::get<xla::Shape>(shape),
    356                              absl::get<xla::Shape>(other.shape))) {
    357       return false;
    358     }
    359   } else {
    360     if (!absl::holds_alternative<TensorShape>(other.shape)) {
    361       return false;
    362     }
    363     if (absl::get<TensorShape>(shape) != absl::get<TensorShape>(other.shape)) {
    364       return false;
    365     }
    366   }
    367   if (constant_value.shape() != other.constant_value.shape()) {
    368     return false;
    369   }
    370   return constant_value.tensor_data() == other.constant_value.tensor_data();
    371 }
    372 
    373 string XlaCompiler::Argument::HumanString() const {
    374   string common;
    375   if (!name.empty()) {
    376     common = absl::StrCat(" name=", name);
    377   }
    378   absl::StrAppend(&common, " type=", DataTypeString(type),
    379                   " shape=", ShapeHumanString());
    380   switch (kind) {
    381     case kInvalid:
    382       return "invalid";
    383     case kConstant:
    384       return absl::StrCat("kind=constant", common,
    385                           " value=", constant_value.DebugString());
    386     case kResource: {
    387       string output = absl::StrCat("kind=resource", common, " resource_kind=",
    388                                    XlaResource::KindToString(resource_kind),
    389                                    " initialized=", initialized);
    390       if (max_array_size >= 0) {
    391         absl::StrAppend(&output, " max_array_size=", max_array_size);
    392       }
    393       if (!tensor_array_gradients.empty()) {
    394         absl::StrAppend(&output, " tensor_array_gradients=",
    395                         absl::StrJoin(tensor_array_gradients, ","));
    396       }
    397       return output;
    398     }
    399     case kParameter:
    400       return absl::StrCat("kind=parameter", common);
    401     case kToken:
    402       return absl::StrCat("token", common);
    403   }
    404 }
    405 
    406 std::vector<int64> XlaCompiler::Argument::DimensionSizes() const {
    407   if (absl::holds_alternative<TensorShape>(shape)) {
    408     return xla::InlinedVectorToVector(
    409         absl::get<TensorShape>(shape).dim_sizes());
    410   } else {
    411     return absl::get<xla::Shape>(shape).dimensions();
    412   }
    413 }
    414 
    415 string XlaCompiler::Argument::ShapeHumanString() const {
    416   if (absl::holds_alternative<TensorShape>(shape)) {
    417     return absl::get<TensorShape>(shape).DebugString();
    418   } else {
    419     return absl::get<xla::Shape>(shape).DebugString();
    420   }
    421 }
    422 
    423 XlaCompiler::XlaCompiler(XlaCompiler::Options options)
    424     : options_(options),
    425       initialization_status_(Status::OK()),
    426       next_step_id_(1),
    427       device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
    428       device_mgr_(absl::WrapUnique(device_)) {
    429   CHECK(!options_.device_type.type_string().empty());
    430   if (options_.populate_resource_manager) {
    431     initialization_status_ =
    432         (*options_.populate_resource_manager)(device_->resource_manager());
    433   }
    434 
    435   local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
    436                                                       FunctionDefLibrary{}));
    437   local_pflr_.reset(new ProcessFunctionLibraryRuntime(
    438       &device_mgr_, Env::Default(), options.graph_def_version,
    439       local_flib_def_.get(), OptimizerOptions(),
    440       nullptr /* custom_kernel_creator */));
    441   pflr_.reset(new ProcessFunctionLibraryRuntime(
    442       &device_mgr_, Env::Default(), options.graph_def_version, options.flib_def,
    443       OptimizerOptions(), nullptr /* custom_kernel_creator */));
    444 
    445   local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
    446   flib_runtime_ = pflr_->GetFLR(device_->name());
    447 
    448   // The default shape representation function is the identity.
    449   if (!options_.shape_representation_fn) {
    450     options_.shape_representation_fn =
    451         [](const TensorShape& shape,
    452            DataType dtype) -> xla::StatusOr<xla::Shape> {
    453       xla::Shape xla_shape;
    454       TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
    455       return xla_shape;
    456     };
    457   }
    458 }
    459 
    460 XlaCompiler::~XlaCompiler() = default;
    461 
    462 int64 XlaCompiler::NextStepId() { return next_step_id_++; }
    463 
    464 uint64 XlaCompiler::SignatureHash::operator()(
    465     const std::pair<string, std::vector<Argument>>& signature) const {
    466   return std::hash<string>()(signature.first);
    467 }
    468 
    469 static Status GetFunctionBody(const NameAttrList& function,
    470                               FunctionLibraryRuntime* flib_runtime,
    471                               const FunctionBody** fbody) {
    472   FunctionLibraryRuntime::Handle handle;
    473   TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
    474       function.name(), AttrSlice(&function.attr()), &handle));
    475 
    476   *fbody = flib_runtime->GetFunctionBody(handle);
    477   TF_RET_CHECK(*fbody);
    478   return Status::OK();
    479 }
    480 
    481 Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
    482                                      const FunctionBody** fbody) {
    483   // The function may be in either the local_flib_runtime_ or flib_runtime_.
    484   // Look up the function in local first and if it is not found then look up the
    485   // function in flib_runtime_.
    486   auto status = GetFunctionBody(function, local_flib_runtime_, fbody);
    487   if (!status.ok()) {
    488     if (!errors::IsNotFound(status)) {
    489       return status;
    490     }
    491     TF_RETURN_WITH_CONTEXT_IF_ERROR(
    492         GetFunctionBody(function, flib_runtime_, fbody),
    493         "Local lookup failed with: ", status.error_message());
    494     VLOG(4) << "Function " << function.name() << " in flib_runtime_";
    495   } else {
    496     VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
    497   }
    498   return Status::OK();
    499 }
    500 
    501 std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
    502   std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
    503   CopyGraph(*fbody->graph, graph.get());
    504   OptimizerOptions opts;
    505   opts.set_opt_level(OptimizerOptions::L0);
    506   opts.set_do_common_subexpression_elimination(false);
    507   opts.set_do_function_inlining(true);
    508   opts.set_do_constant_folding(true);
    509   GraphOptimizer optimizer(opts);
    510   // Do not constant fold nodes that output DT_VARIANT type tensors.
    511   // XLA does not support Const nodes of Variant type since it needs
    512   // to know the original ops to be able to compile them to the relevant
    513   // XLA form.
    514   // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
    515   // the form:
    516   //                          Const
    517   //                            |
    518   // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
    519   //                                                  |
    520   //                                        (Discard popped list)
    521   //
    522   // Would have been reduced to "Const -> Op" without this filter.
    523   // However since we are only allowed to specify the filter at the "Node"
    524   // level there is no good way to allow the above behavior. So we
    525   // disallow any sort of constant folding on Variant nodes for now.
    526   auto cf_consider_fn = [](const Node* n) {
    527     for (const auto& output_arg : n->op_def().output_arg()) {
    528       if (output_arg.type() == DT_VARIANT) {
    529         return false;
    530       }
    531     }
    532     return true;
    533   };
    534   GraphOptimizer::Options graph_optimizer_options;
    535   graph_optimizer_options.cf_consider_fn = cf_consider_fn;
    536   optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
    537                      /*device=*/nullptr, &graph, graph_optimizer_options);
    538 
    539   return graph;
    540 }
    541 
    542 Status XlaCompiler::CompileFunction(
    543     const XlaCompiler::CompileOptions& options, const NameAttrList& function,
    544     absl::Span<const XlaCompiler::Argument> args,
    545     XlaCompiler::CompilationResult* result) {
    546   const string function_id =
    547       Canonicalize(function.name(), AttrSlice(&function.attr()));
    548   VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
    549 
    550   const std::vector<XlaCompiler::Argument> arg_vector(args.begin(), args.end());
    551   auto it = cache_.find({function_id, arg_vector});
    552   if (it != cache_.end()) {
    553     *result = it->second;
    554     return Status::OK();
    555   }
    556 
    557   const FunctionBody* fbody;
    558   TF_RETURN_IF_ERROR(FindFunctionBody(function, &fbody));
    559 
    560   TF_RETURN_WITH_CONTEXT_IF_ERROR(
    561       CheckSignature(fbody->arg_types, args),
    562       "Signature check failure while compiling: ", function.name());
    563 
    564   std::unique_ptr<Graph> graph = GetGraph(fbody);
    565 
    566   // Clear the "_kernel" attribute if it is set to "host". This is used to
    567   // indicate that a computation should happen on the host instead of the
    568   // accelerator, but doesn't make sense in XLA.
    569   const char* const kKernelAttr = "_kernel";
    570   for (Node* n : graph->nodes()) {
    571     string value;
    572     if (GetNodeAttrSimple(n->attrs(), kKernelAttr, &value) && value == "host") {
    573       n->ClearAttr(kKernelAttr);
    574     }
    575   }
    576 
    577   // _Arg and _Retval nodes don't exist in the stored subgraph for the function;
    578   // they are added by the function body looked up.  Therefore, they don't have
    579   // core assignments here.
    580   // Attempt to assign a core to each _Retval and _Arg. Chooses the
    581   // lowest-numbered core that consumes the argument. We choose the
    582   // lowest-numbered core so the assignment is deterministic.
    583   for (Node* n : graph->nodes()) {
    584     if (n->IsArg()) {
    585       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
    586     }
    587   }
    588   // Do _Retval as a second loop, in case the retval's input is an _Arg (which
    589   // may have gotten a device assignment from the first loop).
    590   for (Node* n : graph->nodes()) {
    591     if (n->IsRetval()) {
    592       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
    593     }
    594   }
    595 
    596   if (VLOG_IS_ON(2)) {
    597     VLOG(2) << "XlaCompiler::CompileFunction: "
    598             << DumpGraphToFile(
    599                    absl::StrCat("xla_compile_function_", function_id), *graph);
    600   }
    601 
    602   VLOG(1) << "====================================================";
    603   TF_RETURN_IF_ERROR(
    604       CompileGraph(options, function_id, std::move(graph), args, {}, result));
    605   VLOG(1) << "====================================================";
    606 
    607   cache_[{function_id, arg_vector}] = *result;
    608   return Status::OK();
    609 }
    610 
    611 // Computes the XLA shape for argument 'arg'.
    612 Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
    613                                         bool is_entry_computation,
    614                                         xla::Shape* xla_shape) const {
    615   switch (arg.kind) {
    616     case XlaCompiler::Argument::kConstant:
    617       LOG(FATAL) << "Unreachable case";
    618     case XlaCompiler::Argument::kParameter: {
    619       if (is_entry_computation) {
    620         TensorShape shape;
    621         if (absl::holds_alternative<TensorShape>(arg.shape)) {
    622           shape = absl::get<TensorShape>(arg.shape);
    623         } else {
    624           TF_RETURN_IF_ERROR(
    625               XLAShapeToTensorShape(absl::get<xla::Shape>(arg.shape), &shape));
    626         }
    627         TF_ASSIGN_OR_RETURN(*xla_shape,
    628                             options_.shape_representation_fn(shape, arg.type));
    629       } else {
    630         if (absl::holds_alternative<xla::Shape>(arg.shape)) {
    631           *xla_shape = absl::get<xla::Shape>(arg.shape);
    632         } else {
    633           TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
    634               arg.type, absl::get<TensorShape>(arg.shape), xla_shape));
    635         }
    636       }
    637       return Status::OK();
    638     }
    639     case XlaCompiler::Argument::kResource: {
    640       TF_RET_CHECK(arg.initialized);
    641 
    642       switch (arg.resource_kind) {
    643         case XlaResource::kVariable: {
    644           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
    645           TF_ASSIGN_OR_RETURN(*xla_shape,
    646                               options_.shape_representation_fn(
    647                                   absl::get<TensorShape>(arg.shape), arg.type));
    648 
    649           return Status::OK();
    650         }
    651         case XlaResource::kTensorArray: {
    652           if (arg.max_array_size < 0) {
    653             return errors::InvalidArgument(
    654                 "Negative max_array_size in XLAShapeForArgument");
    655           }
    656           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
    657           TensorShape shape;
    658           shape.AddDim(arg.max_array_size);
    659           shape.AppendShape(absl::get<TensorShape>(arg.shape));
    660           TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
    661 
    662           if (!arg.tensor_array_gradients.empty()) {
    663             std::vector<xla::Shape> tuple_shape(
    664                 arg.tensor_array_gradients.size() + 1, *xla_shape);
    665             *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
    666           }
    667           return Status::OK();
    668         }
    669         case XlaResource::kStack: {
    670           if (arg.max_array_size < 0) {
    671             return errors::InvalidArgument(
    672                 "Negative max_array_size in XLAShapeForArgument");
    673           }
    674           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
    675           TensorShape shape;
    676           shape.AddDim(arg.max_array_size);
    677           shape.AppendShape(absl::get<TensorShape>(arg.shape));
    678           xla::Shape buffer_shape;
    679           TF_RETURN_IF_ERROR(
    680               TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
    681           *xla_shape = xla::ShapeUtil::MakeTupleShape(
    682               {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
    683           return Status::OK();
    684         }
    685 
    686         case XlaResource::kInvalid:
    687           return errors::Internal(
    688               "Invalid resource type in XLAShapeForArgument()");
    689       }
    690     }
    691     case XlaCompiler::Argument::kToken: {
    692       *xla_shape = xla::ShapeUtil::MakeTokenShape();
    693       return Status::OK();
    694     }
    695     case XlaCompiler::Argument::kInvalid:
    696       return errors::Internal("Invalid argument type in XLAShapeForArgument()");
    697   }
    698 }
    699 
    700 // Builds XLA computations for each of the arguments to the computation.
    701 // `args` are the arguments to the computation.
    702 Status XlaCompiler::BuildArguments(
    703     const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
    704     bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
    705     const std::map<int, int>& arg_cores,
    706     std::vector<XlaExpression>* arg_expressions,
    707     std::vector<int>* input_to_args, std::vector<xla::Shape>* input_shapes,
    708     bool is_entry_computation) {
    709   arg_expressions->resize(args.size());
    710 
    711   // Argument numbers of arguments and resources that are to be passed to the
    712   // XLA computation as runtime parameters. `input_to_args[a] = b` means that
    713   // the a'th XLA input corresponds to the b'th original arg indexes.
    714   input_to_args->clear();
    715   input_to_args->reserve(args.size());
    716 
    717   // Fills in constant arguments, and computes non-constant argument order.
    718   for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
    719        ++i) {
    720     const XlaCompiler::Argument& arg = args[i];
    721     XlaExpression& arg_expression = (*arg_expressions)[i];
    722     switch (arg.kind) {
    723       case XlaCompiler::Argument::kResource: {
    724         TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
    725         TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
    726         // TODO(phawkins): this code assumes that resource arguments do not
    727         // alias.
    728         XlaResource* resource =
    729             context->AddResource(absl::make_unique<XlaResource>(
    730                 arg.resource_kind, i, arg.name, arg.type,
    731                 absl::get<TensorShape>(arg.shape), xla::XlaOp(),
    732                 /*max_array_size=*/arg.max_array_size,
    733                 /*tensor_array_gradients=*/arg.tensor_array_gradients,
    734                 /*tensor_array_multiple_writes_aggregate=*/true));
    735         arg_expression = XlaExpression::Resource(resource);
    736         if (arg.initialized) {
    737           input_to_args->push_back(i);
    738         }
    739         break;
    740       }
    741       case XlaCompiler::Argument::kParameter:
    742       case XlaCompiler::Argument::kToken: {
    743         input_to_args->push_back(i);
    744         break;
    745       }
    746       case XlaCompiler::Argument::kConstant:
    747         arg_expression = XlaExpression::Constant(arg.constant_value);
    748         break;
    749       case XlaCompiler::Argument::kInvalid:
    750         return errors::Internal(
    751             "Unreachable case in BuildArguments() while filling constant args");
    752     }
    753   }
    754 
    755   if (input_to_args->empty()) {
    756     return Status::OK();
    757   }
    758 
    759   // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds
    760   // to the d'th XLA input. Note that the value -1 corresponds to constants, or
    761   // other args that don't correspond to an input.
    762   std::vector<int> arg_to_inputs(args.size(), -1);
    763   for (int i = 0; i < input_to_args->size(); i++) {
    764     arg_to_inputs[input_to_args->at(i)] = i;
    765   }
    766 
    767   std::vector<xla::Shape> arg_shapes(input_to_args->size());
    768   for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
    769     // Computes the shapes of non-constant arguments.
    770     TF_RETURN_IF_ERROR(XLAShapeForArgument(
    771         args[(*input_to_args)[i]], is_entry_computation, &arg_shapes[i]));
    772   }
    773 
    774   if (use_tuple_arg) {
    775     input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes));
    776   } else {
    777     *input_shapes = arg_shapes;
    778   }
    779 
    780   // Attach a common operator name as metadata. This has no semantic effect  it
    781   // merely makes the HLO graph more readable when visualized via TensorBoard,
    782   // since TensorBoard forms groups out of operators with similar names.
    783   xla::OpMetadata arg_metadata;
    784   arg_metadata.set_op_name("XLA_Args");
    785   builder->SetOpMetadata(arg_metadata);
    786 
    787   // Build parameter handles for non-constant arguments.
    788   std::vector<xla::XlaOp> arg_handles(input_to_args->size());
    789   if (use_tuple_arg) {
    790     xla::XlaOp tuple;
    791     if (is_entry_computation) {
    792       xla::OpSharding tuple_sharding;
    793       tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
    794       for (int64 parameter : *input_to_args) {
    795         auto it = arg_cores.find(parameter);
    796         const int core = it == arg_cores.end() ? 0 : it->second;
    797         *tuple_sharding.add_tuple_shardings() =
    798             xla::sharding_builder::AssignDevice(core);
    799       }
    800       xla::XlaScopedShardingAssignment assign_tuple_sharding(builder,
    801                                                              tuple_sharding);
    802       tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
    803     } else {
    804       tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
    805     }
    806 
    807     for (int i = 0; i < input_to_args->size(); ++i) {
    808       const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
    809       for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) {
    810         int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second);
    811         TF_RETURN_IF_ERROR(builder->SetDynamicBinding(
    812             /*dynamic_size_param_num=*/0, {dynamic_size_param_index},
    813             /*target_param_num=*/0, /*target_param_index=*/{i},
    814             dim_and_arg_num.first));
    815       }
    816     }
    817 
    818     for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
    819       auto it = arg_cores.find(i);
    820       const int core = it == arg_cores.end() ? -1 : it->second;
    821       xla::XlaScopedShardingAssignment assign_sharding(
    822           builder, core == -1 ? absl::optional<xla::OpSharding>()
    823                               : xla::sharding_builder::AssignDevice(core));
    824       arg_handles[i] = xla::GetTupleElement(tuple, i);
    825     }
    826   } else {
    827     for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
    828       auto it = arg_cores.find(i);
    829       const int core = it == arg_cores.end() ? -1 : it->second;
    830       xla::XlaScopedShardingAssignment assign_sharding(
    831           builder, core == -1 ? absl::optional<xla::OpSharding>()
    832                               : xla::sharding_builder::AssignDevice(core));
    833       arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i],
    834                                       absl::StrCat("arg", i));
    835     }
    836 
    837     for (int i = 0; i < input_to_args->size(); ++i) {
    838       const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
    839       for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) {
    840         int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second);
    841         TF_RETURN_IF_ERROR(builder->SetDynamicBinding(
    842             /*dynamic_size_param_num=*/dynamic_size_param_index, {},
    843             /*target_param_num=*/i, /*target_param_index=*/{},
    844             dim_and_arg_num.first));
    845       }
    846     }
    847   }
    848 
    849   builder->ClearOpMetadata();
    850 
    851   // Fill in the handles in non-constant arguments, and reshape parameters
    852   // back to their correct shapes.
    853   VLOG(2) << "XLA computation inputs:";
    854   for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
    855     const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
    856     VLOG(2) << "  XLA arg " << i
    857             << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
    858             << " name: " << arg.name << " TF arg " << input_to_args->at(i);
    859     XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
    860     switch (arg.kind) {
    861       case XlaCompiler::Argument::kResource: {
    862         TF_RET_CHECK(arg.initialized);
    863         XlaResource* resource = arg_expression.resource();
    864         TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
    865                                                  arg_handles[i], builder));
    866         VLOG(2) << "    resource: num_gradients: "
    867                 << arg.tensor_array_gradients.size();
    868         break;
    869       }
    870       case XlaCompiler::Argument::kParameter:
    871         // Reshape parameters back to their correct shapes.
    872         // TODO(b/76097077): propagate device assignments onto arguments and
    873         // return values of functions, and then reshape unconditionally.
    874         if (is_entry_computation) {
    875           arg_expression = XlaExpression::XlaOp(
    876               xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type);
    877         } else {
    878           arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
    879         }
    880         break;
    881       case XlaCompiler::Argument::kToken: {
    882         arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
    883         break;
    884       }
    885       case XlaCompiler::Argument::kConstant:
    886       case XlaCompiler::Argument::kInvalid:
    887         return errors::Internal(
    888             "Unreachable case in BuildArguments() while filling handles");
    889     }
    890   }
    891 
    892   return Status::OK();
    893 }
    894 
    895 Status XlaCompiler::CompileSingleOp(
    896     const XlaCompiler::CompileOptions& options, const NodeDef& node_def,
    897     absl::Span<const XlaCompiler::Argument> args,
    898     absl::Span<const DataType> result_types, CompilationResult* result) {
    899   // TODO(b/74182462): We implement this by creating a new dummy Graph including
    900   // _Arg nodes, and let CompileGraph walk it. This could be optimized.
    901   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    902 
    903   Status status;
    904   // First create the actual node we care about computing.
    905   Node* main_node = graph->AddNode(node_def, &status);
    906   TF_RETURN_IF_ERROR(status);
    907 
    908   // Create dummy _Arg nodes. Link these to `node` and also via a control
    909   // dependency edge to the _SOURCE node.
    910   for (int64 i = 0; i < args.size(); ++i) {
    911     Node* node;
    912     string arg_name = absl::StrCat("_arg", i);
    913     Status status =
    914         NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
    915             .ControlInput(graph->source_node())
    916             .Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE
    917                                                            : args[i].type)
    918             .Attr("index", i)
    919             .Finalize(graph.get(), &node);
    920     TF_RETURN_IF_ERROR(status);
    921     graph->AddEdge(node, 0, main_node, i);
    922   }
    923 
    924   // Similarly with return values, create dummy _Retval nodes fed by `node`.
    925   for (int64 i = 0; i < result_types.size(); ++i) {
    926     Node* node;
    927     string retval_name = absl::StrCat("_retval", i);
    928     Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
    929                         .Input(main_node, i)
    930                         .Attr("T", result_types[i])
    931                         .Attr("index", i)
    932                         .Finalize(graph.get(), &node);
    933     TF_RETURN_IF_ERROR(status);
    934   }
    935   FixupSourceAndSinkEdges(graph.get());
    936 
    937   return CompileGraph(options, node_def.name(), std::move(graph), args, {},
    938                       result);
    939 }
    940 
    941 namespace {
    942 
    943 // Check that the ops of all non-functional nodes have been registered.
    944 Status ValidateFunctionDef(const FunctionDef* fdef,
    945                            const FunctionLibraryDefinition& flib_def) {
    946   for (const NodeDef& node : fdef->node_def()) {
    947     const string& op = node.op();
    948     if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
    949       continue;
    950     }
    951     const OpDef* op_def;
    952     TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def));
    953   }
    954   return Status::OK();
    955 }
    956 
    957 // If node is PartitionedCall or StatefulPartitionedCall, returns the
    958 // name from the "f" attr, else returns node.def().op().
    959 // Returned pointer points to the internal string either in node's attributes
    960 // or in its NodeDef. This pointer is valid as long as the node has not been
    961 // modified.
    962 Status GetPotentialFunctionName(const Node& node, const string** name) {
    963   if (node.IsPartitionedCall()) {
    964     const AttrValue* attr_value;
    965     TF_RETURN_IF_ERROR(
    966         node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
    967     if (!attr_value->has_func()) {
    968       return errors::InvalidArgument(
    969           "The attribute value for attribute 'f' in node ", node.DebugString(),
    970           " does not have 'func' field set");
    971     }
    972     *name = &attr_value->func().name();
    973     return Status::OK();
    974   }
    975   *name = &node.type_string();
    976   return Status::OK();
    977 }
    978 
    979 // Check that the graph doesn't have any invalid nodes (e.g. incompatible with
    980 // given device_type, invalid data type, missing attributes...)
    981 Status ValidateGraph(const Graph* graph,
    982                      const FunctionLibraryDefinition& flib_def,
    983                      const DeviceType& device_type, const string& name) {
    984   auto maybe_error = [&](const Node* node, const Status& s) -> Status {
    985     if (!s.ok()) {
    986       return errors::InvalidArgument(absl::StrCat(
    987           "Detected unsupported operations when trying to compile graph ", name,
    988           " on ", device_type.type_string(), ": ", node->def().op(), " (",
    989           s.error_message(), ")", FormatNodeForError(*node)));
    990     }
    991     return Status::OK();
    992   };
    993 
    994   for (const Node* node : graph->nodes()) {
    995     if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
    996       continue;
    997     }
    998     const string* function_name;
    999     TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name));
   1000     const FunctionDef* fdef = flib_def.Find(*function_name);
   1001     Status s;
   1002     if (fdef) {
   1003       s = ValidateFunctionDef(fdef, flib_def);
   1004       TF_RETURN_IF_ERROR(maybe_error(node, s));
   1005       continue;
   1006     }
   1007     const OpDef* op_def;
   1008     s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def);
   1009     TF_RETURN_IF_ERROR(maybe_error(node, s));
   1010     TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
   1011     s = FindKernelDef(device_type, node->def(), nullptr, nullptr);
   1012     TF_RETURN_IF_ERROR(maybe_error(node, s));
   1013   }
   1014   return Status::OK();
   1015 }
   1016 
   1017 // Converts the value of any expressions whose values are known at compile-time
   1018 // to constants.
   1019 Status ResolveConstantExpressionsToConstants(
   1020     xla::Client* client, absl::Span<XlaExpression> expressions) {
   1021   for (XlaExpression& expression : expressions) {
   1022     if (expression.kind() == XlaExpression::Kind::kXlaOp) {
   1023       TF_ASSIGN_OR_RETURN(absl::optional<Tensor> constant,
   1024                           expression.ResolveConstant(client));
   1025       if (constant.has_value()) {
   1026         expression = XlaExpression::Constant(*constant);
   1027       }
   1028     }
   1029   }
   1030   return Status::OK();
   1031 }
   1032 
   1033 void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
   1034                                    absl::Span<XlaExpression> expressions) {
   1035   for (XlaExpression& expression : expressions) {
   1036     if (expression.kind() == XlaExpression::Kind::kConstant) {
   1037       expression =
   1038           XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype());
   1039     }
   1040   }
   1041 }
   1042 
   1043 }  // namespace
   1044 
   1045 Status XlaCompiler::CompileGraph(
   1046     const XlaCompiler::CompileOptions& options, string const& name,
   1047     std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
   1048     absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,
   1049     CompilationResult* result) {
   1050   VLOG(1) << "Executing graph symbolically to populate XlaBuilder.";
   1051 
   1052   TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
   1053       graph.get(), options_.flib_def, local_flib_def_.get()));
   1054   if (VLOG_IS_ON(2)) {
   1055     VLOG(2) << "XlaCompiler::CompileGraph: "
   1056             << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph,
   1057                                flib_runtime_->GetFunctionLibraryDefinition());
   1058   }
   1059 
   1060   // Report the error here if initialization failed.
   1061   TF_RETURN_IF_ERROR(initialization_status_);
   1062 
   1063   // Detect invalid nodes.
   1064   // FunctionalizeControlFlow may remove some nodes from the graph.
   1065   TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
   1066                                    options_.device_type, name));
   1067 
   1068   xla::XlaBuilder builder(name);
   1069   XlaContext* context = new XlaContext(this, &builder);
   1070   core::ScopedUnref context_unref(context);
   1071 
   1072   std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
   1073   int token_input_index = -1;
   1074   std::unique_ptr<xla::XlaOp> token_output;
   1075   if (options.add_token_input_output) {
   1076     // Add extra token input.
   1077     token_input_index = real_args.size();
   1078 
   1079     XlaCompiler::Argument token_arg;
   1080     token_arg.kind = XlaCompiler::Argument::kToken;
   1081     real_args.push_back(token_arg);
   1082   }
   1083 
   1084   std::map<int, int> arg_cores;
   1085   std::map<int, int> retval_cores;
   1086   TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores),
   1087                       ComputeArgAndRetvalCores(*graph));
   1088 
   1089   std::vector<XlaExpression> arg_expressions;
   1090   TF_RETURN_IF_ERROR(BuildArguments(
   1091       *graph, real_args, options.use_tuple_arg, &builder, context, arg_cores,
   1092       &arg_expressions, &result->input_mapping, &result->xla_input_shapes,
   1093       options.is_entry_computation));
   1094   context->set_args(std::move(arg_expressions));
   1095 
   1096   // Propagate any aliases given to us by the user.
   1097   for (const xla::XlaBuilder::InputOutputAlias& alias : user_aliases) {
   1098     builder.SetUpAlias(alias.output_index, alias.param_number,
   1099                        alias.param_index);
   1100   }
   1101 
   1102   PushNodeTokenMapping();
   1103   // Use std::set instead of std::unordered_set to ensure determinism.
   1104   std::set<std::string> output_node_token_inputs;
   1105   if (token_input_index != -1) {
   1106     // Original token comes from input.
   1107     auto arg_expression = context->args()[token_input_index];
   1108     TF_RETURN_IF_ERROR(
   1109         SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
   1110 
   1111     // Calculate token inputs for output token.
   1112     output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
   1113 
   1114     // If there's no side-effecting op in the graph, use token input as token
   1115     // output.
   1116     if (output_node_token_inputs.empty()) {
   1117       output_node_token_inputs.insert(kXlaTokenArgNodeName);
   1118     }
   1119   } else if (options.is_entry_computation) {
   1120     // Original token is manually created.
   1121     if (HasSideEffectingNodes(*graph)) {
   1122       TF_RETURN_IF_ERROR(
   1123           SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
   1124     }
   1125   }
   1126 
   1127   TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
   1128                                   flib_runtime_, NextStepId()));
   1129   if (token_input_index != -1) {
   1130     // Add extra token output.
   1131     std::vector<xla::XlaOp> token_inputs;
   1132     for (const auto& node_name : output_node_token_inputs) {
   1133       auto token_or = GetNodeToken(node_name);
   1134       TF_RETURN_IF_ERROR(token_or.status());
   1135       token_inputs.push_back(token_or.ValueOrDie());
   1136     }
   1137     token_output.reset(new xla::XlaOp(xla::AfterAll(&builder, token_inputs)));
   1138   }
   1139   TF_RETURN_IF_ERROR(PopNodeTokenMapping());
   1140 
   1141   int num_nonconst_outputs;
   1142   int num_computation_outputs;
   1143   result->computation = std::make_shared<xla::XlaComputation>();
   1144   result->outputs.resize(context->retvals().size());
   1145   std::vector<XlaExpression> retvals = context->retvals();
   1146   if (options.resolve_compile_time_constants) {
   1147     Status status = ResolveConstantExpressionsToConstants(
   1148         client(), absl::Span<XlaExpression>(retvals));
   1149 
   1150     // If the HloEvaluator has not implemented an expression, just evaluate it
   1151     // at runtime.
   1152     if (status.code() == error::UNIMPLEMENTED) {
   1153       ConvertConstantsToExpressions(&builder,
   1154                                     absl::Span<XlaExpression>(retvals));
   1155     } else {
   1156       TF_RETURN_IF_ERROR(status);
   1157     }
   1158   } else {
   1159     ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
   1160   }
   1161   TF_RETURN_IF_ERROR(BuildComputation(
   1162       real_args, retvals, arg_cores, retval_cores, context->resources(),
   1163       std::move(token_output),
   1164       options.is_entry_computation ? options_.shape_representation_fn
   1165                                    : ShapeRepresentationFn{},
   1166       options.return_updated_values_for_all_resources,
   1167       options.always_return_tuple, &builder, result->computation.get(),
   1168       &num_computation_outputs, &num_nonconst_outputs, &result->outputs,
   1169       &result->resource_updates, &result->xla_output_shape));
   1170 
   1171   VLOG(2) << "Outputs: total: " << context->retvals().size()
   1172           << " nonconstant: " << num_nonconst_outputs;
   1173   VLOG(2) << "XLA output shape: "
   1174           << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape);
   1175   return Status::OK();
   1176 }
   1177 
   1178 Status XlaCompiler::GetChannelHandle(const string& key,
   1179                                      xla::ChannelHandle* channel) {
   1180   auto result = channels_.emplace(key, xla::ChannelHandle());
   1181   if (result.second) {
   1182     TF_ASSIGN_OR_RETURN(result.first->second, client()->CreateChannelHandle());
   1183   }
   1184   *channel = result.first->second;
   1185   VLOG(1) << "Channel: " << key << " " << channel->DebugString();
   1186   return Status::OK();
   1187 }
   1188 
   1189 Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key,
   1190                                                  xla::ChannelHandle* channel) {
   1191   auto result = channels_.emplace(key, xla::ChannelHandle());
   1192   if (result.second) {
   1193     TF_ASSIGN_OR_RETURN(result.first->second,
   1194                         client()->CreateHostToDeviceChannelHandle());
   1195   }
   1196   *channel = result.first->second;
   1197   VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString();
   1198   return Status::OK();
   1199 }
   1200 
   1201 Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
   1202                                                  xla::ChannelHandle* channel) {
   1203   auto result = channels_.emplace(key, xla::ChannelHandle());
   1204   if (result.second) {
   1205     TF_ASSIGN_OR_RETURN(result.first->second,
   1206                         client()->CreateDeviceToHostChannelHandle());
   1207   }
   1208   *channel = result.first->second;
   1209   VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString();
   1210   return Status::OK();
   1211 }
   1212 
   1213 namespace {
   1214 
   1215 void SetTransfer(const string& key, absl::Span<const DataType> types,
   1216                  absl::Span<const TensorShape> shapes,
   1217                  tf2xla::HostTransferMetadata* transfer) {
   1218   transfer->set_key(key);
   1219   CHECK(types.size() == shapes.size());
   1220   for (int i = 0; i < types.size(); ++i) {
   1221     tf2xla::TensorMetadata* metadata = transfer->add_metadata();
   1222     metadata->set_type(types[i]);
   1223     shapes[i].AsProto(metadata->mutable_shape());
   1224   }
   1225 }
   1226 
   1227 }  // namespace
   1228 
   1229 Status XlaCompiler::SetDeviceToHostMetadata(
   1230     const string& key, absl::Span<const DataType> types,
   1231     absl::Span<const TensorShape> shapes) {
   1232   if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
   1233     return errors::InvalidArgument(
   1234         "Duplicate calls to SetDeviceToHostMetadata with key ", key);
   1235   }
   1236   tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
   1237   SetTransfer(key, types, shapes, &transfer);
   1238   return Status::OK();
   1239 }
   1240 
   1241 Status XlaCompiler::GetDeviceToHostShapes(
   1242     const string& key, std::vector<TensorShape>* shapes) const {
   1243   const auto iter = host_compute_sends_.find(key);
   1244   if (iter == host_compute_sends_.end()) {
   1245     return errors::InvalidArgument(
   1246         "No host compute send shapes registered for key ", key);
   1247   }
   1248   shapes->clear();
   1249   for (int i = 0; i < iter->second.metadata_size(); ++i) {
   1250     TensorShape shape(iter->second.metadata(i).shape());
   1251     shapes->push_back(shape);
   1252   }
   1253   return Status::OK();
   1254 }
   1255 
   1256 Status XlaCompiler::SetHostToDeviceMetadata(
   1257     const string& key, absl::Span<const DataType> types,
   1258     absl::Span<const TensorShape> shapes) {
   1259   if (host_compute_recvs_.find(key) != host_compute_sends_.end()) {
   1260     return errors::InvalidArgument(
   1261         "Duplicate calls to SetHostToDeviceMetadata with key ", key);
   1262   }
   1263   tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
   1264   SetTransfer(key, types, shapes, &transfer);
   1265   return Status::OK();
   1266 }
   1267 
   1268 Status XlaCompiler::GetHostComputeControlDependency(
   1269     const string& host_compute_name, xla::XlaOp* handle) {
   1270   const auto iter = host_compute_control_output_.find(host_compute_name);
   1271   if (iter == host_compute_control_output_.end()) {
   1272     return errors::InvalidArgument(
   1273         "No registered control handle for host compute Op '", host_compute_name,
   1274         "'");
   1275   } else {
   1276     *handle = iter->second;
   1277   }
   1278   return Status::OK();
   1279 }
   1280 
   1281 Status XlaCompiler::SetHostComputeControlDependency(
   1282     const string& host_compute_name, const xla::XlaOp& handle) {
   1283   if (host_compute_control_output_.find(host_compute_name) !=
   1284       host_compute_control_output_.end()) {
   1285     return errors::InvalidArgument(
   1286         "Duplicate control handles registered for for host compute Op ",
   1287         host_compute_name);
   1288   }
   1289   host_compute_control_output_[host_compute_name] = handle;
   1290   return Status::OK();
   1291 }
   1292 
   1293 void XlaCompiler::PushNodeTokenMapping() {
   1294   node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
   1295 }
   1296 
   1297 Status XlaCompiler::PopNodeTokenMapping() {
   1298   if (node_token_mapping_stack_.empty()) {
   1299     return errors::FailedPrecondition(
   1300         "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
   1301         "empty.");
   1302   }
   1303   node_token_mapping_stack_.pop();
   1304   return Status::OK();
   1305 }
   1306 
   1307 Status XlaCompiler::SetNodeToken(const string& node_name,
   1308                                  const xla::XlaOp& op) {
   1309   if (node_token_mapping_stack_.empty()) {
   1310     return errors::FailedPrecondition(
   1311         "Calling SetNodeToken() when node_token_mapping_stack_ is "
   1312         "empty.");
   1313   }
   1314   auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
   1315   if (!insert_result.second) {
   1316     return errors::FailedPrecondition("Token mapping already exists for node ",
   1317                                       node_name);
   1318   }
   1319   return Status::OK();
   1320 }
   1321 
   1322 xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
   1323   if (node_token_mapping_stack_.empty()) {
   1324     return errors::FailedPrecondition(
   1325         "Calling GetNodeToken() when node_token_mapping_stack_ is "
   1326         "empty.");
   1327   }
   1328   auto iter = node_token_mapping_stack_.top().find(node_name);
   1329   if (iter == node_token_mapping_stack_.top().end()) {
   1330     return errors::FailedPrecondition("Cannot find token mapping for node ",
   1331                                       node_name);
   1332   }
   1333   return iter->second;
   1334 }
   1335 
   1336 }  // namespace tensorflow
   1337