Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/user_computation.h"
     17 
     18 #include <algorithm>
     19 #include <set>
     20 #include <stack>
     21 #include <unordered_map>
     22 #include <utility>
     23 #include <vector>
     24 
     25 #include "tensorflow/compiler/xla/layout_util.h"
     26 #include "tensorflow/compiler/xla/literal_util.h"
     27 #include "tensorflow/compiler/xla/ptr_util.h"
     28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     31 #include "tensorflow/compiler/xla/service/shape_inference.h"
     32 #include "tensorflow/compiler/xla/shape_util.h"
     33 #include "tensorflow/compiler/xla/status_macros.h"
     34 #include "tensorflow/compiler/xla/types.h"
     35 #include "tensorflow/compiler/xla/util.h"
     36 #include "tensorflow/core/lib/core/errors.h"
     37 #include "tensorflow/core/lib/strings/str_util.h"
     38 #include "tensorflow/core/lib/strings/strcat.h"
     39 #include "tensorflow/core/lib/strings/stringprintf.h"
     40 #include "tensorflow/core/platform/logging.h"
     41 #include "tensorflow/core/platform/protobuf.h"
     42 
     43 namespace xla {
     44 namespace {
     45 
     46 HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
     47   switch (unop) {
     48     case UNOP_ABS:
     49       return HloOpcode::kAbs;
     50     case UNOP_CEIL:
     51       return HloOpcode::kCeil;
     52     case UNOP_COS:
     53       return HloOpcode::kCos;
     54     case UNOP_EXP:
     55       return HloOpcode::kExp;
     56     case UNOP_FLOOR:
     57       return HloOpcode::kFloor;
     58     case UNOP_IMAG:
     59       return HloOpcode::kImag;
     60     case UNOP_IS_FINITE:
     61       return HloOpcode::kIsFinite;
     62     case UNOP_LOG:
     63       return HloOpcode::kLog;
     64     case UNOP_NOT:
     65       return HloOpcode::kNot;
     66     case UNOP_NEGATE:
     67       return HloOpcode::kNegate;
     68     case UNOP_REAL:
     69       return HloOpcode::kReal;
     70     case UNOP_ROUND_NEAREST_AFZ:
     71       return HloOpcode::kRoundNearestAfz;
     72     case UNOP_SIGN:
     73       return HloOpcode::kSign;
     74     case UNOP_SIN:
     75       return HloOpcode::kSin;
     76     case UNOP_SORT:
     77       return HloOpcode::kSort;
     78     case UNOP_TANH:
     79       return HloOpcode::kTanh;
     80     default:
     81       LOG(FATAL) << "unhandled operation " << unop;
     82   }
     83 }
     84 
     85 HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) {
     86   switch (binop) {
     87     case BINOP_ATAN2:
     88       return HloOpcode::kAtan2;
     89     case BINOP_COMPLEX:
     90       return HloOpcode::kComplex;
     91     case BINOP_MUL:
     92       return HloOpcode::kMultiply;
     93     case BINOP_ADD:
     94       return HloOpcode::kAdd;
     95     case BINOP_SUB:
     96       return HloOpcode::kSubtract;
     97     case BINOP_DIV:
     98       return HloOpcode::kDivide;
     99     case BINOP_EQ:
    100       return HloOpcode::kEq;
    101     case BINOP_GE:
    102       return HloOpcode::kGe;
    103     case BINOP_GT:
    104       return HloOpcode::kGt;
    105     case BINOP_LE:
    106       return HloOpcode::kLe;
    107     case BINOP_LT:
    108       return HloOpcode::kLt;
    109     case BINOP_NE:
    110       return HloOpcode::kNe;
    111     case BINOP_MAX:
    112       return HloOpcode::kMaximum;
    113     case BINOP_MIN:
    114       return HloOpcode::kMinimum;
    115     case BINOP_POW:
    116       return HloOpcode::kPower;
    117     case BINOP_REM:
    118       return HloOpcode::kRemainder;
    119     case BINOP_OR:
    120       return HloOpcode::kOr;
    121     case BINOP_AND:
    122       return HloOpcode::kAnd;
    123     case BINOP_SHIFT_LEFT:
    124       return HloOpcode::kShiftLeft;
    125     case BINOP_SHIFT_RIGHT_ARITHMETIC:
    126       return HloOpcode::kShiftRightArithmetic;
    127     case BINOP_SHIFT_RIGHT_LOGICAL:
    128       return HloOpcode::kShiftRightLogical;
    129     default:
    130       LOG(FATAL) << "unhandled operation " << binop;
    131   }
    132 }
    133 
    134 HloOpcode TernaryOperationToHloOpcode(TernaryOperation triop) {
    135   switch (triop) {
    136     case TRIOP_CLAMP:
    137       return HloOpcode::kClamp;
    138     case TRIOP_SELECT:
    139       return HloOpcode::kSelect;
    140     default:
    141       LOG(FATAL) << "unhandled operation " << triop;
    142   }
    143 }
    144 
    145 HloOpcode VariadicOperationToHloOpcode(VariadicOperation varop) {
    146   switch (varop) {
    147     case VAROP_TUPLE:
    148       return HloOpcode::kTuple;
    149     default:
    150       LOG(FATAL) << "unhandled operation " << varop;
    151   }
    152 }
    153 
    154 }  // namespace
    155 
    156 /* static */ StatusOr<std::unique_ptr<UserComputation>>
    157 UserComputation::MakeWithRemapping(
    158     const SessionComputation& session_computation,
    159     const ComputationHandle& handle,
    160     const std::map<int64, ComputationHandle>& old_to_new) {
    161   auto user_computation =
    162       MakeUnique<UserComputation>(session_computation.name(), handle);
    163   {
    164     tensorflow::mutex_lock lock(user_computation->mutex_);
    165     user_computation->session_computation_ = session_computation;
    166     user_computation->next_handle_value_ =
    167         std::max_element(session_computation.requests().begin(),
    168                          session_computation.requests().end(),
    169                          [](const std::pair<int64, OperationRequest>& lhs,
    170                             const std::pair<int64, OperationRequest>& rhs) {
    171                            return lhs.first < rhs.first;
    172                          })
    173             ->first +
    174         1;
    175     TF_RETURN_IF_ERROR(user_computation->RemapEmbeddedComputations(old_to_new));
    176   }
    177 
    178   return std::move(user_computation);
    179 }
    180 
    181 UserComputation::UserComputation(const string& name,
    182                                  const ComputationHandle& handle)
    183     : name_(name), next_handle_value_(1) {
    184   *session_computation_.mutable_computation_handle() = handle;
    185   session_computation_.set_name(name);
    186 
    187   VLOG(1) << "New UserComputation \"" << name
    188           << "\", handle: " << handle.handle();
    189 }
    190 
    191 ComputationDataHandle UserComputation::CreateComputationDataHandle() {
    192   ComputationDataHandle handle;
    193   handle.set_handle(next_handle_value_);
    194   // Handles are used as Version values and *must* be assigned consecutively for
    195   // computation versioning to work.
    196   next_handle_value_++;
    197   return handle;
    198 }
    199 
    200 StatusOr<ComputationDataHandle> UserComputation::AddParameterInstruction(
    201     const ParameterRequest& parameter_request) {
    202   tensorflow::mutex_lock lock(mutex_);
    203 
    204   int64 parameter_number = parameter_request.parameter();
    205   if (parameters_.count(parameter_number) != 0) {
    206     return InvalidArgument("parameter %lld already registered",
    207                            parameter_number);
    208   }
    209   ComputationDataHandle handle = CreateComputationDataHandle();
    210 
    211   const Shape& validated_shape = parameter_request.shape();
    212   TF_RETURN_IF_ERROR(
    213       ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape));
    214 
    215   OperationRequest& request =
    216       (*session_computation_.mutable_requests())[handle.handle()];
    217   *request.mutable_output_handle() = handle;
    218   *request.mutable_output_shape() = validated_shape;
    219   *request.mutable_request()->mutable_parameter_request() = parameter_request;
    220 
    221   parameters_[parameter_number] = &request;
    222 
    223   VLOG(1) << "AddParameterInstruction (" << GetVersionedHandleInternal()
    224           << "), data handle " << handle.handle() << ": "
    225           << parameter_request.ShortDebugString();
    226   return handle;
    227 }
    228 
    229 Status UserComputation::AddSendInstruction(const SendRequest& send_request) {
    230   tensorflow::mutex_lock lock(mutex_);
    231 
    232   // Check if the operand of the instruction is valid.
    233   TF_RETURN_IF_ERROR(LookUpRequest(send_request.operand()).status());
    234 
    235   // No handle is returned, but a handle must be assigned to this instruction
    236   // for computation versioning.
    237   ComputationDataHandle handle = CreateComputationDataHandle();
    238   OperationRequest& request =
    239       (*session_computation_.mutable_requests())[handle.handle()];
    240   *request.mutable_output_handle() = handle;
    241   *request.mutable_output_shape() = ShapeUtil::MakeNil();
    242   *request.mutable_request()->mutable_send_request() = send_request;
    243 
    244   VLOG(1) << "AddSendInstruction (" << GetVersionedHandleInternal()
    245           << "), data handle " << handle.handle() << ": "
    246           << send_request.ShortDebugString();
    247   return Status::OK();
    248 }
    249 
    250 StatusOr<ComputationDataHandle> UserComputation::AddRecvInstruction(
    251     const RecvRequest& recv_request) {
    252   tensorflow::mutex_lock lock(mutex_);
    253 
    254   const Shape& shape = recv_request.shape();
    255   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
    256   ComputationDataHandle handle = CreateComputationDataHandle();
    257   OperationRequest& request =
    258       (*session_computation_.mutable_requests())[handle.handle()];
    259   *request.mutable_output_handle() = handle;
    260   *request.mutable_output_shape() = shape;
    261   *request.mutable_request()->mutable_recv_request() = recv_request;
    262 
    263   VLOG(1) << "AddRecvInstruction (" << GetVersionedHandleInternal()
    264           << "), data handle " << handle.handle() << ": "
    265           << recv_request.ShortDebugString();
    266   return handle;
    267 }
    268 
    269 StatusOr<ComputationDataHandle> UserComputation::AddPadInstruction(
    270     const PadRequest& pad_request) {
    271   tensorflow::mutex_lock lock(mutex_);
    272 
    273   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    274                       LookUpRequest(pad_request.operand()));
    275 
    276   TF_ASSIGN_OR_RETURN(const OperationRequest* padding_value,
    277                       LookUpRequest(pad_request.padding_value()));
    278 
    279   TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferPadShape(
    280                                                 operand->output_shape(),
    281                                                 padding_value->output_shape(),
    282                                                 pad_request.padding_config()));
    283 
    284   ComputationDataHandle handle = CreateComputationDataHandle();
    285   OperationRequest& request =
    286       (*session_computation_.mutable_requests())[handle.handle()];
    287   *request.mutable_output_handle() = handle;
    288   *request.mutable_output_shape() = inferred_shape;
    289   *request.mutable_request()->mutable_pad_request() = pad_request;
    290 
    291   VLOG(1) << "AddPadInstruction (" << GetVersionedHandleInternal()
    292           << "), data handle " << handle.handle() << ": "
    293           << pad_request.ShortDebugString();
    294   return handle;
    295 }
    296 
    297 StatusOr<ComputationDataHandle> UserComputation::AddConstantInstruction(
    298     const ConstantRequest& constant_request) {
    299   const Shape& validated_shape = constant_request.literal().shape();
    300   TF_RETURN_IF_ERROR(
    301       ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape));
    302 
    303   tensorflow::mutex_lock lock(mutex_);
    304 
    305   ComputationDataHandle handle = CreateComputationDataHandle();
    306 
    307   OperationRequest& request =
    308       (*session_computation_.mutable_requests())[handle.handle()];
    309   *request.mutable_output_handle() = handle;
    310   *request.mutable_output_shape() = validated_shape;
    311   *request.mutable_request()->mutable_constant_request() = constant_request;
    312 
    313   VLOG(1) << "AddConstantInstruction (" << GetVersionedHandleInternal()
    314           << "), data handle " << handle.handle();
    315   return handle;
    316 }
    317 
    318 StatusOr<ComputationDataHandle> UserComputation::AddGatherInstruction(
    319     const GatherRequest& gather_request) {
    320   tensorflow::mutex_lock lock(mutex_);
    321 
    322   TF_ASSIGN_OR_RETURN(const OperationRequest* input_request,
    323                       LookUpRequest(gather_request.input()));
    324   TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request,
    325                       LookUpRequest(gather_request.gather_indices()));
    326 
    327   TF_ASSIGN_OR_RETURN(
    328       Shape shape,
    329       ShapeInference::InferGatherShape(
    330           input_request->output_shape(), gather_indices_request->output_shape(),
    331           gather_request.dimension_numbers(),
    332           AsInt64Slice(gather_request.window_bounds())));
    333 
    334   const ComputationDataHandle handle = CreateComputationDataHandle();
    335 
    336   OperationRequest& request =
    337       (*session_computation_.mutable_requests())[handle.handle()];
    338   *request.mutable_output_handle() = handle;
    339   *request.mutable_output_shape() = shape;
    340   *request.mutable_request()->mutable_gather_request() = gather_request;
    341 
    342   VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal()
    343           << "), data handle " << handle.handle() << ": "
    344           << gather_request.ShortDebugString();
    345   return handle;
    346 }
    347 
    348 StatusOr<ComputationDataHandle> UserComputation::AddGetTupleElementInstruction(
    349     const GetTupleElementRequest& get_tuple_element_request) {
    350   tensorflow::mutex_lock lock(mutex_);
    351 
    352   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    353                       LookUpRequest(get_tuple_element_request.operand()));
    354   if (!ShapeUtil::IsTuple(operand->output_shape())) {
    355     return InvalidArgument(
    356         "Operand to GetTupleElement() is not a tuple; got %s",
    357         ShapeUtil::HumanString(operand->output_shape()).c_str());
    358   }
    359   Shape element_shape = ShapeUtil::GetTupleElementShape(
    360       operand->output_shape(), get_tuple_element_request.index());
    361 
    362   ComputationDataHandle handle = CreateComputationDataHandle();
    363 
    364   OperationRequest& request =
    365       (*session_computation_.mutable_requests())[handle.handle()];
    366   *request.mutable_output_handle() = handle;
    367   *request.mutable_output_shape() = element_shape;
    368   *request.mutable_request()->mutable_get_tuple_element_request() =
    369       get_tuple_element_request;
    370 
    371   VLOG(1) << "AddGetTupleElementInstruction (" << GetVersionedHandleInternal()
    372           << "), data handle " << handle.handle() << ": "
    373           << get_tuple_element_request.ShortDebugString();
    374   return handle;
    375 }
    376 
    377 Status UserComputation::AddTraceInstruction(const TraceRequest& trace_request) {
    378   tensorflow::mutex_lock lock(mutex_);
    379 
    380   // Verify that the operand index is valid.
    381   TF_RETURN_IF_ERROR(LookUpRequest(trace_request.operand()).status());
    382 
    383   ComputationDataHandle handle = CreateComputationDataHandle();
    384   OperationRequest& request =
    385       (*session_computation_.mutable_requests())[handle.handle()];
    386   *request.mutable_output_handle() = handle;
    387   *request.mutable_output_shape() = ShapeUtil::MakeNil();
    388   *request.mutable_request()->mutable_trace_request() = trace_request;
    389 
    390   VLOG(1) << "AddTraceInstruction (" << GetVersionedHandleInternal()
    391           << "), data handle " << handle.handle() << ": "
    392           << trace_request.ShortDebugString();
    393   return Status::OK();
    394 }
    395 
    396 StatusOr<ComputationDataHandle> UserComputation::AddRngInstruction(
    397     const RngRequest& rng_request) {
    398   tensorflow::mutex_lock lock(mutex_);
    399 
    400   // Check the number of parameters per RNG distribution.
    401   switch (rng_request.distribution()) {
    402     case RandomDistribution::RNG_NORMAL:
    403     case RandomDistribution::RNG_UNIFORM:
    404       if (rng_request.parameter_size() != 2) {
    405         return InvalidArgument(
    406             "RNG distribution (%s) expects 2 parameters, but got %d",
    407             RandomDistribution_Name(rng_request.distribution()).c_str(),
    408             rng_request.parameter_size());
    409       }
    410       break;
    411     default:
    412       LOG(FATAL) << "unhandled distribution " << rng_request.distribution();
    413   }
    414 
    415   // Verify that the parameter indices are valid;
    416   for (const ComputationDataHandle& param : rng_request.parameter()) {
    417     TF_RETURN_IF_ERROR(LookUpRequest(param).status());
    418   }
    419   const Shape& validated_shape = rng_request.shape();
    420   TF_RETURN_IF_ERROR(
    421       ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape));
    422 
    423   ComputationDataHandle handle = CreateComputationDataHandle();
    424 
    425   OperationRequest& request =
    426       (*session_computation_.mutable_requests())[handle.handle()];
    427   *request.mutable_output_handle() = handle;
    428   *request.mutable_output_shape() = validated_shape;
    429   *request.mutable_request()->mutable_rng_request() = rng_request;
    430 
    431   VLOG(1) << "AddRngInstruction (" << GetVersionedHandleInternal()
    432           << "), data handle " << handle.handle() << ": "
    433           << rng_request.ShortDebugString();
    434   return handle;
    435 }
    436 
    437 StatusOr<ComputationDataHandle> UserComputation::AddMapInstruction(
    438     const MapRequest& map_request,
    439     const UserComputation& to_apply_computation) {
    440   tensorflow::mutex_lock lock(mutex_);
    441 
    442   std::vector<const Shape*> operand_shapes;
    443   for (const ComputationDataHandle& handle : map_request.operands()) {
    444     TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle));
    445     operand_shapes.push_back(&operand->output_shape());
    446   }
    447 
    448   VersionedComputationHandle::Version to_apply_version =
    449       to_apply_computation.version();
    450   TF_ASSIGN_OR_RETURN(
    451       std::shared_ptr<const ProgramShape> to_apply_program_shape,
    452       to_apply_computation.ComputeProgramShape(to_apply_version));
    453   TF_ASSIGN_OR_RETURN(
    454       Shape inferred_shape,
    455       ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape,
    456                                     AsInt64Slice(map_request.dimensions())));
    457 
    458   ComputationDataHandle handle = CreateComputationDataHandle();
    459 
    460   OperationRequest& request =
    461       (*session_computation_.mutable_requests())[handle.handle()];
    462   *request.mutable_output_handle() = handle;
    463   *request.mutable_output_shape() = inferred_shape;
    464   request.add_embedded_computation_versions(to_apply_version);
    465   *request.mutable_request()->mutable_map_request() = map_request;
    466 
    467   VLOG(1) << "AddMapInstruction (" << GetVersionedHandleInternal()
    468           << "), data handle " << handle.handle() << ": "
    469           << map_request.ShortDebugString();
    470   return handle;
    471 }
    472 
    473 StatusOr<ComputationDataHandle> UserComputation::AddReduceInstruction(
    474     const ReduceRequest& reduce_request,
    475     const UserComputation& to_apply_computation) {
    476   tensorflow::mutex_lock lock(mutex_);
    477 
    478   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    479                       LookUpRequest(reduce_request.operand()));
    480   TF_ASSIGN_OR_RETURN(const OperationRequest* init_value,
    481                       LookUpRequest(reduce_request.init_value()));
    482 
    483   VersionedComputationHandle::Version to_apply_version =
    484       to_apply_computation.version();
    485   TF_ASSIGN_OR_RETURN(
    486       std::shared_ptr<const ProgramShape> to_apply_program_shape,
    487       to_apply_computation.ComputeProgramShape(to_apply_version));
    488 
    489   TF_ASSIGN_OR_RETURN(
    490       Shape inferred_shape,
    491       ShapeInference::InferReduceShape(
    492           operand->output_shape(), init_value->output_shape(),
    493           AsInt64Slice(reduce_request.dimensions()), *to_apply_program_shape));
    494 
    495   ComputationDataHandle handle = CreateComputationDataHandle();
    496 
    497   OperationRequest& request =
    498       (*session_computation_.mutable_requests())[handle.handle()];
    499   *request.mutable_output_handle() = handle;
    500   *request.mutable_output_shape() = inferred_shape;
    501   request.add_embedded_computation_versions(to_apply_version);
    502   *request.mutable_request()->mutable_reduce_request() = reduce_request;
    503 
    504   VLOG(1) << "AddReduceInstruction (" << GetVersionedHandleInternal()
    505           << "), data handle " << handle.handle() << ": "
    506           << reduce_request.ShortDebugString();
    507   return handle;
    508 }
    509 
    510 StatusOr<ComputationDataHandle>
    511 UserComputation::AddBatchNormTrainingInstruction(
    512     const BatchNormTrainingRequest& batch_norm_training_request) {
    513   tensorflow::mutex_lock lock(mutex_);
    514 
    515   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    516                       LookUpRequest(batch_norm_training_request.operand()));
    517 
    518   TF_ASSIGN_OR_RETURN(const OperationRequest* scale,
    519                       LookUpRequest(batch_norm_training_request.scale()));
    520 
    521   TF_ASSIGN_OR_RETURN(const OperationRequest* offset,
    522                       LookUpRequest(batch_norm_training_request.offset()));
    523 
    524   ComputationDataHandle handle = CreateComputationDataHandle();
    525 
    526   OperationRequest& request =
    527       (*session_computation_.mutable_requests())[handle.handle()];
    528 
    529   TF_ASSIGN_OR_RETURN(
    530       Shape inferred_shape,
    531       ShapeInference::InferBatchNormTrainingShape(
    532           operand->output_shape(), scale->output_shape(),
    533           offset->output_shape(), batch_norm_training_request.feature_index()));
    534 
    535   *request.mutable_output_shape() = inferred_shape;
    536 
    537   *request.mutable_output_handle() = handle;
    538 
    539   *request.mutable_request()->mutable_batch_norm_training_request() =
    540       batch_norm_training_request;
    541 
    542   VLOG(1) << "AddBatchNormTrainingInstruction (" << GetVersionedHandleInternal()
    543           << "), data handle " << handle.handle() << ": "
    544           << batch_norm_training_request.ShortDebugString();
    545 
    546   return handle;
    547 }
    548 
    549 StatusOr<ComputationDataHandle>
    550 UserComputation::AddBatchNormInferenceInstruction(
    551     const BatchNormInferenceRequest& batch_norm_inference_request) {
    552   tensorflow::mutex_lock lock(mutex_);
    553 
    554   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    555                       LookUpRequest(batch_norm_inference_request.operand()));
    556 
    557   TF_ASSIGN_OR_RETURN(const OperationRequest* scale,
    558                       LookUpRequest(batch_norm_inference_request.scale()));
    559 
    560   TF_ASSIGN_OR_RETURN(const OperationRequest* offset,
    561                       LookUpRequest(batch_norm_inference_request.offset()));
    562 
    563   TF_ASSIGN_OR_RETURN(const OperationRequest* mean,
    564                       LookUpRequest(batch_norm_inference_request.mean()));
    565 
    566   TF_ASSIGN_OR_RETURN(const OperationRequest* variance,
    567                       LookUpRequest(batch_norm_inference_request.variance()));
    568 
    569   ComputationDataHandle handle = CreateComputationDataHandle();
    570 
    571   OperationRequest& request =
    572       (*session_computation_.mutable_requests())[handle.handle()];
    573 
    574   TF_ASSIGN_OR_RETURN(Shape inferred_shape,
    575                       ShapeInference::InferBatchNormInferenceShape(
    576                           operand->output_shape(), scale->output_shape(),
    577                           offset->output_shape(), mean->output_shape(),
    578                           variance->output_shape(),
    579                           batch_norm_inference_request.feature_index()));
    580 
    581   *request.mutable_output_shape() = inferred_shape;
    582 
    583   *request.mutable_output_handle() = handle;
    584 
    585   *request.mutable_request()->mutable_batch_norm_inference_request() =
    586       batch_norm_inference_request;
    587 
    588   VLOG(1) << "AddBatchNormInferenceInstruction ("
    589           << GetVersionedHandleInternal() << "), data handle "
    590           << handle.handle() << ": "
    591           << batch_norm_inference_request.ShortDebugString();
    592 
    593   return handle;
    594 }
    595 
    596 StatusOr<ComputationDataHandle> UserComputation::AddBatchNormGradInstruction(
    597     const BatchNormGradRequest& batch_norm_grad_request) {
    598   tensorflow::mutex_lock lock(mutex_);
    599 
    600   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    601                       LookUpRequest(batch_norm_grad_request.operand()));
    602 
    603   TF_ASSIGN_OR_RETURN(const OperationRequest* scale,
    604                       LookUpRequest(batch_norm_grad_request.scale()));
    605 
    606   TF_ASSIGN_OR_RETURN(const OperationRequest* mean,
    607                       LookUpRequest(batch_norm_grad_request.mean()));
    608 
    609   TF_ASSIGN_OR_RETURN(const OperationRequest* variance,
    610                       LookUpRequest(batch_norm_grad_request.variance()));
    611 
    612   TF_ASSIGN_OR_RETURN(const OperationRequest* grad_output,
    613                       LookUpRequest(batch_norm_grad_request.grad_output()));
    614 
    615   ComputationDataHandle handle = CreateComputationDataHandle();
    616 
    617   OperationRequest& request =
    618       (*session_computation_.mutable_requests())[handle.handle()];
    619 
    620   TF_ASSIGN_OR_RETURN(
    621       Shape inferred_shape,
    622       ShapeInference::InferBatchNormGradShape(
    623           operand->output_shape(), scale->output_shape(), mean->output_shape(),
    624           variance->output_shape(), grad_output->output_shape(),
    625           batch_norm_grad_request.feature_index()));
    626 
    627   *request.mutable_output_shape() = inferred_shape;
    628 
    629   *request.mutable_output_handle() = handle;
    630 
    631   *request.mutable_request()->mutable_batch_norm_grad_request() =
    632       batch_norm_grad_request;
    633 
    634   VLOG(1) << "AddBatchNormGradInstruction (" << GetVersionedHandleInternal()
    635           << "), data handle " << handle.handle() << ": "
    636           << batch_norm_grad_request.ShortDebugString();
    637 
    638   return handle;
    639 }
    640 
    641 StatusOr<ComputationDataHandle> UserComputation::AddReduceWindowInstruction(
    642     const ReduceWindowRequest& reduce_window_request,
    643     const UserComputation& to_apply_computation) {
    644   tensorflow::mutex_lock lock(mutex_);
    645 
    646   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    647                       LookUpRequest(reduce_window_request.operand()));
    648   TF_ASSIGN_OR_RETURN(const OperationRequest* init_value,
    649                       LookUpRequest(reduce_window_request.init_value()));
    650 
    651   VersionedComputationHandle::Version to_apply_version =
    652       to_apply_computation.version();
    653   TF_ASSIGN_OR_RETURN(
    654       std::shared_ptr<const ProgramShape> to_apply_program_shape,
    655       to_apply_computation.ComputeProgramShape(to_apply_version));
    656 
    657   TF_ASSIGN_OR_RETURN(
    658       Shape inferred_shape,
    659       ShapeInference::InferReduceWindowShape(
    660           operand->output_shape(), init_value->output_shape(),
    661           reduce_window_request.window(), *to_apply_program_shape));
    662 
    663   ComputationDataHandle handle = CreateComputationDataHandle();
    664 
    665   OperationRequest& request =
    666       (*session_computation_.mutable_requests())[handle.handle()];
    667   *request.mutable_output_handle() = handle;
    668   *request.mutable_output_shape() = inferred_shape;
    669   request.add_embedded_computation_versions(to_apply_version);
    670   *request.mutable_request()->mutable_reduce_window_request() =
    671       reduce_window_request;
    672 
    673   VLOG(1) << "AddReduceWindowInstruction (" << GetVersionedHandleInternal()
    674           << "), data handle " << handle.handle() << ": "
    675           << reduce_window_request.ShortDebugString();
    676   return handle;
    677 }
    678 
    679 StatusOr<ComputationDataHandle> UserComputation::AddSelectAndScatterInstruction(
    680     const SelectAndScatterRequest& select_and_scatter_request,
    681     const UserComputation& select_computation,
    682     const UserComputation& scatter_computation) {
    683   tensorflow::mutex_lock lock(mutex_);
    684 
    685   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    686                       LookUpRequest(select_and_scatter_request.operand()));
    687   TF_ASSIGN_OR_RETURN(const OperationRequest* source,
    688                       LookUpRequest(select_and_scatter_request.source()));
    689   TF_ASSIGN_OR_RETURN(const OperationRequest* init_value,
    690                       LookUpRequest(select_and_scatter_request.init_value()));
    691 
    692   VersionedComputationHandle::Version select_version =
    693       select_computation.version();
    694   TF_ASSIGN_OR_RETURN(std::shared_ptr<const ProgramShape> select_program_shape,
    695                       select_computation.ComputeProgramShape(select_version));
    696   VersionedComputationHandle::Version scatter_version =
    697       scatter_computation.version();
    698   TF_ASSIGN_OR_RETURN(std::shared_ptr<const ProgramShape> scatter_program_shape,
    699                       scatter_computation.ComputeProgramShape(scatter_version));
    700 
    701   TF_ASSIGN_OR_RETURN(
    702       Shape inferred_shape,
    703       ShapeInference::InferSelectAndScatterShape(
    704           operand->output_shape(), *select_program_shape,
    705           select_and_scatter_request.window(), source->output_shape(),
    706           init_value->output_shape(), *scatter_program_shape));
    707 
    708   ComputationDataHandle handle = CreateComputationDataHandle();
    709 
    710   OperationRequest& request =
    711       (*session_computation_.mutable_requests())[handle.handle()];
    712   *request.mutable_output_handle() = handle;
    713   *request.mutable_output_shape() = inferred_shape;
    714   request.add_embedded_computation_versions(select_version);
    715   request.add_embedded_computation_versions(scatter_version);
    716   *request.mutable_request()->mutable_select_and_scatter_request() =
    717       select_and_scatter_request;
    718 
    719   VLOG(1) << "AddSelectAndScatterInstruction (" << GetVersionedHandleInternal()
    720           << "), data handle " << handle.handle() << ": "
    721           << select_and_scatter_request.ShortDebugString();
    722   return handle;
    723 }
    724 
    725 StatusOr<ComputationDataHandle> UserComputation::AddReverseInstruction(
    726     const ReverseRequest& reverse_request) {
    727   tensorflow::mutex_lock lock(mutex_);
    728 
    729   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    730                       LookUpRequest(reverse_request.operand()));
    731   TF_ASSIGN_OR_RETURN(
    732       Shape inferred_shape,
    733       ShapeInference::InferReverseShape(
    734           operand->output_shape(), AsInt64Slice(reverse_request.dimensions())));
    735 
    736   ComputationDataHandle handle = CreateComputationDataHandle();
    737   OperationRequest& request =
    738       (*session_computation_.mutable_requests())[handle.handle()];
    739   *request.mutable_output_handle() = handle;
    740   *request.mutable_output_shape() = inferred_shape;
    741   *request.mutable_request()->mutable_reverse_request() = reverse_request;
    742   VLOG(1) << "AddReverseInstruction (" << GetVersionedHandleInternal()
    743           << "), data handle " << handle.handle() << ": "
    744           << reverse_request.ShortDebugString();
    745   return handle;
    746 }
    747 
    748 StatusOr<ComputationDataHandle> UserComputation::AddWhileInstruction(
    749     const WhileRequest& while_request,
    750     const UserComputation& condition_computation,
    751     const UserComputation& body_computation) {
    752   tensorflow::mutex_lock lock(mutex_);
    753 
    754   TF_ASSIGN_OR_RETURN(const OperationRequest* init,
    755                       LookUpRequest(while_request.init()));
    756 
    757   VersionedComputationHandle::Version condition_version =
    758       condition_computation.version();
    759   TF_ASSIGN_OR_RETURN(
    760       std::shared_ptr<const ProgramShape> condition_program_shape,
    761       condition_computation.ComputeProgramShape(condition_version));
    762 
    763   VersionedComputationHandle::Version body_version = body_computation.version();
    764   TF_ASSIGN_OR_RETURN(std::shared_ptr<const ProgramShape> body_program_shape,
    765                       body_computation.ComputeProgramShape(body_version));
    766 
    767   TF_ASSIGN_OR_RETURN(
    768       Shape inferred_shape,
    769       ShapeInference::InferWhileShape(
    770           *condition_program_shape, *body_program_shape, init->output_shape()));
    771 
    772   ComputationDataHandle handle = CreateComputationDataHandle();
    773 
    774   OperationRequest& request =
    775       (*session_computation_.mutable_requests())[handle.handle()];
    776   *request.mutable_output_handle() = handle;
    777   *request.mutable_output_shape() = inferred_shape;
    778   request.add_embedded_computation_versions(condition_version);
    779   request.add_embedded_computation_versions(body_version);
    780   *request.mutable_request()->mutable_while_request() = while_request;
    781 
    782   VLOG(1) << "AddWhileInstruction (" << GetVersionedHandleInternal()
    783           << "), data handle " << handle.handle() << ": "
    784           << while_request.ShortDebugString();
    785   return handle;
    786 }
    787 
    788 StatusOr<ComputationDataHandle> UserComputation::AddConditionalInstruction(
    789     const ConditionalRequest& conditional_request,
    790     const UserComputation& true_computation,
    791     const UserComputation& false_computation) {
    792   tensorflow::mutex_lock lock(mutex_);
    793 
    794   TF_ASSIGN_OR_RETURN(const OperationRequest* pred,
    795                       LookUpRequest(conditional_request.predicate()));
    796   TF_ASSIGN_OR_RETURN(const OperationRequest* true_operand,
    797                       LookUpRequest(conditional_request.true_operand()));
    798   TF_ASSIGN_OR_RETURN(const OperationRequest* false_operand,
    799                       LookUpRequest(conditional_request.false_operand()));
    800 
    801   VersionedComputationHandle::Version true_computation_version =
    802       true_computation.version();
    803   TF_ASSIGN_OR_RETURN(
    804       std::shared_ptr<const ProgramShape> true_computation_shape,
    805       true_computation.ComputeProgramShape(true_computation_version));
    806 
    807   VersionedComputationHandle::Version false_computation_version =
    808       false_computation.version();
    809   TF_ASSIGN_OR_RETURN(
    810       std::shared_ptr<const ProgramShape> false_computation_shape,
    811       false_computation.ComputeProgramShape(false_computation_version));
    812 
    813   TF_ASSIGN_OR_RETURN(Shape inferred_shape,
    814                       ShapeInference::InferConditionalShape(
    815                           pred->output_shape(), true_operand->output_shape(),
    816                           false_operand->output_shape(),
    817                           *true_computation_shape, *false_computation_shape));
    818 
    819   ComputationDataHandle handle = CreateComputationDataHandle();
    820 
    821   OperationRequest& request =
    822       (*session_computation_.mutable_requests())[handle.handle()];
    823   *request.mutable_output_handle() = handle;
    824   *request.mutable_output_shape() = inferred_shape;
    825   request.add_embedded_computation_versions(true_computation_version);
    826   request.add_embedded_computation_versions(false_computation_version);
    827   *request.mutable_request()->mutable_conditional_request() =
    828       conditional_request;
    829 
    830   VLOG(1) << "AddConditionalInstruction (" << GetVersionedHandleInternal()
    831           << "), data handle " << handle.handle() << ": "
    832           << conditional_request.ShortDebugString();
    833   return handle;
    834 }
    835 
    836 StatusOr<ComputationDataHandle> UserComputation::AddBroadcastInstruction(
    837     const BroadcastRequest& broadcast_request) {
    838   tensorflow::mutex_lock lock(mutex_);
    839 
    840   // Fetches and validates the operand.
    841   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    842                       LookUpRequest(broadcast_request.operand()));
    843   TF_ASSIGN_OR_RETURN(Shape inferred_shape,
    844                       ShapeInference::InferBroadcastShape(
    845                           operand->output_shape(),
    846                           AsInt64Slice(broadcast_request.broadcast_sizes())));
    847 
    848   ComputationDataHandle handle = CreateComputationDataHandle();
    849   OperationRequest& request =
    850       (*session_computation_.mutable_requests())[handle.handle()];
    851   *request.mutable_output_handle() = handle;
    852   *request.mutable_output_shape() = inferred_shape;
    853   *request.mutable_request()->mutable_broadcast_request() = broadcast_request;
    854 
    855   VLOG(1) << "AddBroadcastInstruction (" << GetVersionedHandleInternal()
    856           << "), data handle " << handle.handle() << ": "
    857           << broadcast_request.ShortDebugString();
    858   return handle;
    859 }
    860 
    861 StatusOr<ComputationDataHandle> UserComputation::AddReshapeInstruction(
    862     const ReshapeRequest& reshape_request) {
    863   tensorflow::mutex_lock lock(mutex_);
    864 
    865   // Fetches and validates the operand.
    866   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    867                       LookUpRequest(reshape_request.operand()));
    868 
    869   TF_ASSIGN_OR_RETURN(
    870       Shape inferred_shape,
    871       ShapeInference::InferReshapeShape(
    872           operand->output_shape(), AsInt64Slice(reshape_request.dimensions()),
    873           AsInt64Slice(reshape_request.new_sizes())));
    874 
    875   ComputationDataHandle handle = CreateComputationDataHandle();
    876 
    877   OperationRequest& request =
    878       (*session_computation_.mutable_requests())[handle.handle()];
    879   *request.mutable_output_handle() = handle;
    880   *request.mutable_output_shape() = inferred_shape;
    881   *request.mutable_request()->mutable_reshape_request() = reshape_request;
    882 
    883   VLOG(1) << "AddReshapeInstruction (" << GetVersionedHandleInternal()
    884           << "), data handle " << handle.handle() << ": "
    885           << reshape_request.ShortDebugString();
    886   return handle;
    887 }
    888 
    889 StatusOr<ComputationDataHandle> UserComputation::AddTransposeInstruction(
    890     const TransposeRequest& transpose_request) {
    891   tensorflow::mutex_lock lock(mutex_);
    892 
    893   // Fetches and validates the operand.
    894   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    895                       LookUpRequest(transpose_request.operand()));
    896 
    897   TF_ASSIGN_OR_RETURN(Shape inferred_shape,
    898                       ShapeInference::InferTransposeShape(
    899                           operand->output_shape(),
    900                           AsInt64Slice(transpose_request.dimensions())));
    901 
    902   ComputationDataHandle handle = CreateComputationDataHandle();
    903 
    904   OperationRequest& request =
    905       (*session_computation_.mutable_requests())[handle.handle()];
    906   *request.mutable_output_handle() = handle;
    907   *request.mutable_output_shape() = inferred_shape;
    908   *request.mutable_request()->mutable_transpose_request() = transpose_request;
    909 
    910   VLOG(1) << "AddTransposeInstruction (" << GetVersionedHandleInternal()
    911           << "), data handle " << handle.handle() << ": "
    912           << transpose_request.ShortDebugString();
    913   return handle;
    914 }
    915 
    916 StatusOr<ComputationDataHandle> UserComputation::AddSliceInstruction(
    917     const SliceRequest& slice_request) {
    918   tensorflow::mutex_lock lock(mutex_);
    919 
    920   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    921                       LookUpRequest(slice_request.operand()));
    922 
    923   TF_ASSIGN_OR_RETURN(
    924       Shape new_shape,
    925       ShapeInference::InferSliceShape(
    926           operand->output_shape(), AsInt64Slice(slice_request.start_indices()),
    927           AsInt64Slice(slice_request.limit_indices()),
    928           AsInt64Slice(slice_request.strides())));
    929 
    930   ComputationDataHandle handle = CreateComputationDataHandle();
    931 
    932   OperationRequest& request =
    933       (*session_computation_.mutable_requests())[handle.handle()];
    934   *request.mutable_output_handle() = handle;
    935   *request.mutable_output_shape() = new_shape;
    936   *request.mutable_request()->mutable_slice_request() = slice_request;
    937 
    938   VLOG(1) << "AddSliceInstruction (" << GetVersionedHandleInternal()
    939           << "), data handle " << handle.handle() << ": "
    940           << slice_request.ShortDebugString();
    941   return handle;
    942 }
    943 
    944 StatusOr<ComputationDataHandle> UserComputation::AddDynamicSliceInstruction(
    945     const DynamicSliceRequest& dynamic_slice_request) {
    946   tensorflow::mutex_lock lock(mutex_);
    947 
    948   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    949                       LookUpRequest(dynamic_slice_request.operand()));
    950 
    951   TF_ASSIGN_OR_RETURN(const OperationRequest* start_indices,
    952                       LookUpRequest(dynamic_slice_request.start_indices()));
    953 
    954   TF_ASSIGN_OR_RETURN(
    955       Shape new_shape,
    956       ShapeInference::InferDynamicSliceShape(
    957           operand->output_shape(), start_indices->output_shape(),
    958           AsInt64Slice(dynamic_slice_request.slice_sizes())));
    959 
    960   ComputationDataHandle handle = CreateComputationDataHandle();
    961 
    962   OperationRequest& request =
    963       (*session_computation_.mutable_requests())[handle.handle()];
    964   *request.mutable_output_handle() = handle;
    965   *request.mutable_output_shape() = new_shape;
    966   *request.mutable_request()->mutable_dynamic_slice_request() =
    967       dynamic_slice_request;
    968 
    969   VLOG(1) << "AddDynamicSliceInstruction (" << GetVersionedHandleInternal()
    970           << "), data handle " << handle.handle() << ": "
    971           << dynamic_slice_request.ShortDebugString();
    972   return handle;
    973 }
    974 
    975 StatusOr<ComputationDataHandle>
    976 UserComputation::AddDynamicUpdateSliceInstruction(
    977     const DynamicUpdateSliceRequest& dynamic_update_slice_request) {
    978   tensorflow::mutex_lock lock(mutex_);
    979 
    980   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
    981                       LookUpRequest(dynamic_update_slice_request.operand()));
    982 
    983   TF_ASSIGN_OR_RETURN(const OperationRequest* update,
    984                       LookUpRequest(dynamic_update_slice_request.update()));
    985 
    986   TF_ASSIGN_OR_RETURN(
    987       const OperationRequest* start_indices,
    988       LookUpRequest(dynamic_update_slice_request.start_indices()));
    989 
    990   TF_ASSIGN_OR_RETURN(Shape new_shape,
    991                       ShapeInference::InferDynamicUpdateSliceShape(
    992                           operand->output_shape(), update->output_shape(),
    993                           start_indices->output_shape()));
    994 
    995   ComputationDataHandle handle = CreateComputationDataHandle();
    996 
    997   OperationRequest& request =
    998       (*session_computation_.mutable_requests())[handle.handle()];
    999   *request.mutable_output_handle() = handle;
   1000   *request.mutable_output_shape() = new_shape;
   1001   *request.mutable_request()->mutable_dynamic_update_slice_request() =
   1002       dynamic_update_slice_request;
   1003 
   1004   VLOG(1) << "AddDynamicUpdateSliceInstruction ("
   1005           << GetVersionedHandleInternal() << "), data handle "
   1006           << handle.handle() << ": "
   1007           << dynamic_update_slice_request.ShortDebugString();
   1008   return handle;
   1009 }
   1010 
   1011 StatusOr<ComputationDataHandle> UserComputation::AddConcatenateInstruction(
   1012     const ConcatenateRequest& concatenate_request) {
   1013   tensorflow::mutex_lock lock(mutex_);
   1014 
   1015   std::vector<const Shape*> operand_shapes;
   1016   for (const ComputationDataHandle& handle : concatenate_request.operands()) {
   1017     TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle));
   1018     operand_shapes.push_back(&operand->output_shape());
   1019   }
   1020 
   1021   TF_ASSIGN_OR_RETURN(Shape new_shape,
   1022                       ShapeInference::InferConcatOpShape(
   1023                           operand_shapes, concatenate_request.dimension()));
   1024 
   1025   ComputationDataHandle handle = CreateComputationDataHandle();
   1026 
   1027   OperationRequest& request =
   1028       (*session_computation_.mutable_requests())[handle.handle()];
   1029   *request.mutable_output_handle() = handle;
   1030   *request.mutable_output_shape() = new_shape;
   1031   *request.mutable_request()->mutable_concatenate_request() =
   1032       concatenate_request;
   1033 
   1034   VLOG(1) << "AddConcatenateInstruction (" << GetVersionedHandleInternal()
   1035           << "), data handle " << handle.handle() << ": "
   1036           << concatenate_request.ShortDebugString();
   1037   return handle;
   1038 }
   1039 
   1040 StatusOr<ComputationDataHandle> UserComputation::AddConvertInstruction(
   1041     const ConvertRequest& convert_request) {
   1042   tensorflow::mutex_lock lock(mutex_);
   1043 
   1044   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
   1045                       LookUpRequest(convert_request.operand()));
   1046 
   1047   TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape(
   1048                                            operand->output_shape(),
   1049                                            convert_request.new_element_type()));
   1050 
   1051   ComputationDataHandle handle = CreateComputationDataHandle();
   1052 
   1053   OperationRequest& request =
   1054       (*session_computation_.mutable_requests())[handle.handle()];
   1055   *request.mutable_output_handle() = handle;
   1056   *request.mutable_output_shape() = new_shape;
   1057   *request.mutable_request()->mutable_convert_request() = convert_request;
   1058 
   1059   VLOG(1) << "AddConvertInstruction (" << GetVersionedHandleInternal()
   1060           << "), data handle " << handle.handle() << ": "
   1061           << convert_request.ShortDebugString();
   1062   return handle;
   1063 }
   1064 
   1065 StatusOr<ComputationDataHandle> UserComputation::AddBitcastConvertInstruction(
   1066     const ConvertRequest& convert_request) {
   1067   tensorflow::mutex_lock lock(mutex_);
   1068 
   1069   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
   1070                       LookUpRequest(convert_request.operand()));
   1071 
   1072   TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape(
   1073                                            operand->output_shape(),
   1074                                            convert_request.new_element_type()));
   1075 
   1076   ComputationDataHandle handle = CreateComputationDataHandle();
   1077 
   1078   OperationRequest& request =
   1079       (*session_computation_.mutable_requests())[handle.handle()];
   1080   *request.mutable_output_handle() = handle;
   1081   *request.mutable_output_shape() = new_shape;
   1082   *request.mutable_request()->mutable_bitcast_convert_request() =
   1083       convert_request;
   1084 
   1085   VLOG(1) << "AddBitcastConvertInstruction (" << GetVersionedHandleInternal()
   1086           << "), data handle " << handle.handle() << ": "
   1087           << convert_request.ShortDebugString();
   1088   return handle;
   1089 }
   1090 
   1091 StatusOr<ComputationDataHandle> UserComputation::AddReducePrecisionInstruction(
   1092     const ReducePrecisionRequest& reduce_precision_request) {
   1093   tensorflow::mutex_lock lock(mutex_);
   1094 
   1095   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
   1096                       LookUpRequest(reduce_precision_request.operand()));
   1097 
   1098   TF_ASSIGN_OR_RETURN(
   1099       Shape new_shape,
   1100       ShapeInference::InferReducePrecisionShape(
   1101           operand->output_shape(), reduce_precision_request.exponent_bits(),
   1102           reduce_precision_request.mantissa_bits()));
   1103 
   1104   ComputationDataHandle handle = CreateComputationDataHandle();
   1105 
   1106   OperationRequest& request =
   1107       (*session_computation_.mutable_requests())[handle.handle()];
   1108   *request.mutable_output_handle() = handle;
   1109   *request.mutable_output_shape() = new_shape;
   1110   *request.mutable_request()->mutable_reduce_precision_request() =
   1111       reduce_precision_request;
   1112 
   1113   VLOG(1) << "AddReducePrecisionInstruction (" << GetVersionedHandleInternal()
   1114           << "), data handle " << handle.handle() << ": "
   1115           << reduce_precision_request.ShortDebugString();
   1116   return handle;
   1117 }
   1118 
   1119 StatusOr<ComputationDataHandle> UserComputation::AddConvolveInstruction(
   1120     const ConvolveRequest& convolve_request) {
   1121   tensorflow::mutex_lock lock(mutex_);
   1122 
   1123   TF_ASSIGN_OR_RETURN(const OperationRequest* lhs,
   1124                       LookUpRequest(convolve_request.lhs()));
   1125   TF_ASSIGN_OR_RETURN(const OperationRequest* rhs,
   1126                       LookUpRequest(convolve_request.rhs()));
   1127   TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape(
   1128                                        lhs->output_shape(), rhs->output_shape(),
   1129                                        convolve_request.window(),
   1130                                        convolve_request.dimension_numbers()));
   1131 
   1132   const ComputationDataHandle handle = CreateComputationDataHandle();
   1133 
   1134   OperationRequest& request =
   1135       (*session_computation_.mutable_requests())[handle.handle()];
   1136   *request.mutable_output_handle() = handle;
   1137   *request.mutable_output_shape() = shape;
   1138   *request.mutable_request()->mutable_convolve_request() = convolve_request;
   1139 
   1140   VLOG(1) << "AddConvolveInstruction (" << GetVersionedHandleInternal()
   1141           << "), data handle " << handle.handle() << ": "
   1142           << convolve_request.ShortDebugString();
   1143   return handle;
   1144 }
   1145 
   1146 StatusOr<ComputationDataHandle> UserComputation::AddFftInstruction(
   1147     const FftRequest& fft_request) {
   1148   tensorflow::mutex_lock lock(mutex_);
   1149 
   1150   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
   1151                       LookUpRequest(fft_request.operand()));
   1152   TF_ASSIGN_OR_RETURN(Shape shape,
   1153                       ShapeInference::InferFftShape(
   1154                           operand->output_shape(), fft_request.fft_type(),
   1155                           AsInt64Slice(fft_request.fft_length())));
   1156 
   1157   const ComputationDataHandle handle = CreateComputationDataHandle();
   1158 
   1159   OperationRequest& request =
   1160       (*session_computation_.mutable_requests())[handle.handle()];
   1161   *request.mutable_output_handle() = handle;
   1162   *request.mutable_output_shape() = shape;
   1163   *request.mutable_request()->mutable_fft_request() = fft_request;
   1164 
   1165   VLOG(1) << "AddFftInstruction (" << GetVersionedHandleInternal()
   1166           << "), data handle " << handle.handle() << ": "
   1167           << fft_request.ShortDebugString();
   1168   return handle;
   1169 }
   1170 
   1171 StatusOr<ComputationDataHandle> UserComputation::AddCrossReplicaSumInstruction(
   1172     const CrossReplicaSumRequest& cross_replica_sum_request) {
   1173   tensorflow::mutex_lock lock(mutex_);
   1174 
   1175   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
   1176                       LookUpRequest(cross_replica_sum_request.operand()));
   1177   TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape(
   1178                                        {&operand->output_shape()}));
   1179 
   1180   ComputationDataHandle handle = CreateComputationDataHandle();
   1181 
   1182   OperationRequest& request =
   1183       (*session_computation_.mutable_requests())[handle.handle()];
   1184   *request.mutable_output_handle() = handle;
   1185   *request.mutable_output_shape() = shape;
   1186   *request.mutable_request()->mutable_cross_replica_sum_request() =
   1187       cross_replica_sum_request;
   1188 
   1189   VLOG(1) << "AddCrossreplicaSumInstruction (" << GetVersionedHandleInternal()
   1190           << "), data handle " << handle.handle() << ": "
   1191           << cross_replica_sum_request.ShortDebugString();
   1192   return handle;
   1193 }
   1194 
   1195 StatusOr<ComputationDataHandle> UserComputation::AddInfeedInstruction(
   1196     const InfeedRequest& infeed_request) {
   1197   tensorflow::mutex_lock lock(mutex_);
   1198 
   1199   const Shape& shape = infeed_request.shape();
   1200   if (!LayoutUtil::HasLayout(shape)) {
   1201     return InvalidArgument("Given shape to Infeed must have a layout");
   1202   }
   1203 
   1204   const ComputationDataHandle handle = CreateComputationDataHandle();
   1205 
   1206   OperationRequest& request =
   1207       (*session_computation_.mutable_requests())[handle.handle()];
   1208   *request.mutable_output_handle() = handle;
   1209   *request.mutable_output_shape() = shape;
   1210   *request.mutable_request()->mutable_infeed_request() = infeed_request;
   1211 
   1212   VLOG(1) << "AddInfeedInstruction (" << GetVersionedHandleInternal()
   1213           << "), data handle " << handle.handle() << ": "
   1214           << infeed_request.ShortDebugString();
   1215   return handle;
   1216 }
   1217 
   1218 StatusOr<ComputationDataHandle> UserComputation::AddOutfeedInstruction(
   1219     const OutfeedRequest& outfeed_request) {
   1220   tensorflow::mutex_lock lock(mutex_);
   1221 
   1222   const Shape& shape = outfeed_request.shape();
   1223   if (!LayoutUtil::HasLayout(shape)) {
   1224     return InvalidArgument("Given shape to Outfeed must have a layout");
   1225   }
   1226 
   1227   // Verify that operand is valid.
   1228   TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status());
   1229 
   1230   ComputationDataHandle handle = CreateComputationDataHandle();
   1231   OperationRequest& request =
   1232       (*session_computation_.mutable_requests())[handle.handle()];
   1233   *request.mutable_output_handle() = handle;
   1234   *request.mutable_output_shape() = shape;
   1235   *request.mutable_request()->mutable_outfeed_request() = outfeed_request;
   1236 
   1237   VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal()
   1238           << "), data handle " << handle.handle() << ": "
   1239           << outfeed_request.ShortDebugString();
   1240   return handle;
   1241 }
   1242 
   1243 StatusOr<ComputationDataHandle> UserComputation::AddCallInstruction(
   1244     const CallRequest& call_request,
   1245     const UserComputation& to_apply_computation) {
   1246   tensorflow::mutex_lock lock(mutex_);
   1247 
   1248   std::vector<const Shape*> operand_shapes;
   1249   for (const ComputationDataHandle& handle : call_request.operands()) {
   1250     TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle));
   1251     operand_shapes.push_back(&operand->output_shape());
   1252   }
   1253 
   1254   VersionedComputationHandle::Version to_apply_version =
   1255       to_apply_computation.version();
   1256   TF_ASSIGN_OR_RETURN(
   1257       std::shared_ptr<const ProgramShape> to_apply_program_shape,
   1258       to_apply_computation.ComputeProgramShape(to_apply_version));
   1259   TF_ASSIGN_OR_RETURN(
   1260       Shape inferred_shape,
   1261       ShapeInference::InferCallShape(operand_shapes, *to_apply_program_shape));
   1262 
   1263   ComputationDataHandle handle = CreateComputationDataHandle();
   1264 
   1265   OperationRequest& request =
   1266       (*session_computation_.mutable_requests())[handle.handle()];
   1267   *request.mutable_output_handle() = handle;
   1268   *request.mutable_output_shape() = inferred_shape;
   1269   request.add_embedded_computation_versions(to_apply_version);
   1270   *request.mutable_request()->mutable_call_request() = call_request;
   1271 
   1272   VLOG(1) << "AddCallInstruction (" << GetVersionedHandleInternal()
   1273           << "), data handle " << handle.handle() << ": "
   1274           << call_request.ShortDebugString();
   1275   return handle;
   1276 }
   1277 
   1278 StatusOr<ComputationDataHandle> UserComputation::AddCustomCallInstruction(
   1279     const CustomCallRequest& custom_call_request) {
   1280   tensorflow::mutex_lock lock(mutex_);
   1281 
   1282   for (const ComputationDataHandle& handle : custom_call_request.operands()) {
   1283     TF_RETURN_IF_ERROR(LookUpRequest(handle).status());
   1284   }
   1285 
   1286   if (tensorflow::StringPiece(custom_call_request.call_target_name())
   1287           .starts_with("$")) {
   1288     return InvalidArgument(
   1289         "Invalid custom_call_target \"%s\": Call targets that start with '$' "
   1290         "are reserved for internal use.",
   1291         custom_call_request.call_target_name().c_str());
   1292   }
   1293 
   1294   const ComputationDataHandle handle = CreateComputationDataHandle();
   1295 
   1296   OperationRequest& request =
   1297       (*session_computation_.mutable_requests())[handle.handle()];
   1298   *request.mutable_output_handle() = handle;
   1299   *request.mutable_output_shape() = custom_call_request.shape();
   1300   *request.mutable_request()->mutable_custom_call_request() =
   1301       custom_call_request;
   1302 
   1303   VLOG(1) << "AddCustomCallInstruction (" << GetVersionedHandleInternal()
   1304           << "), data handle " << handle.handle() << ": "
   1305           << custom_call_request.ShortDebugString();
   1306   return handle;
   1307 }
   1308 
   1309 StatusOr<ComputationDataHandle> UserComputation::AddHostComputeInstruction(
   1310     const HostComputeRequest& host_compute_request) {
   1311   tensorflow::mutex_lock lock(mutex_);
   1312 
   1313   for (const ComputationDataHandle& handle : host_compute_request.operands()) {
   1314     TF_RETURN_IF_ERROR(LookUpRequest(handle).status());
   1315   }
   1316 
   1317   ComputationDataHandle handle = CreateComputationDataHandle();
   1318   OperationRequest& request =
   1319       (*session_computation_.mutable_requests())[handle.handle()];
   1320   *request.mutable_output_handle() = handle;
   1321   *request.mutable_output_shape() = host_compute_request.shape();
   1322   *request.mutable_request()->mutable_host_compute_request() =
   1323       host_compute_request;
   1324 
   1325   VLOG(1) << "AddHostComputeInstruction (" << GetVersionedHandleInternal()
   1326           << "), data handle " << handle.handle() << ": "
   1327           << host_compute_request.ShortDebugString();
   1328   return handle;
   1329 }
   1330 
   1331 StatusOr<ComputationDataHandle> UserComputation::AddDotInstruction(
   1332     const DotRequest& dot_request) {
   1333   tensorflow::mutex_lock lock(mutex_);
   1334 
   1335   TF_ASSIGN_OR_RETURN(const OperationRequest* lhs,
   1336                       LookUpRequest(dot_request.lhs()));
   1337   TF_ASSIGN_OR_RETURN(const OperationRequest* rhs,
   1338                       LookUpRequest(dot_request.rhs()));
   1339 
   1340   TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape(
   1341                                        lhs->output_shape(), rhs->output_shape(),
   1342                                        dot_request.dimension_numbers()));
   1343 
   1344   const ComputationDataHandle handle = CreateComputationDataHandle();
   1345 
   1346   OperationRequest& request =
   1347       (*session_computation_.mutable_requests())[handle.handle()];
   1348   *request.mutable_output_handle() = handle;
   1349   *request.mutable_output_shape() = shape;
   1350   *request.mutable_request()->mutable_dot_request() = dot_request;
   1351 
   1352   VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal()
   1353           << "), data handle " << handle.handle() << ": "
   1354           << dot_request.ShortDebugString();
   1355   return handle;
   1356 }
   1357 
   1358 StatusOr<ComputationDataHandle> UserComputation::AddUnaryInstruction(
   1359     const UnaryOpRequest& unary_request) {
   1360   tensorflow::mutex_lock lock(mutex_);
   1361 
   1362   TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
   1363                       LookUpRequest(unary_request.operand()));
   1364   TF_ASSIGN_OR_RETURN(
   1365       Shape shape, ShapeInference::InferUnaryOpShape(unary_request.unop(),
   1366                                                      operand->output_shape()));
   1367 
   1368   ComputationDataHandle handle = CreateComputationDataHandle();
   1369 
   1370   OperationRequest& request =
   1371       (*session_computation_.mutable_requests())[handle.handle()];
   1372   *request.mutable_output_handle() = handle;
   1373   *request.mutable_output_shape() = shape;
   1374   *request.mutable_request()->mutable_unary_op_request() = unary_request;
   1375 
   1376   VLOG(1) << "AddUnaryInstruction (" << GetVersionedHandleInternal()
   1377           << "), data handle " << handle.handle() << ": "
   1378           << unary_request.ShortDebugString();
   1379   return handle;
   1380 }
   1381 
   1382 StatusOr<ComputationDataHandle> UserComputation::AddBinaryInstruction(
   1383     const BinaryOpRequest& binary_request) {
   1384   tensorflow::mutex_lock lock(mutex_);
   1385 
   1386   TF_ASSIGN_OR_RETURN(const OperationRequest* lhs,
   1387                       LookUpRequest(binary_request.lhs()));
   1388   TF_ASSIGN_OR_RETURN(const OperationRequest* rhs,
   1389                       LookUpRequest(binary_request.rhs()));
   1390   TF_ASSIGN_OR_RETURN(
   1391       Shape shape,
   1392       ShapeInference::InferBinaryOpShape(
   1393           binary_request.binop(), lhs->output_shape(), rhs->output_shape(),
   1394           AsInt64Slice(binary_request.broadcast_dimensions())));
   1395 
   1396   ComputationDataHandle handle = CreateComputationDataHandle();
   1397 
   1398   OperationRequest& request =
   1399       (*session_computation_.mutable_requests())[handle.handle()];
   1400   *request.mutable_output_handle() = handle;
   1401   *request.mutable_output_shape() = shape;
   1402   *request.mutable_request()->mutable_binary_op_request() = binary_request;
   1403 
   1404   VLOG(1) << "AddBinaryInstruction (" << GetVersionedHandleInternal()
   1405           << "), data handle " << handle.handle() << ": "
   1406           << binary_request.ShortDebugString();
   1407   return handle;
   1408 }
   1409 
   1410 StatusOr<ComputationDataHandle> UserComputation::AddTernaryInstruction(
   1411     const TernaryOpRequest& ternary_request) {
   1412   tensorflow::mutex_lock lock(mutex_);
   1413 
   1414   TF_ASSIGN_OR_RETURN(const OperationRequest* lhs,
   1415                       LookUpRequest(ternary_request.lhs()));
   1416   TF_ASSIGN_OR_RETURN(const OperationRequest* rhs,
   1417                       LookUpRequest(ternary_request.rhs()));
   1418   TF_ASSIGN_OR_RETURN(const OperationRequest* ehs,
   1419                       LookUpRequest(ternary_request.ehs()));
   1420   TF_ASSIGN_OR_RETURN(Shape shape,
   1421                       ShapeInference::InferTernaryOpShape(
   1422                           ternary_request.triop(), lhs->output_shape(),
   1423                           rhs->output_shape(), ehs->output_shape()));
   1424 
   1425   ComputationDataHandle handle = CreateComputationDataHandle();
   1426 
   1427   OperationRequest& request =
   1428       (*session_computation_.mutable_requests())[handle.handle()];
   1429   *request.mutable_output_handle() = handle;
   1430   *request.mutable_output_shape() = shape;
   1431   *request.mutable_request()->mutable_ternary_op_request() = ternary_request;
   1432 
   1433   VLOG(1) << "AddTernaryInstruction (" << GetVersionedHandleInternal()
   1434           << "), data handle " << handle.handle() << ": "
   1435           << ternary_request.ShortDebugString();
   1436   return handle;
   1437 }
   1438 
   1439 StatusOr<ComputationDataHandle> UserComputation::AddVariadicInstruction(
   1440     const VariadicOpRequest& variadic_request) {
   1441   tensorflow::mutex_lock lock(mutex_);
   1442 
   1443   std::vector<const Shape*> operand_shapes;
   1444   for (const ComputationDataHandle& handle : variadic_request.operands()) {
   1445     TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle));
   1446     operand_shapes.push_back(&operand->output_shape());
   1447   }
   1448 
   1449   TF_ASSIGN_OR_RETURN(Shape shape,
   1450                       ShapeInference::InferVariadicOpShape(
   1451                           variadic_request.varop(), operand_shapes));
   1452 
   1453   ComputationDataHandle handle = CreateComputationDataHandle();
   1454 
   1455   OperationRequest& request =
   1456       (*session_computation_.mutable_requests())[handle.handle()];
   1457   *request.mutable_output_handle() = handle;
   1458   *request.mutable_output_shape() = shape;
   1459   *request.mutable_request()->mutable_variadic_op_request() = variadic_request;
   1460 
   1461   VLOG(1) << "AddVariadicInstruction (" << GetVersionedHandleInternal()
   1462           << "), data handle " << handle.handle() << ": "
   1463           << variadic_request.ShortDebugString();
   1464   return handle;
   1465 }
   1466 
   1467 StatusOr<Shape> UserComputation::GetShape(const ComputationDataHandle& handle) {
   1468   tensorflow::mutex_lock lock(mutex_);
   1469 
   1470   TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle));
   1471   return operand->output_shape();
   1472 }
   1473 
   1474 Status UserComputation::SetOpMetadata(const ComputationDataHandle& handle,
   1475                                       const OpMetadata& metadata) {
   1476   tensorflow::mutex_lock lock(mutex_);
   1477 
   1478   int64 handle_value = handle.handle();
   1479   if (session_computation_.requests().count(handle_value) == 0) {
   1480     return InvalidArgument("Invalid handle in SetOpMetadata (%lld)",
   1481                            handle_value);
   1482   }
   1483   *session_computation_.mutable_requests()
   1484        ->at(handle_value)
   1485        .mutable_request()
   1486        ->mutable_metadata() = metadata;
   1487   return Status::OK();
   1488 }
   1489 
   1490 Status UserComputation::SetOpSharding(const ComputationDataHandle& handle,
   1491                                       const OpSharding& sharding) {
   1492   tensorflow::mutex_lock lock(mutex_);
   1493 
   1494   int64 handle_value = handle.handle();
   1495   if (session_computation_.requests().count(handle_value) == 0) {
   1496     return InvalidArgument("Invalid handle in SetOpSharding (%lld)",
   1497                            handle_value);
   1498   }
   1499   *session_computation_.mutable_requests()
   1500        ->at(handle_value)
   1501        .mutable_request()
   1502        ->mutable_sharding() = sharding;
   1503   return Status::OK();
   1504 }
   1505 
   1506 Status UserComputation::SetReturnValue(const ComputationDataHandle& handle) {
   1507   tensorflow::mutex_lock lock(mutex_);
   1508 
   1509   if (!(handle.handle() > 0 && handle.handle() < next_handle_value_)) {
   1510     return InvalidArgument("Invalid handle in SetReturnValue");
   1511   }
   1512 
   1513   handle_to_return_ = handle;
   1514 
   1515   VLOG(1) << "SetReturnValue of computation \"" << name() << "\" fixed to "
   1516           << GetVersionedHandleInternal();
   1517 
   1518   return Status::OK();
   1519 }
   1520 
   1521 VersionedComputationHandle UserComputation::GetVersionedHandle() const {
   1522   tensorflow::mutex_lock lock(mutex_);
   1523   return GetVersionedHandleInternal();
   1524 }
   1525 
   1526 VersionedComputationHandle UserComputation::GetVersionedHandleInternal() const {
   1527   VersionedComputationHandle versioned_handle;
   1528   versioned_handle.handle = session_computation_.computation_handle();
   1529 
   1530   if (handle_to_return_.handle() > 0) {
   1531     // A specific handle has been requested for the result of the computation.
   1532     versioned_handle.version = handle_to_return_.handle();
   1533   } else {
   1534     // A version value is simply the most recently assigned
   1535     // ComputationDataHandle value, ie the handle value of the root of the
   1536     // computation.
   1537     versioned_handle.version = next_handle_value_ - 1;
   1538   }
   1539 
   1540   return versioned_handle;
   1541 }
   1542 
   1543 VersionedComputationHandle UserComputation::GetVersionedHandleAtOperation(
   1544     const ComputationDataHandle& operation) const {
   1545   tensorflow::mutex_lock lock(mutex_);
   1546 
   1547   // The version at which an operation was added is simply the handle value of
   1548   // the ComputationDataHandle.
   1549   VersionedComputationHandle versioned_handle;
   1550   versioned_handle.handle = session_computation_.computation_handle();
   1551   versioned_handle.version = operation.handle();
   1552   return versioned_handle;
   1553 }
   1554 
   1555 VersionedComputationHandle::Version UserComputation::version() const {
   1556   return GetVersionedHandle().version;
   1557 }
   1558 
   1559 namespace {
   1560 
   1561 // Returns true if the operation type corresponding to the given opcase can be
   1562 // the root of the computation.
   1563 bool CanBeRoot(const OpRequest::OpCase& op_case) {
   1564   switch (op_case) {
   1565     case OpRequest::kTraceRequest:
   1566     case OpRequest::kSendRequest:
   1567     case OpRequest::kOutfeedRequest:
   1568       return false;
   1569     default:
   1570       return true;
   1571   }
   1572 }
   1573 
   1574 // Returns a pointer to the operation with the given data handle value in the
   1575 // given SessionComputation.
   1576 StatusOr<const OperationRequest*> LookUpRequest(
   1577     int64 handle_value, const SessionComputation& session_computation) {
   1578   if (session_computation.requests().count(handle_value) == 0) {
   1579     return InvalidArgument("no ComputationDataHandle value %lld", handle_value);
   1580   }
   1581   return &session_computation.requests().at(handle_value);
   1582 }
   1583 
   1584 // Returns the OperationRequest corresponding to the root (result) of the
   1585 // session computation.
   1586 StatusOr<const OperationRequest*> GetRoot(
   1587     VersionedComputationHandle::Version version,
   1588     const SessionComputation& session_computation) {
   1589   TF_RET_CHECK(version > 0);
   1590   // Not all instructions can be roots. Walk backwards from the operation
   1591   // indicated by this version until a valid root is found.
   1592   const OperationRequest* root_request = nullptr;
   1593   while (version > 0) {
   1594     TF_ASSIGN_OR_RETURN(root_request,
   1595                         LookUpRequest(version, session_computation));
   1596     if (CanBeRoot(root_request->request().op_case())) {
   1597       break;
   1598     }
   1599     version--;
   1600   }
   1601   if (version == 0) {
   1602     return InternalError("Computation contains no root operation");
   1603   }
   1604   return root_request;
   1605 }
   1606 
   1607 }  // namespace
   1608 
   1609 StatusOr<std::shared_ptr<const ProgramShape>>
   1610 UserComputation::ComputeProgramShape(
   1611     VersionedComputationHandle::Version version) const {
   1612   tensorflow::mutex_lock lock(mutex_);
   1613 
   1614   TF_RET_CHECK(version > 0 && version < next_handle_value_);
   1615 
   1616   if (program_shape_ == nullptr || program_shape_version_ != version) {
   1617     // ProgramShape has not been computed yet, or is for different
   1618     // version. Compute it now.
   1619     TF_RETURN_IF_ERROR(CheckParametersAreContiguous(version));
   1620 
   1621     auto program_shape = MakeUnique<ProgramShape>();
   1622     for (int64 request_num = 1; request_num <= version; ++request_num) {
   1623       const OperationRequest& request =
   1624           session_computation_.requests().at(request_num);
   1625       if (request.request().op_case() == OpRequest::kParameterRequest) {
   1626         const ParameterRequest& parameter_request =
   1627             request.request().parameter_request();
   1628         int64 param_no = parameter_request.parameter();
   1629         // Parameters may be out of order so expand ProgramShape parameters
   1630         // until it is at least large enough to hold the current parameter
   1631         // number.
   1632         while (program_shape->parameters_size() <= param_no) {
   1633           program_shape->add_parameters();
   1634           program_shape->add_parameter_names();
   1635         }
   1636         *program_shape->mutable_parameters(param_no) = request.output_shape();
   1637         *program_shape->mutable_parameter_names(param_no) =
   1638             parameter_request.name();
   1639       }
   1640     }
   1641 
   1642     // The root determines the output shape.
   1643     TF_ASSIGN_OR_RETURN(const OperationRequest* root_request,
   1644                         GetRoot(version, session_computation_));
   1645     *program_shape->mutable_result() = root_request->output_shape();
   1646     if (ShapeUtil::IsOpaque(program_shape->result())) {
   1647       return Unimplemented("Computation results cannot be opaque");
   1648     }
   1649 
   1650     program_shape_ = std::move(program_shape);
   1651     program_shape_version_ = version;
   1652   }
   1653 
   1654   return program_shape_;
   1655 }
   1656 
   1657 namespace {
   1658 
   1659 // A visitor which checks whether an operation is pure functional meaning that
   1660 // it doesn't depend on any parameter with an index higher then num_parameters.
   1661 // The visitor walks the computation starting at a given operation and sets
   1662 // is_functional to false iff a parameter or RNG operation is encountered.
   1663 void PureFunctionalVisitor(const SessionComputation& session_computation,
   1664                            const ComputationDataHandle& handle,
   1665                            int64 num_parameters, std::set<int64>* visited,
   1666                            bool* is_functional) {
   1667   if (visited->count(handle.handle()) != 0 || !*is_functional) {
   1668     return;
   1669   }
   1670 
   1671   const OperationRequest& request =
   1672       session_computation.requests().at(handle.handle());
   1673   switch (request.request().op_case()) {
   1674     case OpRequest::kRngRequest:
   1675       *is_functional = false;
   1676       break;
   1677 
   1678     case OpRequest::kConstantRequest:
   1679       break;
   1680 
   1681     case OpRequest::kGetTupleElementRequest: {
   1682       const GetTupleElementRequest& get_tuple_element_request =
   1683           request.request().get_tuple_element_request();
   1684       PureFunctionalVisitor(session_computation,
   1685                             get_tuple_element_request.operand(), num_parameters,
   1686                             visited, is_functional);
   1687       break;
   1688     }
   1689 
   1690     case OpRequest::kSliceRequest: {
   1691       const SliceRequest& slice_request = request.request().slice_request();
   1692       PureFunctionalVisitor(session_computation, slice_request.operand(),
   1693                             num_parameters, visited, is_functional);
   1694       break;
   1695     }
   1696 
   1697     case OpRequest::kDynamicSliceRequest: {
   1698       const DynamicSliceRequest& dynamic_slice_request =
   1699           request.request().dynamic_slice_request();
   1700       PureFunctionalVisitor(session_computation,
   1701                             dynamic_slice_request.operand(), num_parameters,
   1702                             visited, is_functional);
   1703       PureFunctionalVisitor(session_computation,
   1704                             dynamic_slice_request.start_indices(),
   1705                             num_parameters, visited, is_functional);
   1706       break;
   1707     }
   1708 
   1709     case OpRequest::kDynamicUpdateSliceRequest: {
   1710       const DynamicUpdateSliceRequest& dynamic_update_slice_request =
   1711           request.request().dynamic_update_slice_request();
   1712       PureFunctionalVisitor(session_computation,
   1713                             dynamic_update_slice_request.operand(),
   1714                             num_parameters, visited, is_functional);
   1715       PureFunctionalVisitor(session_computation,
   1716                             dynamic_update_slice_request.update(),
   1717                             num_parameters, visited, is_functional);
   1718       PureFunctionalVisitor(session_computation,
   1719                             dynamic_update_slice_request.start_indices(),
   1720                             num_parameters, visited, is_functional);
   1721       break;
   1722     }
   1723 
   1724     case OpRequest::kConcatenateRequest: {
   1725       const ConcatenateRequest& concatenate_request =
   1726           request.request().concatenate_request();
   1727       for (const ComputationDataHandle& handle :
   1728            concatenate_request.operands()) {
   1729         PureFunctionalVisitor(session_computation, handle, num_parameters,
   1730                               visited, is_functional);
   1731       }
   1732       break;
   1733     }
   1734 
   1735     case OpRequest::kConvolveRequest: {
   1736       const ConvolveRequest& convolve_request =
   1737           request.request().convolve_request();
   1738       PureFunctionalVisitor(session_computation, convolve_request.lhs(),
   1739                             num_parameters, visited, is_functional);
   1740       PureFunctionalVisitor(session_computation, convolve_request.rhs(),
   1741                             num_parameters, visited, is_functional);
   1742       break;
   1743     }
   1744 
   1745     case OpRequest::kFftRequest: {
   1746       const FftRequest& fft_request = request.request().fft_request();
   1747       PureFunctionalVisitor(session_computation, fft_request.operand(),
   1748                             num_parameters, visited, is_functional);
   1749       break;
   1750     }
   1751 
   1752     case OpRequest::kCrossReplicaSumRequest: {
   1753       // TODO(b/33009255): Implmement constant folding for cross replica sum.
   1754       *is_functional = false;
   1755       break;
   1756     }
   1757 
   1758     case OpRequest::kInfeedRequest: {
   1759       *is_functional = false;
   1760       break;
   1761     }
   1762 
   1763     case OpRequest::kOutfeedRequest: {
   1764       *is_functional = false;
   1765       break;
   1766     }
   1767 
   1768     case OpRequest::kHostComputeRequest: {
   1769       *is_functional = false;
   1770       break;
   1771     }
   1772 
   1773     case OpRequest::kCallRequest: {
   1774       const CallRequest& call_request = request.request().call_request();
   1775       for (const ComputationDataHandle& handle : call_request.operands()) {
   1776         PureFunctionalVisitor(session_computation, handle, num_parameters,
   1777                               visited, is_functional);
   1778       }
   1779       // TODO(b/32495713): We aren't checking the to_apply computation itself,
   1780       // so we conservatively say that computations containing the Call op
   1781       // cannot be constant.  We cannot set is_functional=false in other similar
   1782       // cases since we're already relying on IsConstant to return true.
   1783       *is_functional = false;
   1784       break;
   1785     }
   1786 
   1787     case OpRequest::kCustomCallRequest: {
   1788       *is_functional = false;
   1789       break;
   1790     }
   1791 
   1792     case OpRequest::kDotRequest: {
   1793       const DotRequest& dot_request = request.request().dot_request();
   1794       PureFunctionalVisitor(session_computation, dot_request.lhs(),
   1795                             num_parameters, visited, is_functional);
   1796       PureFunctionalVisitor(session_computation, dot_request.rhs(),
   1797                             num_parameters, visited, is_functional);
   1798       break;
   1799     }
   1800 
   1801     case OpRequest::kSendRequest: {
   1802       *is_functional = false;
   1803       break;
   1804     }
   1805 
   1806     case OpRequest::kRecvRequest: {
   1807       *is_functional = false;
   1808       break;
   1809     }
   1810 
   1811     case OpRequest::kMapRequest: {
   1812       const MapRequest& map_request = request.request().map_request();
   1813       for (const ComputationDataHandle& handle : map_request.operands()) {
   1814         PureFunctionalVisitor(session_computation, handle, num_parameters,
   1815                               visited, is_functional);
   1816       }
   1817       // TODO(b/32495713): We aren't checking the to_apply computation itself.
   1818       break;
   1819     }
   1820 
   1821     case OpRequest::kReduceRequest: {
   1822       const ReduceRequest& reduce_request = request.request().reduce_request();
   1823       PureFunctionalVisitor(session_computation, reduce_request.operand(),
   1824                             num_parameters, visited, is_functional);
   1825       PureFunctionalVisitor(session_computation, reduce_request.init_value(),
   1826                             num_parameters, visited, is_functional);
   1827       // TODO(b/32495713): We aren't checking the to_apply computation itself.
   1828       break;
   1829     }
   1830 
   1831     case OpRequest::kReduceWindowRequest: {
   1832       const ReduceWindowRequest& reduce_window_request =
   1833           request.request().reduce_window_request();
   1834       PureFunctionalVisitor(session_computation,
   1835                             reduce_window_request.operand(), num_parameters,
   1836                             visited, is_functional);
   1837       PureFunctionalVisitor(session_computation,
   1838                             reduce_window_request.init_value(), num_parameters,
   1839                             visited, is_functional);
   1840       // TODO(b/32495713): We aren't checking the to_apply computation itself.
   1841       break;
   1842     }
   1843 
   1844     case OpRequest::kSelectAndScatterRequest: {
   1845       const SelectAndScatterRequest& select_and_scatter_request =
   1846           request.request().select_and_scatter_request();
   1847       PureFunctionalVisitor(session_computation,
   1848                             select_and_scatter_request.operand(),
   1849                             num_parameters, visited, is_functional);
   1850       PureFunctionalVisitor(session_computation,
   1851                             select_and_scatter_request.source(), num_parameters,
   1852                             visited, is_functional);
   1853       PureFunctionalVisitor(session_computation,
   1854                             select_and_scatter_request.init_value(),
   1855                             num_parameters, visited, is_functional);
   1856       // TODO(b/32495713): We aren't checking the select and scatter
   1857       // computations themselves.
   1858       break;
   1859     }
   1860 
   1861     case OpRequest::kBroadcastRequest: {
   1862       const BroadcastRequest& broadcast_request =
   1863           request.request().broadcast_request();
   1864       PureFunctionalVisitor(session_computation, broadcast_request.operand(),
   1865                             num_parameters, visited, is_functional);
   1866       break;
   1867     }
   1868 
   1869     case OpRequest::kReshapeRequest: {
   1870       const ReshapeRequest& reshape_request =
   1871           request.request().reshape_request();
   1872       PureFunctionalVisitor(session_computation, reshape_request.operand(),
   1873                             num_parameters, visited, is_functional);
   1874       break;
   1875     }
   1876 
   1877     case OpRequest::kReverseRequest: {
   1878       const ReverseRequest& reverse_request =
   1879           request.request().reverse_request();
   1880       PureFunctionalVisitor(session_computation, reverse_request.operand(),
   1881                             num_parameters, visited, is_functional);
   1882       break;
   1883     }
   1884 
   1885     case OpRequest::kPadRequest: {
   1886       const PadRequest& pad_request = request.request().pad_request();
   1887       PureFunctionalVisitor(session_computation, pad_request.operand(),
   1888                             num_parameters, visited, is_functional);
   1889       PureFunctionalVisitor(session_computation, pad_request.padding_value(),
   1890                             num_parameters, visited, is_functional);
   1891       break;
   1892     }
   1893 
   1894     case OpRequest::kParameterRequest: {
   1895       const ParameterRequest& parameter_request =
   1896           request.request().parameter_request();
   1897       if (parameter_request.parameter() >= num_parameters) {
   1898         *is_functional = false;
   1899       }
   1900       break;
   1901     }
   1902 
   1903     case OpRequest::kConvertRequest: {
   1904       const ConvertRequest& convert_request =
   1905           request.request().convert_request();
   1906       PureFunctionalVisitor(session_computation, convert_request.operand(),
   1907                             num_parameters, visited, is_functional);
   1908       break;
   1909     }
   1910 
   1911     case OpRequest::kBitcastConvertRequest: {
   1912       const ConvertRequest& convert_request =
   1913           request.request().bitcast_convert_request();
   1914       PureFunctionalVisitor(session_computation, convert_request.operand(),
   1915                             num_parameters, visited, is_functional);
   1916       break;
   1917     }
   1918 
   1919     case OpRequest::kWhileRequest: {
   1920       const WhileRequest& while_request = request.request().while_request();
   1921       PureFunctionalVisitor(session_computation, while_request.init(),
   1922                             num_parameters, visited, is_functional);
   1923       // TODO(b/32495713): We aren't checking the condition and body
   1924       // computations themselves.
   1925       *is_functional = false;
   1926       break;
   1927     }
   1928 
   1929     case OpRequest::kConditionalRequest: {
   1930       const ConditionalRequest& conditional_request =
   1931           request.request().conditional_request();
   1932       PureFunctionalVisitor(session_computation,
   1933                             conditional_request.predicate(), num_parameters,
   1934                             visited, is_functional);
   1935       PureFunctionalVisitor(session_computation,
   1936                             conditional_request.true_operand(), num_parameters,
   1937                             visited, is_functional);
   1938       PureFunctionalVisitor(session_computation,
   1939                             conditional_request.false_operand(), num_parameters,
   1940                             visited, is_functional);
   1941       // TODO(b/32495713): We aren't checking the true and false computations
   1942       // themselves.
   1943       break;
   1944     }
   1945 
   1946     case OpRequest::kTernaryOpRequest: {
   1947       const TernaryOpRequest& ternary_op_request =
   1948           request.request().ternary_op_request();
   1949       PureFunctionalVisitor(session_computation, ternary_op_request.lhs(),
   1950                             num_parameters, visited, is_functional);
   1951       PureFunctionalVisitor(session_computation, ternary_op_request.rhs(),
   1952                             num_parameters, visited, is_functional);
   1953       PureFunctionalVisitor(session_computation, ternary_op_request.ehs(),
   1954                             num_parameters, visited, is_functional);
   1955       break;
   1956     }
   1957 
   1958     case OpRequest::kTransposeRequest: {
   1959       const TransposeRequest& transpose_request =
   1960           request.request().transpose_request();
   1961       PureFunctionalVisitor(session_computation, transpose_request.operand(),
   1962                             num_parameters, visited, is_functional);
   1963       break;
   1964     }
   1965 
   1966     case OpRequest::kVariadicOpRequest: {
   1967       const VariadicOpRequest& variadic_op_request =
   1968           request.request().variadic_op_request();
   1969       for (const ComputationDataHandle& handle :
   1970            variadic_op_request.operands()) {
   1971         PureFunctionalVisitor(session_computation, handle, num_parameters,
   1972                               visited, is_functional);
   1973       }
   1974       break;
   1975     }
   1976 
   1977     case OpRequest::kUnaryOpRequest: {
   1978       const UnaryOpRequest& unary_op_request =
   1979           request.request().unary_op_request();
   1980       PureFunctionalVisitor(session_computation, unary_op_request.operand(),
   1981                             num_parameters, visited, is_functional);
   1982       break;
   1983     }
   1984 
   1985     case OpRequest::kBatchNormTrainingRequest: {
   1986       const BatchNormTrainingRequest& batch_norm_training_request =
   1987           request.request().batch_norm_training_request();
   1988       PureFunctionalVisitor(session_computation,
   1989                             batch_norm_training_request.operand(),
   1990                             num_parameters, visited, is_functional);
   1991       PureFunctionalVisitor(session_computation,
   1992                             batch_norm_training_request.scale(), num_parameters,
   1993                             visited, is_functional);
   1994       PureFunctionalVisitor(session_computation,
   1995                             batch_norm_training_request.offset(),
   1996                             num_parameters, visited, is_functional);
   1997       break;
   1998     }
   1999 
   2000     case OpRequest::kBatchNormInferenceRequest: {
   2001       const BatchNormInferenceRequest& batch_norm_inference_request =
   2002           request.request().batch_norm_inference_request();
   2003       PureFunctionalVisitor(session_computation,
   2004                             batch_norm_inference_request.operand(),
   2005                             num_parameters, visited, is_functional);
   2006       PureFunctionalVisitor(session_computation,
   2007                             batch_norm_inference_request.scale(),
   2008                             num_parameters, visited, is_functional);
   2009       PureFunctionalVisitor(session_computation,
   2010                             batch_norm_inference_request.offset(),
   2011                             num_parameters, visited, is_functional);
   2012       PureFunctionalVisitor(session_computation,
   2013                             batch_norm_inference_request.mean(), num_parameters,
   2014                             visited, is_functional);
   2015       PureFunctionalVisitor(session_computation,
   2016                             batch_norm_inference_request.variance(),
   2017                             num_parameters, visited, is_functional);
   2018       break;
   2019     }
   2020 
   2021     case OpRequest::kBatchNormGradRequest: {
   2022       const BatchNormGradRequest& batch_norm_grad_request =
   2023           request.request().batch_norm_grad_request();
   2024       PureFunctionalVisitor(session_computation,
   2025                             batch_norm_grad_request.operand(), num_parameters,
   2026                             visited, is_functional);
   2027       PureFunctionalVisitor(session_computation,
   2028                             batch_norm_grad_request.scale(), num_parameters,
   2029                             visited, is_functional);
   2030       PureFunctionalVisitor(session_computation, batch_norm_grad_request.mean(),
   2031                             num_parameters, visited, is_functional);
   2032       PureFunctionalVisitor(session_computation,
   2033                             batch_norm_grad_request.variance(), num_parameters,
   2034                             visited, is_functional);
   2035       PureFunctionalVisitor(session_computation,
   2036                             batch_norm_grad_request.grad_output(),
   2037                             num_parameters, visited, is_functional);
   2038       break;
   2039     }
   2040 
   2041     case OpRequest::kBinaryOpRequest: {
   2042       const BinaryOpRequest& binary_op_request =
   2043           request.request().binary_op_request();
   2044       PureFunctionalVisitor(session_computation, binary_op_request.lhs(),
   2045                             num_parameters, visited, is_functional);
   2046       PureFunctionalVisitor(session_computation, binary_op_request.rhs(),
   2047                             num_parameters, visited, is_functional);
   2048       break;
   2049     }
   2050 
   2051     case OpRequest::kGatherRequest: {
   2052       PureFunctionalVisitor(session_computation,
   2053                             request.request().gather_request().input(),
   2054                             num_parameters, visited, is_functional);
   2055       PureFunctionalVisitor(session_computation,
   2056                             request.request().gather_request().gather_indices(),
   2057                             num_parameters, visited, is_functional);
   2058       break;
   2059     }
   2060 
   2061     case OpRequest::OP_NOT_SET:
   2062       LOG(FATAL) << "OperationRequest doesn't contain a request";
   2063 
   2064     default:
   2065       LOG(FATAL) << "Unexpected request type: " << request.request().op_case();
   2066   }
   2067   if (!*is_functional) {
   2068     VLOG(1) << "Non-functional: " << request.request().DebugString();
   2069   }
   2070   visited->insert(handle.handle());
   2071 }
   2072 
   2073 }  // namespace
   2074 
   2075 StatusOr<bool> UserComputation::IsConstant(const ComputationDataHandle& handle,
   2076                                            int64 num_parameters) {
   2077   tensorflow::mutex_lock lock(mutex_);
   2078 
   2079   // Verify that the handle is valid.
   2080   auto operation_status = LookUpRequest(handle);
   2081   if (!operation_status.ok()) {
   2082     return operation_status.status();
   2083   }
   2084 
   2085   bool is_constant = true;
   2086   std::set<int64> visited;
   2087   PureFunctionalVisitor(session_computation_, handle, num_parameters, &visited,
   2088                         &is_constant);
   2089 
   2090   return is_constant;
   2091 }
   2092 
   2093 std::vector<VersionedComputationHandle>
   2094 UserComputation::GetEmbeddedComputations(
   2095     VersionedComputationHandle::Version version) const {
   2096   tensorflow::mutex_lock lock(mutex_);
   2097 
   2098   VLOG(1)
   2099       << "GetEmbeddedComputations(" << name() << " "
   2100       << VersionedComputationHandle{session_computation_.computation_handle(),
   2101                                     version}
   2102       << ")";
   2103   XLA_VLOG_LINES(3, session_computation_.DebugString());
   2104 
   2105   std::vector<VersionedComputationHandle> computations;
   2106   std::vector<int64> sorted_handles;
   2107   for (const auto& handle_request : session_computation_.requests()) {
   2108     sorted_handles.push_back(handle_request.first);
   2109   }
   2110   std::sort(sorted_handles.begin(), sorted_handles.end());
   2111   for (int64 handle : sorted_handles) {
   2112     const auto& handle_request = session_computation_.requests().find(handle);
   2113     CHECK(handle_request != session_computation_.requests().end());
   2114     int64 handle_value = handle_request->first;
   2115     if (handle_value <= version) {
   2116       const OperationRequest& request = handle_request->second;
   2117       switch (request.request().op_case()) {
   2118         case OpRequest::kCallRequest: {
   2119           CHECK_EQ(1, request.embedded_computation_versions_size());
   2120           const CallRequest& call_request = request.request().call_request();
   2121           const VersionedComputationHandle versioned_handle = {
   2122               call_request.to_apply(),
   2123               request.embedded_computation_versions(0)};
   2124           computations.push_back(versioned_handle);
   2125           break;
   2126         }
   2127 
   2128         case OpRequest::kMapRequest: {
   2129           CHECK_EQ(1, request.embedded_computation_versions_size());
   2130           const MapRequest& map_request = request.request().map_request();
   2131           const VersionedComputationHandle versioned_handle = {
   2132               map_request.to_apply(), request.embedded_computation_versions(0)};
   2133           computations.push_back(versioned_handle);
   2134           break;
   2135         }
   2136 
   2137         case OpRequest::kReduceRequest: {
   2138           CHECK_EQ(1, request.embedded_computation_versions_size());
   2139           const ReduceRequest& reduce_request =
   2140               request.request().reduce_request();
   2141           const VersionedComputationHandle versioned_handle = {
   2142               reduce_request.to_apply(),
   2143               request.embedded_computation_versions(0)};
   2144           computations.push_back(versioned_handle);
   2145           break;
   2146         }
   2147 
   2148         case OpRequest::kReduceWindowRequest: {
   2149           CHECK_EQ(1, request.embedded_computation_versions_size());
   2150           const ReduceWindowRequest& reduce_window_request =
   2151               request.request().reduce_window_request();
   2152           const VersionedComputationHandle versioned_handle = {
   2153               reduce_window_request.to_apply(),
   2154               request.embedded_computation_versions(0)};
   2155           computations.push_back(versioned_handle);
   2156           break;
   2157         }
   2158 
   2159         case OpRequest::kSelectAndScatterRequest: {
   2160           CHECK_EQ(2, request.embedded_computation_versions_size());
   2161           const SelectAndScatterRequest& select_and_scatter_request =
   2162               request.request().select_and_scatter_request();
   2163           const VersionedComputationHandle select_versioned_handle = {
   2164               select_and_scatter_request.select(),
   2165               request.embedded_computation_versions(0)};
   2166           computations.push_back(select_versioned_handle);
   2167           const VersionedComputationHandle scatter_versioned_handle = {
   2168               select_and_scatter_request.scatter(),
   2169               request.embedded_computation_versions(1)};
   2170           computations.push_back(scatter_versioned_handle);
   2171           break;
   2172         }
   2173 
   2174         case OpRequest::kWhileRequest: {
   2175           CHECK_EQ(2, request.embedded_computation_versions_size());
   2176           const WhileRequest& while_request = request.request().while_request();
   2177           const VersionedComputationHandle condition_versioned_handle = {
   2178               while_request.condition(),
   2179               request.embedded_computation_versions(0)};
   2180           computations.push_back(condition_versioned_handle);
   2181           const VersionedComputationHandle body_versioned_handle = {
   2182               while_request.body(), request.embedded_computation_versions(1)};
   2183           computations.push_back(body_versioned_handle);
   2184           break;
   2185         }
   2186 
   2187         case OpRequest::kConditionalRequest: {
   2188           CHECK_EQ(2, request.embedded_computation_versions_size());
   2189           const ConditionalRequest& conditional_request =
   2190               request.request().conditional_request();
   2191           const VersionedComputationHandle true_computation_versioned_handle = {
   2192               conditional_request.true_computation(),
   2193               request.embedded_computation_versions(0)};
   2194           computations.push_back(true_computation_versioned_handle);
   2195           const VersionedComputationHandle false_computation_versioned_handle =
   2196               {conditional_request.false_computation(),
   2197                request.embedded_computation_versions(1)};
   2198           computations.push_back(false_computation_versioned_handle);
   2199           break;
   2200         }
   2201 
   2202         default:
   2203           // No embedded computation.
   2204           break;
   2205       }
   2206     }
   2207   }
   2208   VLOG(2) << "Embedded computations: "
   2209           << tensorflow::str_util::Join(
   2210                  computations, ", ",
   2211                  [](string* out, const VersionedComputationHandle& h) {
   2212                    out->append(h.ToString());
   2213                  });
   2214   return computations;
   2215 }
   2216 
   2217 StatusOr<const OperationRequest*>
   2218 UserComputation::LookUpRequestForErrorReporting(
   2219     const ComputationDataHandle& handle) const {
   2220   tensorflow::mutex_lock lock(mutex_);
   2221   return LookUpRequest(handle);
   2222 }
   2223 
   2224 tensorflow::gtl::optional<const OpMetadata*> UserComputation::ParameterMetadata(
   2225     int parameter_number) const {
   2226   tensorflow::mutex_lock lock(mutex_);
   2227   auto it = parameters_.find(parameter_number);
   2228   if (it == parameters_.end()) {
   2229     return tensorflow::gtl::nullopt;
   2230   }
   2231   OperationRequest* op = it->second;
   2232   return &op->request().metadata();
   2233 }
   2234 
   2235 Status UserComputation::RemapEmbeddedComputations(
   2236     const std::map<int64, ComputationHandle>& old_to_new) {
   2237   auto update = [&old_to_new](ComputationHandle* to_update) -> Status {
   2238     int64 old = to_update->handle();
   2239     auto it = old_to_new.find(old);
   2240     if (it == old_to_new.end()) {
   2241       string mapping = tensorflow::str_util::Join(
   2242           old_to_new, ", ",
   2243           [](string* out, std::pair<int64, ComputationHandle> element) {
   2244             tensorflow::strings::Appendf(out, "%lld:%lld", element.first,
   2245                                          element.second.handle());
   2246           });
   2247       return NotFound(
   2248           "could not find referenced (old) computation handle in mapping: "
   2249           "%lld; mapping: {%s}",
   2250           old, mapping.c_str());
   2251     }
   2252     VLOG(2) << "remapping " << old << " to " << it->second.handle();
   2253     *to_update = it->second;
   2254     return Status::OK();
   2255   };
   2256   TF_RETURN_IF_ERROR(update(session_computation_.mutable_computation_handle()));
   2257   for (auto& handle_request : *session_computation_.mutable_requests()) {
   2258     OperationRequest& request = handle_request.second;
   2259     switch (request.request().op_case()) {
   2260       case OpRequest::kCallRequest: {
   2261         TF_RET_CHECK(1 == request.embedded_computation_versions_size());
   2262         CallRequest* call_request =
   2263             request.mutable_request()->mutable_call_request();
   2264         TF_RETURN_IF_ERROR(update(call_request->mutable_to_apply()));
   2265         break;
   2266       }
   2267       case OpRequest::kMapRequest: {
   2268         TF_RET_CHECK(1 == request.embedded_computation_versions_size());
   2269         MapRequest* map_request =
   2270             request.mutable_request()->mutable_map_request();
   2271         TF_RETURN_IF_ERROR(update(map_request->mutable_to_apply()));
   2272         break;
   2273       }
   2274       case OpRequest::kReduceRequest: {
   2275         TF_RET_CHECK(1 == request.embedded_computation_versions_size());
   2276         ReduceRequest* reduce_request =
   2277             request.mutable_request()->mutable_reduce_request();
   2278         TF_RETURN_IF_ERROR(update(reduce_request->mutable_to_apply()));
   2279         break;
   2280       }
   2281       case OpRequest::kReduceWindowRequest: {
   2282         TF_RET_CHECK(1 == request.embedded_computation_versions_size());
   2283         ReduceWindowRequest* reduce_window_request =
   2284             request.mutable_request()->mutable_reduce_window_request();
   2285         TF_RETURN_IF_ERROR(update(reduce_window_request->mutable_to_apply()));
   2286         break;
   2287       }
   2288       case OpRequest::kSelectAndScatterRequest: {
   2289         TF_RET_CHECK(2 == request.embedded_computation_versions_size());
   2290         SelectAndScatterRequest* select_and_scatter_request =
   2291             request.mutable_request()->mutable_select_and_scatter_request();
   2292         TF_RETURN_IF_ERROR(
   2293             update(select_and_scatter_request->mutable_select()));
   2294         TF_RETURN_IF_ERROR(
   2295             update(select_and_scatter_request->mutable_scatter()));
   2296         break;
   2297       }
   2298       case OpRequest::kWhileRequest: {
   2299         TF_RET_CHECK(2 == request.embedded_computation_versions_size());
   2300         WhileRequest* while_request =
   2301             request.mutable_request()->mutable_while_request();
   2302         TF_RETURN_IF_ERROR(update(while_request->mutable_condition()));
   2303         TF_RETURN_IF_ERROR(update(while_request->mutable_body()));
   2304         break;
   2305       }
   2306       case OpRequest::kConditionalRequest: {
   2307         TF_RET_CHECK(2 == request.embedded_computation_versions_size());
   2308         ConditionalRequest* conditional_request =
   2309             request.mutable_request()->mutable_conditional_request();
   2310         TF_RETURN_IF_ERROR(
   2311             update(conditional_request->mutable_true_computation()));
   2312         TF_RETURN_IF_ERROR(
   2313             update(conditional_request->mutable_false_computation()));
   2314         break;
   2315       }
   2316       default:
   2317         // No embedded computation.
   2318         TF_RET_CHECK(0 == request.embedded_computation_versions_size());
   2319         break;
   2320     }
   2321   }
   2322   return Status::OK();
   2323 }
   2324 
   2325 SessionComputation UserComputation::CloneSessionComputation(
   2326     VersionedComputationHandle::Version version) const {
   2327   tensorflow::mutex_lock lock(mutex_);
   2328   SessionComputation result = session_computation_;
   2329   // Erase all the requests that exceed the version specified.
   2330   // There's no lower_bound method on tensorflow::protobuf::Map so we iterate
   2331   // all the elements.
   2332   auto it = result.mutable_requests()->begin();
   2333   while (it != result.mutable_requests()->end()) {
   2334     if (it->first > version) {
   2335       it = result.mutable_requests()->erase(it);
   2336     } else {
   2337       ++it;
   2338     }
   2339   }
   2340   return result;
   2341 }
   2342 
   2343 StatusOr<const OperationRequest*> UserComputation::LookUpRequest(
   2344     const ComputationDataHandle& handle) const {
   2345   int64 handle_value = handle.handle();
   2346   if (session_computation_.requests().count(handle_value) == 0) {
   2347     return InvalidArgument("no ComputationDataHandle value %lld", handle_value);
   2348   }
   2349   return &session_computation_.requests().at(handle_value);
   2350 }
   2351 
   2352 Status UserComputation::CheckParametersAreContiguous(
   2353     VersionedComputationHandle::Version version) const {
   2354   TF_RET_CHECK(version > 0 && version < next_handle_value_);
   2355 
   2356   // Determine number of parameter inputs at the given version.
   2357   std::map<int64, const ParameterRequest*> parameter_requests;
   2358   for (int64 request_num = 1; request_num <= version; ++request_num) {
   2359     const OperationRequest& request =
   2360         session_computation_.requests().at(request_num);
   2361 
   2362     if (request.request().op_case() == OpRequest::kParameterRequest) {
   2363       const ParameterRequest& parameter_request =
   2364           request.request().parameter_request();
   2365       // Duplicate parameters should be checked when parameter requests are
   2366       // added.
   2367       TF_RET_CHECK(0 ==
   2368                    parameter_requests.count(parameter_request.parameter()));
   2369       parameter_requests[parameter_request.parameter()] = &parameter_request;
   2370     }
   2371   }
   2372 
   2373   for (int64 i = 0; i < parameter_requests.size(); ++i) {
   2374     auto it = parameter_requests.find(i);
   2375     if (it == parameter_requests.end()) {
   2376       return FailedPrecondition(
   2377           "computation %s does not have all its parameters populated "
   2378           "sequentially, missing parameter %lld",
   2379           name_.c_str(), i);
   2380     }
   2381   }
   2382 
   2383   return Status::OK();
   2384 }
   2385 
   2386 namespace {
   2387 
   2388 // Helper class which builds an HLO computation from a SessionComputation. To
   2389 // construct the HLO computation, the SessionComputation graph is walked in
   2390 // DFS order lowering each OperationRequest to an HLO instruction.
   2391 class ComputationLowerer {
   2392  public:
   2393   static StatusOr<std::unique_ptr<HloComputation>> Lower(
   2394       const string& computation_name,
   2395       const SessionComputation& session_computation,
   2396       VersionedComputationHandle::Version version,
   2397       UserComputation::HloComputationResolver hlo_resolver,
   2398       const DebugOptions& debug_options,
   2399       bool include_unreachable_instructions) {
   2400     ComputationLowerer lowerer(computation_name, session_computation, version,
   2401                                std::move(hlo_resolver), debug_options,
   2402                                include_unreachable_instructions);
   2403     return lowerer.Lower();
   2404   }
   2405 
   2406  private:
   2407   ComputationLowerer(const string& computation_name,
   2408                      const SessionComputation& session_computation,
   2409                      VersionedComputationHandle::Version version,
   2410                      UserComputation::HloComputationResolver hlo_resolver,
   2411                      const DebugOptions& debug_options,
   2412                      bool include_unreachable_instructions)
   2413       : hlo_builder_(computation_name),
   2414         session_computation_(session_computation),
   2415         version_(version),
   2416         hlo_resolver_(std::move(hlo_resolver)),
   2417         debug_options_(debug_options),
   2418         include_unreachable_instructions_(include_unreachable_instructions) {}
   2419 
   2420   // Build an HLO computation from the SessionComputation at the given
   2421   // version.
   2422   StatusOr<std::unique_ptr<HloComputation>> Lower();
   2423 
   2424  private:
   2425   // Traverses the computation 'root' using a DFS, calling 'visit' in postorder.
   2426   void TraversePostorder(
   2427       const ComputationDataHandle& root,
   2428       std::unordered_map<int64, HloInstruction*>* visited,
   2429       const std::function<void(const ComputationDataHandle&)>& visit);
   2430 
   2431   // DFS visitor of the UserComputation operations which lowers the operations
   2432   // to HLO instructions.
   2433   void Visit(const ComputationDataHandle& handle,
   2434              std::unordered_map<int64, HloInstruction*>* instructions);
   2435 
   2436   // Resolves a ComputationHandle and Version to a previously lowered
   2437   // HloComputation using the hlo_resolver_ function.
   2438   HloComputation* ResolveComputation(
   2439       const ComputationHandle& handle,
   2440       VersionedComputationHandle::Version version);
   2441 
   2442   // This function takes an input value which is being implicitly broadcast into
   2443   // an output shape and figures out the right kBroadcast instruction(s)
   2444   // necessary to replicate the implicit broadcast semantics explicitly.
   2445   HloInstruction* ImplicitBroadcastToExplicitBroadcast(
   2446       HloInstruction* operand, const Shape& output_shape);
   2447 
   2448   HloComputation::Builder hlo_builder_;
   2449   const SessionComputation& session_computation_;
   2450   const VersionedComputationHandle::Version version_;
   2451   const UserComputation::HloComputationResolver hlo_resolver_;
   2452   const DebugOptions& debug_options_;
   2453   const bool include_unreachable_instructions_;
   2454 };
   2455 
   2456 // Calls 'apply' on each operand of 'request'.
   2457 static void ForEachOperand(
   2458     const OperationRequest& request,
   2459     const std::function<void(const ComputationDataHandle& param)>& apply) {
   2460   switch (request.request().op_case()) {
   2461     case OpRequest::kRngRequest: {
   2462       const RngRequest& rng_request = request.request().rng_request();
   2463       for (const ComputationDataHandle& param : rng_request.parameter()) {
   2464         apply(param);
   2465       }
   2466       break;
   2467     }
   2468 
   2469     case OpRequest::kConstantRequest:
   2470       break;
   2471     case OpRequest::kGetTupleElementRequest: {
   2472       const GetTupleElementRequest& get_tuple_element_request =
   2473           request.request().get_tuple_element_request();
   2474       apply(get_tuple_element_request.operand());
   2475       break;
   2476     }
   2477 
   2478     case OpRequest::kSliceRequest: {
   2479       const SliceRequest& slice_request = request.request().slice_request();
   2480       apply(slice_request.operand());
   2481       break;
   2482     }
   2483 
   2484     case OpRequest::kDynamicSliceRequest: {
   2485       const DynamicSliceRequest& dynamic_slice_request =
   2486           request.request().dynamic_slice_request();
   2487       apply(dynamic_slice_request.operand());
   2488       apply(dynamic_slice_request.start_indices());
   2489       break;
   2490     }
   2491 
   2492     case OpRequest::kDynamicUpdateSliceRequest: {
   2493       const DynamicUpdateSliceRequest& dynamic_update_slice_request =
   2494           request.request().dynamic_update_slice_request();
   2495       apply(dynamic_update_slice_request.operand());
   2496       apply(dynamic_update_slice_request.update());
   2497       apply(dynamic_update_slice_request.start_indices());
   2498       break;
   2499     }
   2500 
   2501     case OpRequest::kConcatenateRequest: {
   2502       const ConcatenateRequest& concatenate_request =
   2503           request.request().concatenate_request();
   2504       for (const ComputationDataHandle& handle :
   2505            concatenate_request.operands()) {
   2506         apply(handle);
   2507       }
   2508       break;
   2509     }
   2510 
   2511     case OpRequest::kConvolveRequest: {
   2512       const ConvolveRequest& convolve_request =
   2513           request.request().convolve_request();
   2514       apply(convolve_request.lhs());
   2515       apply(convolve_request.rhs());
   2516       break;
   2517     }
   2518 
   2519     case OpRequest::kFftRequest: {
   2520       const FftRequest& fft_request = request.request().fft_request();
   2521       apply(fft_request.operand());
   2522       break;
   2523     }
   2524 
   2525     case OpRequest::kBatchNormTrainingRequest: {
   2526       const BatchNormTrainingRequest& batch_norm_training_request =
   2527           request.request().batch_norm_training_request();
   2528 
   2529       apply(batch_norm_training_request.operand());
   2530       apply(batch_norm_training_request.scale());
   2531       apply(batch_norm_training_request.offset());
   2532       break;
   2533     }
   2534 
   2535     case OpRequest::kBatchNormInferenceRequest: {
   2536       const BatchNormInferenceRequest& batch_norm_inference_request =
   2537           request.request().batch_norm_inference_request();
   2538 
   2539       apply(batch_norm_inference_request.operand());
   2540       apply(batch_norm_inference_request.scale());
   2541       apply(batch_norm_inference_request.offset());
   2542       apply(batch_norm_inference_request.mean());
   2543       apply(batch_norm_inference_request.variance());
   2544       break;
   2545     }
   2546 
   2547     case OpRequest::kBatchNormGradRequest: {
   2548       const BatchNormGradRequest& batch_norm_grad_request =
   2549           request.request().batch_norm_grad_request();
   2550 
   2551       apply(batch_norm_grad_request.operand());
   2552       apply(batch_norm_grad_request.scale());
   2553       apply(batch_norm_grad_request.mean());
   2554       apply(batch_norm_grad_request.variance());
   2555       apply(batch_norm_grad_request.grad_output());
   2556       break;
   2557     }
   2558 
   2559     case OpRequest::kCrossReplicaSumRequest: {
   2560       const CrossReplicaSumRequest& cross_replica_sum_request =
   2561           request.request().cross_replica_sum_request();
   2562       apply(cross_replica_sum_request.operand());
   2563       break;
   2564     }
   2565 
   2566     case OpRequest::kInfeedRequest:
   2567       break;
   2568 
   2569     case OpRequest::kOutfeedRequest: {
   2570       const OutfeedRequest& outfeed_request =
   2571           request.request().outfeed_request();
   2572       apply(outfeed_request.operand());
   2573       break;
   2574     }
   2575 
   2576     case OpRequest::kMapRequest: {
   2577       const MapRequest& map_request = request.request().map_request();
   2578       for (const ComputationDataHandle& handle : map_request.operands()) {
   2579         apply(handle);
   2580       }
   2581       break;
   2582     }
   2583 
   2584     case OpRequest::kReduceRequest: {
   2585       const ReduceRequest& reduce_request = request.request().reduce_request();
   2586       apply(reduce_request.operand());
   2587       apply(reduce_request.init_value());
   2588       break;
   2589     }
   2590 
   2591     case OpRequest::kReduceWindowRequest: {
   2592       const ReduceWindowRequest& reduce_window_request =
   2593           request.request().reduce_window_request();
   2594       apply(reduce_window_request.operand());
   2595       apply(reduce_window_request.init_value());
   2596       break;
   2597     }
   2598 
   2599     case OpRequest::kSelectAndScatterRequest: {
   2600       const SelectAndScatterRequest& select_and_scatter_request =
   2601           request.request().select_and_scatter_request();
   2602       apply(select_and_scatter_request.operand());
   2603       apply(select_and_scatter_request.source());
   2604       apply(select_and_scatter_request.init_value());
   2605 
   2606       break;
   2607     }
   2608 
   2609     case OpRequest::kBroadcastRequest: {
   2610       const BroadcastRequest& broadcast_request =
   2611           request.request().broadcast_request();
   2612       apply(broadcast_request.operand());
   2613       break;
   2614     }
   2615 
   2616     case OpRequest::kReshapeRequest: {
   2617       const ReshapeRequest& reshape_request =
   2618           request.request().reshape_request();
   2619       apply(reshape_request.operand());
   2620       break;
   2621     }
   2622 
   2623     case OpRequest::kTransposeRequest: {
   2624       const TransposeRequest& transpose_request =
   2625           request.request().transpose_request();
   2626       apply(transpose_request.operand());
   2627       break;
   2628     }
   2629 
   2630     case OpRequest::kReverseRequest: {
   2631       const ReverseRequest& reverse_request =
   2632           request.request().reverse_request();
   2633       apply(reverse_request.operand());
   2634       break;
   2635     }
   2636 
   2637     case OpRequest::kPadRequest: {
   2638       const PadRequest& pad_request = request.request().pad_request();
   2639       apply(pad_request.operand());
   2640       apply(pad_request.padding_value());
   2641       break;
   2642     }
   2643 
   2644     case OpRequest::kRecvRequest:
   2645     case OpRequest::kParameterRequest:
   2646       break;
   2647 
   2648     case OpRequest::kConvertRequest: {
   2649       const ConvertRequest& convert_request =
   2650           request.request().convert_request();
   2651       apply(convert_request.operand());
   2652       break;
   2653     }
   2654 
   2655     case OpRequest::kBitcastConvertRequest: {
   2656       const ConvertRequest& convert_request =
   2657           request.request().bitcast_convert_request();
   2658       apply(convert_request.operand());
   2659       break;
   2660     }
   2661 
   2662     case OpRequest::kWhileRequest: {
   2663       const WhileRequest& while_request = request.request().while_request();
   2664       apply(while_request.init());
   2665       break;
   2666     }
   2667 
   2668     case OpRequest::kConditionalRequest: {
   2669       const ConditionalRequest& conditional_request =
   2670           request.request().conditional_request();
   2671       apply(conditional_request.predicate());
   2672       apply(conditional_request.true_operand());
   2673       apply(conditional_request.false_operand());
   2674       break;
   2675     }
   2676 
   2677     case OpRequest::kTernaryOpRequest: {
   2678       const TernaryOpRequest& ternary_op_request =
   2679           request.request().ternary_op_request();
   2680       apply(ternary_op_request.lhs());
   2681       apply(ternary_op_request.rhs());
   2682       apply(ternary_op_request.ehs());
   2683       break;
   2684     }
   2685 
   2686     case OpRequest::kVariadicOpRequest: {
   2687       const VariadicOpRequest& variadic_op_request =
   2688           request.request().variadic_op_request();
   2689       for (const ComputationDataHandle& handle :
   2690            variadic_op_request.operands()) {
   2691         apply(handle);
   2692       }
   2693       break;
   2694     }
   2695 
   2696     case OpRequest::kCallRequest: {
   2697       const CallRequest& call_request = request.request().call_request();
   2698       for (const ComputationDataHandle& handle : call_request.operands()) {
   2699         apply(handle);
   2700       }
   2701       break;
   2702     }
   2703 
   2704     case OpRequest::kCustomCallRequest: {
   2705       const CustomCallRequest& cc_request =
   2706           request.request().custom_call_request();
   2707       for (const ComputationDataHandle& operand : cc_request.operands()) {
   2708         apply(operand);
   2709       }
   2710       break;
   2711     }
   2712 
   2713     case OpRequest::kHostComputeRequest: {
   2714       const HostComputeRequest& hc_request =
   2715           request.request().host_compute_request();
   2716       for (const ComputationDataHandle& operand : hc_request.operands()) {
   2717         apply(operand);
   2718       }
   2719       break;
   2720     }
   2721 
   2722     case OpRequest::kDotRequest: {
   2723       const DotRequest& dot_request = request.request().dot_request();
   2724       apply(dot_request.rhs());
   2725       apply(dot_request.lhs());
   2726       break;
   2727     }
   2728 
   2729     case OpRequest::kUnaryOpRequest: {
   2730       const UnaryOpRequest& unary_op_request =
   2731           request.request().unary_op_request();
   2732       apply(unary_op_request.operand());
   2733       break;
   2734     }
   2735 
   2736     case OpRequest::kBinaryOpRequest: {
   2737       const BinaryOpRequest& binary_op_request =
   2738           request.request().binary_op_request();
   2739       apply(binary_op_request.rhs());
   2740       apply(binary_op_request.lhs());
   2741       break;
   2742     }
   2743 
   2744     case OpRequest::kReducePrecisionRequest: {
   2745       const ReducePrecisionRequest& reduce_precision_request =
   2746           request.request().reduce_precision_request();
   2747       apply(reduce_precision_request.operand());
   2748       break;
   2749     }
   2750 
   2751     case OpRequest::kTraceRequest: {
   2752       const TraceRequest& trace_request = request.request().trace_request();
   2753       apply(trace_request.operand());
   2754       break;
   2755     }
   2756 
   2757     case OpRequest::kSendRequest: {
   2758       const SendRequest& send_request = request.request().send_request();
   2759       apply(send_request.operand());
   2760       break;
   2761     }
   2762 
   2763     case OpRequest::kGatherRequest: {
   2764       const GatherRequest& gather_request = request.request().gather_request();
   2765       apply(gather_request.input());
   2766       apply(gather_request.gather_indices());
   2767       break;
   2768     }
   2769 
   2770     case OpRequest::OP_NOT_SET:
   2771       LOG(FATAL) << "OperationRequest doesn't contain a request";
   2772 
   2773     default:
   2774       LOG(FATAL) << "Unexpected request type: " << request.request().op_case();
   2775   }
   2776 }
   2777 
   2778 void ComputationLowerer::TraversePostorder(
   2779     const ComputationDataHandle& root,
   2780     std::unordered_map<int64, HloInstruction*>* visited,
   2781     const std::function<void(const ComputationDataHandle&)>& visit) {
   2782   // Stack containing {handle, enter} pairs. The 'enter' value describes whether
   2783   // we are entering or leaving 'handle'.
   2784   std::stack<std::pair<ComputationDataHandle, bool>> work;
   2785   work.push({root, true});
   2786   while (!work.empty()) {
   2787     ComputationDataHandle handle;
   2788     bool enter;
   2789     std::tie(handle, enter) = work.top();
   2790     work.pop();
   2791 
   2792     if (enter) {
   2793       // We are entering 'handle'. The first time we enter 'handle', we add it
   2794       // to 'visited' with a nullptr value. If 'handle' is already in 'visited',
   2795       // we do not visit it again. This algorithm only uses the presence of
   2796       // a handle in 'visited', but we use a map so we can use the same data
   2797       // structure to store the HloInstruction outputs.
   2798       if (visited->emplace(handle.handle(), nullptr).second) {
   2799         const OperationRequest& request =
   2800             session_computation_.requests().at(handle.handle());
   2801         // Push the corresponding 'leave' action onto the stack, followed by
   2802         // the operands.
   2803         work.push({handle, false});
   2804         ForEachOperand(request, [&work](const ComputationDataHandle& child) {
   2805           work.push({child, true});
   2806         });
   2807       }
   2808     } else {
   2809       // We are leaving 'handle'. We have visited the operands of 'handle', and
   2810       // now can visit the 'handle' itself.
   2811       visit(handle);
   2812     }
   2813   }
   2814 }
   2815 
   2816 StatusOr<std::unique_ptr<HloComputation>> ComputationLowerer::Lower() {
   2817   // Map from ComputationDataHandle to HLO instruction. Serves as a record of
   2818   // which operations have been visited as well as a cache for looking up
   2819   // ComputationDataHandles as HloInstructions.
   2820   std::unordered_map<int64, HloInstruction*> instructions;
   2821 
   2822   TF_ASSIGN_OR_RETURN(const OperationRequest* root_request,
   2823                       GetRoot(version_, session_computation_));
   2824 
   2825   auto visit = [&](const ComputationDataHandle& handle) {
   2826     Visit(handle, &instructions);
   2827   };
   2828   TraversePostorder(root_request->output_handle(), &instructions, visit);
   2829   HloInstruction* hlo_root =
   2830       instructions.at(root_request->output_handle().handle());
   2831 
   2832   if (include_unreachable_instructions_) {
   2833     // Iterate through all computation data handles, and visit any unvisited
   2834     // operations.
   2835     for (int64 request_num = 1; request_num <= version_; ++request_num) {
   2836       TF_ASSIGN_OR_RETURN(const OperationRequest* request,
   2837                           LookUpRequest(request_num, session_computation_));
   2838       TraversePostorder(request->output_handle(), &instructions, visit);
   2839     }
   2840   }
   2841 
   2842   return hlo_builder_.Build(hlo_root);
   2843 }
   2844 
   2845 HloComputation* ComputationLowerer::ResolveComputation(
   2846     const ComputationHandle& handle,
   2847     VersionedComputationHandle::Version version) {
   2848   const VersionedComputationHandle checked_handle = {handle, version};
   2849   return hlo_resolver_(checked_handle);
   2850 }
   2851 
   2852 HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
   2853     HloInstruction* operand, const Shape& output_shape) {
   2854   auto fadd = [this](std::unique_ptr<HloInstruction> x) {
   2855     return hlo_builder_.AddInstruction(std::move(x));
   2856   };
   2857   return fadd(
   2858       HloInstruction::CreateBroadcastSequence(output_shape, operand, fadd));
   2859 }
   2860 
   2861 void ComputationLowerer::Visit(
   2862     const ComputationDataHandle& handle,
   2863     std::unordered_map<int64, HloInstruction*>* instructions) {
   2864   CHECK_LE(handle.handle(), version_);
   2865   CHECK(instructions->at(handle.handle()) == nullptr);
   2866   const OperationRequest& request =
   2867       session_computation_.requests().at(handle.handle());
   2868   auto add_instruction = [&](std::unique_ptr<HloInstruction> instruction) {
   2869     HloInstruction* hlo_instruction =
   2870         hlo_builder_.AddInstruction(std::move(instruction));
   2871     hlo_instruction->set_metadata(request.request().metadata());
   2872     if (request.request().has_sharding()) {
   2873       OpSharding op_sharding = request.request().sharding();
   2874       hlo_instruction->set_sharding(
   2875           HloSharding::FromProto(op_sharding).ValueOrDie());
   2876     }
   2877     return hlo_instruction;
   2878   };
   2879   auto lookup_instruction = [&](const ComputationDataHandle& handle) {
   2880     return instructions->at(handle.handle());
   2881   };
   2882   HloInstruction* hlo_instruction;
   2883   switch (request.request().op_case()) {
   2884     case OpRequest::kRngRequest: {
   2885       const RngRequest& rng_request = request.request().rng_request();
   2886       std::vector<HloInstruction*> parameters;
   2887       for (const ComputationDataHandle& param : rng_request.parameter()) {
   2888         parameters.push_back(lookup_instruction(param));
   2889       }
   2890       hlo_instruction = add_instruction(HloInstruction::CreateRng(
   2891           request.output_shape(), rng_request.distribution(), parameters));
   2892       break;
   2893     }
   2894 
   2895     case OpRequest::kConstantRequest: {
   2896       const ConstantRequest& constant_request =
   2897           request.request().constant_request();
   2898       hlo_instruction = add_instruction(HloInstruction::CreateConstant(
   2899           Literal::CreateFromProto(constant_request.literal())
   2900               .ConsumeValueOrDie()));
   2901       break;
   2902     }
   2903 
   2904     case OpRequest::kGetTupleElementRequest: {
   2905       const GetTupleElementRequest& get_tuple_element_request =
   2906           request.request().get_tuple_element_request();
   2907       HloInstruction* operand =
   2908           lookup_instruction(get_tuple_element_request.operand());
   2909       hlo_instruction = add_instruction(HloInstruction::CreateGetTupleElement(
   2910           request.output_shape(), operand, get_tuple_element_request.index()));
   2911       break;
   2912     }
   2913 
   2914     case OpRequest::kSliceRequest: {
   2915       const SliceRequest& slice_request = request.request().slice_request();
   2916       HloInstruction* operand = lookup_instruction(slice_request.operand());
   2917       hlo_instruction = add_instruction(HloInstruction::CreateSlice(
   2918           request.output_shape(), operand,
   2919           AsInt64Slice(slice_request.start_indices()),
   2920           AsInt64Slice(slice_request.limit_indices()),
   2921           AsInt64Slice(slice_request.strides())));
   2922       break;
   2923     }
   2924 
   2925     case OpRequest::kDynamicSliceRequest: {
   2926       const DynamicSliceRequest& dynamic_slice_request =
   2927           request.request().dynamic_slice_request();
   2928       HloInstruction* operand =
   2929           lookup_instruction(dynamic_slice_request.operand());
   2930       HloInstruction* start_indices =
   2931           lookup_instruction(dynamic_slice_request.start_indices());
   2932 
   2933       hlo_instruction = add_instruction(HloInstruction::CreateDynamicSlice(
   2934           request.output_shape(), operand, start_indices,
   2935           AsInt64Slice(dynamic_slice_request.slice_sizes())));
   2936       break;
   2937     }
   2938 
   2939     case OpRequest::kDynamicUpdateSliceRequest: {
   2940       const DynamicUpdateSliceRequest& dynamic_update_slice_request =
   2941           request.request().dynamic_update_slice_request();
   2942       HloInstruction* operand =
   2943           lookup_instruction(dynamic_update_slice_request.operand());
   2944       HloInstruction* update =
   2945           lookup_instruction(dynamic_update_slice_request.update());
   2946       HloInstruction* start_indices =
   2947           lookup_instruction(dynamic_update_slice_request.start_indices());
   2948       hlo_instruction =
   2949           add_instruction(HloInstruction::CreateDynamicUpdateSlice(
   2950               request.output_shape(), operand, update, start_indices));
   2951       break;
   2952     }
   2953 
   2954     case OpRequest::kConcatenateRequest: {
   2955       const ConcatenateRequest& concatenate_request =
   2956           request.request().concatenate_request();
   2957       std::vector<HloInstruction*> operands;
   2958       for (const ComputationDataHandle& handle :
   2959            concatenate_request.operands()) {
   2960         HloInstruction* operand = lookup_instruction(handle);
   2961         operands.push_back(operand);
   2962       }
   2963       hlo_instruction = add_instruction(HloInstruction::CreateConcatenate(
   2964           request.output_shape(), operands, concatenate_request.dimension()));
   2965       break;
   2966     }
   2967 
   2968     case OpRequest::kConvolveRequest: {
   2969       const ConvolveRequest& convolve_request =
   2970           request.request().convolve_request();
   2971       HloInstruction* lhs = lookup_instruction(convolve_request.lhs());
   2972       HloInstruction* rhs = lookup_instruction(convolve_request.rhs());
   2973       hlo_instruction = add_instruction(HloInstruction::CreateConvolve(
   2974           request.output_shape(), lhs, rhs, convolve_request.window(),
   2975           convolve_request.dimension_numbers()));
   2976       break;
   2977     }
   2978 
   2979     case OpRequest::kFftRequest: {
   2980       const FftRequest& fft_request = request.request().fft_request();
   2981       HloInstruction* operand = lookup_instruction(fft_request.operand());
   2982       hlo_instruction = add_instruction(HloInstruction::CreateFft(
   2983           request.output_shape(), operand, fft_request.fft_type(),
   2984           AsInt64Slice(fft_request.fft_length())));
   2985       break;
   2986     }
   2987 
   2988     case OpRequest::kDotRequest: {
   2989       const DotRequest& dot_request = request.request().dot_request();
   2990       HloInstruction* lhs = lookup_instruction(dot_request.lhs());
   2991       HloInstruction* rhs = lookup_instruction(dot_request.rhs());
   2992       hlo_instruction = add_instruction(HloInstruction::CreateDot(
   2993           request.output_shape(), lhs, rhs, dot_request.dimension_numbers()));
   2994       break;
   2995     }
   2996 
   2997     case OpRequest::kCrossReplicaSumRequest: {
   2998       const CrossReplicaSumRequest& cross_replica_sum_request =
   2999           request.request().cross_replica_sum_request();
   3000       HloInstruction* operand =
   3001           lookup_instruction(cross_replica_sum_request.operand());
   3002       hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum(
   3003           request.output_shape(), {operand}));
   3004       break;
   3005     }
   3006 
   3007     case OpRequest::kInfeedRequest: {
   3008       const InfeedRequest& infeed_request = request.request().infeed_request();
   3009       hlo_instruction = add_instruction(HloInstruction::CreateInfeed(
   3010           request.output_shape(), infeed_request.config()));
   3011       break;
   3012     }
   3013 
   3014     case OpRequest::kOutfeedRequest: {
   3015       const OutfeedRequest& outfeed_request =
   3016           request.request().outfeed_request();
   3017       HloInstruction* operand = lookup_instruction(outfeed_request.operand());
   3018       hlo_instruction = add_instruction(HloInstruction::CreateOutfeed(
   3019           outfeed_request.shape(), operand, outfeed_request.outfeed_config()));
   3020       break;
   3021     }
   3022 
   3023     case OpRequest::kMapRequest: {
   3024       const MapRequest& map_request = request.request().map_request();
   3025       std::vector<HloInstruction*> operands;
   3026       for (const ComputationDataHandle& handle : map_request.operands()) {
   3027         HloInstruction* operand = lookup_instruction(handle);
   3028         operands.push_back(operand);
   3029       }
   3030       CHECK_EQ(1, request.embedded_computation_versions_size());
   3031       VersionedComputationHandle::Version map_version =
   3032           request.embedded_computation_versions(0);
   3033       HloComputation* map_computation =
   3034           ResolveComputation(map_request.to_apply(), map_version);
   3035       hlo_instruction = add_instruction(HloInstruction::CreateMap(
   3036           request.output_shape(), operands, map_computation));
   3037       break;
   3038     }
   3039 
   3040     case OpRequest::kReduceRequest: {
   3041       const ReduceRequest& reduce_request = request.request().reduce_request();
   3042       HloInstruction* operand = lookup_instruction(reduce_request.operand());
   3043       HloInstruction* init_value =
   3044           lookup_instruction(reduce_request.init_value());
   3045       CHECK_EQ(1, request.embedded_computation_versions_size());
   3046       VersionedComputationHandle::Version reduce_version =
   3047           request.embedded_computation_versions(0);
   3048       HloComputation* reduce_computation =
   3049           ResolveComputation(reduce_request.to_apply(), reduce_version);
   3050       hlo_instruction = add_instruction(HloInstruction::CreateReduce(
   3051           request.output_shape(), operand, init_value,
   3052           AsInt64Slice(reduce_request.dimensions()), reduce_computation));
   3053       break;
   3054     }
   3055 
   3056     case OpRequest::kReduceWindowRequest: {
   3057       const ReduceWindowRequest& reduce_window_request =
   3058           request.request().reduce_window_request();
   3059       HloInstruction* operand =
   3060           lookup_instruction(reduce_window_request.operand());
   3061       HloInstruction* init_value =
   3062           lookup_instruction(reduce_window_request.init_value());
   3063       CHECK_EQ(1, request.embedded_computation_versions_size());
   3064       VersionedComputationHandle::Version reduce_window_version =
   3065           request.embedded_computation_versions(0);
   3066       HloComputation* reduce_window_computation = ResolveComputation(
   3067           reduce_window_request.to_apply(), reduce_window_version);
   3068       hlo_instruction = add_instruction(HloInstruction::CreateReduceWindow(
   3069           request.output_shape(), operand, init_value,
   3070           reduce_window_request.window(), reduce_window_computation));
   3071       break;
   3072     }
   3073 
   3074     case OpRequest::kSelectAndScatterRequest: {
   3075       const SelectAndScatterRequest& select_and_scatter_request =
   3076           request.request().select_and_scatter_request();
   3077       HloInstruction* operand =
   3078           lookup_instruction(select_and_scatter_request.operand());
   3079       HloInstruction* source =
   3080           lookup_instruction(select_and_scatter_request.source());
   3081       HloInstruction* init_value =
   3082           lookup_instruction(select_and_scatter_request.init_value());
   3083       CHECK_EQ(2, request.embedded_computation_versions_size());
   3084       VersionedComputationHandle::Version select_version =
   3085           request.embedded_computation_versions(0);
   3086       VersionedComputationHandle::Version scatter_version =
   3087           request.embedded_computation_versions(1);
   3088       HloComputation* select_computation = ResolveComputation(
   3089           select_and_scatter_request.select(), select_version);
   3090       HloComputation* scatter_computation = ResolveComputation(
   3091           select_and_scatter_request.scatter(), scatter_version);
   3092       hlo_instruction = add_instruction(HloInstruction::CreateSelectAndScatter(
   3093           request.output_shape(), operand, select_computation,
   3094           select_and_scatter_request.window(), source, init_value,
   3095           scatter_computation));
   3096       break;
   3097     }
   3098 
   3099     case OpRequest::kBatchNormTrainingRequest: {
   3100       const BatchNormTrainingRequest& batch_norm_training_request =
   3101           request.request().batch_norm_training_request();
   3102       HloInstruction* operand =
   3103           lookup_instruction(batch_norm_training_request.operand());
   3104       HloInstruction* scale =
   3105           lookup_instruction(batch_norm_training_request.scale());
   3106       HloInstruction* offset =
   3107           lookup_instruction(batch_norm_training_request.offset());
   3108 
   3109       hlo_instruction = add_instruction(HloInstruction::CreateBatchNormTraining(
   3110           request.output_shape(), operand, scale, offset,
   3111           batch_norm_training_request.epsilon(),
   3112           batch_norm_training_request.feature_index()));
   3113       break;
   3114     }
   3115 
   3116     case OpRequest::kBatchNormInferenceRequest: {
   3117       const BatchNormInferenceRequest& batch_norm_inference_request =
   3118           request.request().batch_norm_inference_request();
   3119       HloInstruction* operand =
   3120           lookup_instruction(batch_norm_inference_request.operand());
   3121       HloInstruction* scale =
   3122           lookup_instruction(batch_norm_inference_request.scale());
   3123       HloInstruction* offset =
   3124           lookup_instruction(batch_norm_inference_request.offset());
   3125       HloInstruction* mean =
   3126           lookup_instruction(batch_norm_inference_request.mean());
   3127       HloInstruction* variance =
   3128           lookup_instruction(batch_norm_inference_request.variance());
   3129 
   3130       hlo_instruction =
   3131           add_instruction(HloInstruction::CreateBatchNormInference(
   3132               request.output_shape(), operand, scale, offset, mean, variance,
   3133               batch_norm_inference_request.epsilon(),
   3134               batch_norm_inference_request.feature_index()));
   3135       break;
   3136     }
   3137 
   3138     case OpRequest::kBatchNormGradRequest: {
   3139       const BatchNormGradRequest& batch_norm_grad_request =
   3140           request.request().batch_norm_grad_request();
   3141 
   3142       HloInstruction* operand =
   3143           lookup_instruction(batch_norm_grad_request.operand());
   3144       HloInstruction* scale =
   3145           lookup_instruction(batch_norm_grad_request.scale());
   3146       HloInstruction* mean = lookup_instruction(batch_norm_grad_request.mean());
   3147       HloInstruction* variance =
   3148           lookup_instruction(batch_norm_grad_request.variance());
   3149       HloInstruction* grad_output =
   3150           lookup_instruction(batch_norm_grad_request.grad_output());
   3151 
   3152       hlo_instruction = add_instruction(HloInstruction::CreateBatchNormGrad(
   3153           request.output_shape(), operand, scale, mean, variance, grad_output,
   3154           batch_norm_grad_request.epsilon(),
   3155           batch_norm_grad_request.feature_index()));
   3156       break;
   3157     }
   3158 
   3159     case OpRequest::kBroadcastRequest: {
   3160       const BroadcastRequest& broadcast_request =
   3161           request.request().broadcast_request();
   3162       HloInstruction* operand = lookup_instruction(broadcast_request.operand());
   3163       std::vector<int64> broadcast_dimensions;
   3164       // The client-level broadcast instruction just appends dimensions on the
   3165       // left (adds lowest numbered dimensions). The HLO broadcast op is more
   3166       // flexible and can add new dimensions anywhere. The broadcast_dimensions
   3167       // maps operand dimensions to dimensions in the broadcast output, so
   3168       // to append dimensions on the left the broadcast_dimensions should just
   3169       // be the n highest dimension numbers of the output shape where n is
   3170       // the number of input dimensions.
   3171       broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape()));
   3172       for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) {
   3173         broadcast_dimensions.push_back(i +
   3174                                        ShapeUtil::Rank(request.output_shape()) -
   3175                                        ShapeUtil::Rank(operand->shape()));
   3176       }
   3177       hlo_instruction = add_instruction(HloInstruction::CreateBroadcast(
   3178           request.output_shape(), operand, broadcast_dimensions));
   3179       break;
   3180     }
   3181 
   3182     case OpRequest::kReshapeRequest: {
   3183       const ReshapeRequest& reshape_request =
   3184           request.request().reshape_request();
   3185       HloInstruction* operand = lookup_instruction(reshape_request.operand());
   3186       HloInstruction* transposed;
   3187       if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) {
   3188         transposed = operand;
   3189       } else {
   3190         transposed = add_instruction(HloInstruction::CreateTranspose(
   3191             ShapeUtil::PermuteDimensions(
   3192                 InversePermutation(AsInt64Slice(reshape_request.dimensions())),
   3193                 operand->shape()),
   3194             operand, AsInt64Slice(reshape_request.dimensions())));
   3195       }
   3196       hlo_instruction = add_instruction(
   3197           HloInstruction::CreateReshape(request.output_shape(), transposed));
   3198       break;
   3199     }
   3200 
   3201     case OpRequest::kTransposeRequest: {
   3202       const TransposeRequest& transpose_request =
   3203           request.request().transpose_request();
   3204       HloInstruction* operand = lookup_instruction(transpose_request.operand());
   3205       hlo_instruction = add_instruction(HloInstruction::CreateTranspose(
   3206           ShapeUtil::PermuteDimensions(
   3207               InversePermutation(AsInt64Slice(transpose_request.dimensions())),
   3208               operand->shape()),
   3209           operand, AsInt64Slice(transpose_request.dimensions())));
   3210       break;
   3211     }
   3212 
   3213     case OpRequest::kReverseRequest: {
   3214       const ReverseRequest& reverse_request =
   3215           request.request().reverse_request();
   3216       HloInstruction* operand = lookup_instruction(reverse_request.operand());
   3217       hlo_instruction = add_instruction(HloInstruction::CreateReverse(
   3218           request.output_shape(), operand,
   3219           AsInt64Slice(reverse_request.dimensions())));
   3220       break;
   3221     }
   3222 
   3223     case OpRequest::kPadRequest: {
   3224       const PadRequest& pad_request = request.request().pad_request();
   3225       HloInstruction* operand = lookup_instruction(pad_request.operand());
   3226       HloInstruction* padding_value =
   3227           lookup_instruction(pad_request.padding_value());
   3228       hlo_instruction = add_instruction(HloInstruction::CreatePad(
   3229           request.output_shape(), operand, padding_value,
   3230           pad_request.padding_config()));
   3231       break;
   3232     }
   3233 
   3234     case OpRequest::kRecvRequest: {
   3235       const RecvRequest& recv_request = request.request().recv_request();
   3236       HloInstruction* recv = add_instruction(HloInstruction::CreateRecv(
   3237           request.output_shape(), recv_request.channel_handle().handle()));
   3238       hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv));
   3239       break;
   3240     }
   3241 
   3242     case OpRequest::kParameterRequest: {
   3243       const ParameterRequest& parameter_request =
   3244           request.request().parameter_request();
   3245       hlo_instruction = add_instruction(HloInstruction::CreateParameter(
   3246           parameter_request.parameter(), request.output_shape(),
   3247           parameter_request.name()));
   3248       break;
   3249     }
   3250 
   3251     case OpRequest::kConvertRequest: {
   3252       const ConvertRequest& convert_request =
   3253           request.request().convert_request();
   3254       HloInstruction* operand = lookup_instruction(convert_request.operand());
   3255       hlo_instruction = add_instruction(
   3256           HloInstruction::CreateConvert(request.output_shape(), operand));
   3257       break;
   3258     }
   3259 
   3260     case OpRequest::kBitcastConvertRequest: {
   3261       const ConvertRequest& convert_request =
   3262           request.request().bitcast_convert_request();
   3263       HloInstruction* operand = lookup_instruction(convert_request.operand());
   3264       hlo_instruction = add_instruction(HloInstruction::CreateBitcastConvert(
   3265           request.output_shape(), operand));
   3266       break;
   3267     }
   3268 
   3269     case OpRequest::kWhileRequest: {
   3270       const WhileRequest& while_request = request.request().while_request();
   3271       CHECK_EQ(2, request.embedded_computation_versions_size());
   3272       VersionedComputationHandle::Version condition_version =
   3273           request.embedded_computation_versions(0);
   3274       HloComputation* condition =
   3275           ResolveComputation(while_request.condition(), condition_version);
   3276       VersionedComputationHandle::Version body_version =
   3277           request.embedded_computation_versions(1);
   3278       HloComputation* body =
   3279           ResolveComputation(while_request.body(), body_version);
   3280       HloInstruction* init = lookup_instruction(while_request.init());
   3281       hlo_instruction = add_instruction(HloInstruction::CreateWhile(
   3282           request.output_shape(), condition, body, init));
   3283       break;
   3284     }
   3285 
   3286     case OpRequest::kConditionalRequest: {
   3287       const ConditionalRequest& conditional_request =
   3288           request.request().conditional_request();
   3289       CHECK_EQ(2, request.embedded_computation_versions_size());
   3290       VersionedComputationHandle::Version true_computation_version =
   3291           request.embedded_computation_versions(0);
   3292       HloComputation* true_computation = ResolveComputation(
   3293           conditional_request.true_computation(), true_computation_version);
   3294       VersionedComputationHandle::Version false_computation_version =
   3295           request.embedded_computation_versions(1);
   3296       HloComputation* false_computation = ResolveComputation(
   3297           conditional_request.false_computation(), false_computation_version);
   3298       HloInstruction* predicate =
   3299           lookup_instruction(conditional_request.predicate());
   3300       HloInstruction* true_operand =
   3301           lookup_instruction(conditional_request.true_operand());
   3302       HloInstruction* false_operand =
   3303           lookup_instruction(conditional_request.false_operand());
   3304       hlo_instruction = add_instruction(HloInstruction::CreateConditional(
   3305           request.output_shape(), predicate, true_operand, true_computation,
   3306           false_operand, false_computation));
   3307       break;
   3308     }
   3309 
   3310     case OpRequest::kTernaryOpRequest: {
   3311       const TernaryOpRequest& ternary_op_request =
   3312           request.request().ternary_op_request();
   3313       HloInstruction* lhs = lookup_instruction(ternary_op_request.lhs());
   3314       HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs());
   3315       HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs());
   3316       auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop());
   3317 
   3318       if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) {
   3319         if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) {
   3320           // lhs side is being implicitly broadcast. Change to explicit.
   3321           lhs =
   3322               ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape());
   3323         }
   3324 
   3325         if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) {
   3326           rhs =
   3327               ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape());
   3328         }
   3329 
   3330         if (!ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) {
   3331           ehs =
   3332               ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape());
   3333         }
   3334       }
   3335 
   3336       hlo_instruction = add_instruction(HloInstruction::CreateTernary(
   3337           request.output_shape(), hlo_opcode, lhs, rhs, ehs));
   3338       break;
   3339     }
   3340 
   3341     case OpRequest::kVariadicOpRequest: {
   3342       const VariadicOpRequest& variadic_op_request =
   3343           request.request().variadic_op_request();
   3344       std::vector<HloInstruction*> operands;
   3345       for (const ComputationDataHandle& handle :
   3346            variadic_op_request.operands()) {
   3347         HloInstruction* operand = lookup_instruction(handle);
   3348         operands.push_back(operand);
   3349       }
   3350       auto hlo_opcode =
   3351           VariadicOperationToHloOpcode(variadic_op_request.varop());
   3352       hlo_instruction = add_instruction(HloInstruction::CreateVariadic(
   3353           request.output_shape(), hlo_opcode, operands));
   3354       break;
   3355     }
   3356 
   3357     case OpRequest::kCallRequest: {
   3358       const CallRequest& call_request = request.request().call_request();
   3359       std::vector<HloInstruction*> operands;
   3360       for (const ComputationDataHandle& handle : call_request.operands()) {
   3361         operands.push_back(lookup_instruction(handle));
   3362       }
   3363       CHECK_EQ(1, request.embedded_computation_versions_size());
   3364       VersionedComputationHandle::Version call_version =
   3365           request.embedded_computation_versions(0);
   3366       HloComputation* call_computation =
   3367           ResolveComputation(call_request.to_apply(), call_version);
   3368       hlo_instruction = add_instruction(HloInstruction::CreateCall(
   3369           request.output_shape(), operands, call_computation));
   3370       break;
   3371     }
   3372 
   3373     case OpRequest::kCustomCallRequest: {
   3374       const CustomCallRequest& cc_request =
   3375           request.request().custom_call_request();
   3376       std::vector<HloInstruction*> operands;
   3377       for (const ComputationDataHandle& operand : cc_request.operands()) {
   3378         operands.push_back(lookup_instruction(operand));
   3379       }
   3380       hlo_instruction = add_instruction(HloInstruction::CreateCustomCall(
   3381           cc_request.shape(), operands, cc_request.call_target_name()));
   3382       break;
   3383     }
   3384 
   3385     case OpRequest::kHostComputeRequest: {
   3386       const HostComputeRequest& host_compute_request =
   3387           request.request().host_compute_request();
   3388       std::vector<HloInstruction*> operands;
   3389       for (const ComputationDataHandle& operand :
   3390            host_compute_request.operands()) {
   3391         operands.push_back(lookup_instruction(operand));
   3392       }
   3393       auto output_shape = host_compute_request.shape();
   3394       auto channel_name = host_compute_request.channel_name();
   3395       auto cost_estimate_ns = host_compute_request.cost_estimate_ns();
   3396       hlo_instruction = add_instruction(HloInstruction::CreateHostCompute(
   3397           output_shape, operands, channel_name, cost_estimate_ns));
   3398       break;
   3399     }
   3400 
   3401     case OpRequest::kUnaryOpRequest: {
   3402       const UnaryOpRequest& unary_op_request =
   3403           request.request().unary_op_request();
   3404       HloInstruction* operand = lookup_instruction(unary_op_request.operand());
   3405       auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop());
   3406       hlo_instruction = add_instruction(HloInstruction::CreateUnary(
   3407           request.output_shape(), hlo_opcode, operand));
   3408       break;
   3409     }
   3410 
   3411     case OpRequest::kBinaryOpRequest: {
   3412       const BinaryOpRequest& binary_op_request =
   3413           request.request().binary_op_request();
   3414       HloInstruction* lhs = lookup_instruction(binary_op_request.lhs());
   3415       HloInstruction* rhs = lookup_instruction(binary_op_request.rhs());
   3416       auto hlo_opcode = BinaryOperationToHloOpcode(binary_op_request.binop());
   3417       if (binary_op_request.broadcast_dimensions_size() > 0 &&
   3418           ShapeUtil::Rank(lhs->shape()) != ShapeUtil::Rank(rhs->shape())) {
   3419         // Emit a broadcast instruction to perform the "broadcast in dimension"
   3420         // operation.
   3421         HloInstruction* operand_to_broadcast =
   3422             ShapeUtil::Rank(lhs->shape()) < ShapeUtil::Rank(rhs->shape()) ? lhs
   3423                                                                           : rhs;
   3424         CHECK_EQ(ShapeUtil::Rank(operand_to_broadcast->shape()),
   3425                  binary_op_request.broadcast_dimensions().size());
   3426 
   3427         // Construct the bounds of the shape of the kBroadcast instruction
   3428         // responsible for the in-dimension broadcast.
   3429         std::vector<int64> output_dimensions;
   3430         for (int64 size : request.output_shape().dimensions()) {
   3431           output_dimensions.push_back(size);
   3432         }
   3433         for (int64 operand_dim = 0;
   3434              operand_dim < ShapeUtil::Rank(operand_to_broadcast->shape());
   3435              ++operand_dim) {
   3436           int64 output_dim =
   3437               binary_op_request.broadcast_dimensions()[operand_dim];
   3438           output_dimensions[output_dim] =
   3439               operand_to_broadcast->shape().dimensions(operand_dim);
   3440         }
   3441 
   3442         Shape broadcast_shape = ShapeUtil::MakeShape(
   3443             operand_to_broadcast->shape().element_type(), output_dimensions);
   3444 
   3445         // The broadcast semantics of a client-level binary op broadcast is
   3446         // identical to the HLO broadcast semantics so the broadcast_dimensions
   3447         // field can just be passed to the instruction builder.
   3448         HloInstruction* broadcasted_operand =
   3449             add_instruction(HloInstruction::CreateBroadcast(
   3450                 broadcast_shape, operand_to_broadcast,
   3451                 AsInt64Slice(binary_op_request.broadcast_dimensions())));
   3452 
   3453         lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs;
   3454         rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs;
   3455       }
   3456       if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) {
   3457         if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) {
   3458           // lhs side is being implicitly broadcast. Change to explicit.
   3459           lhs =
   3460               ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape());
   3461         }
   3462 
   3463         if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) {
   3464           rhs =
   3465               ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape());
   3466         }
   3467       }
   3468       hlo_instruction = add_instruction(HloInstruction::CreateBinary(
   3469           request.output_shape(), hlo_opcode, lhs, rhs));
   3470       break;
   3471     }
   3472 
   3473     case OpRequest::kReducePrecisionRequest: {
   3474       const ReducePrecisionRequest& reduce_precision_request =
   3475           request.request().reduce_precision_request();
   3476       HloInstruction* operand =
   3477           lookup_instruction(reduce_precision_request.operand());
   3478       auto exponent_bits = reduce_precision_request.exponent_bits();
   3479       auto mantissa_bits = reduce_precision_request.mantissa_bits();
   3480       hlo_instruction = add_instruction(HloInstruction::CreateReducePrecision(
   3481           request.output_shape(), operand, exponent_bits, mantissa_bits));
   3482       break;
   3483     }
   3484 
   3485     case OpRequest::kTraceRequest: {
   3486       const TraceRequest& trace_request = request.request().trace_request();
   3487       HloInstruction* operand = lookup_instruction(trace_request.operand());
   3488       hlo_instruction = add_instruction(
   3489           HloInstruction::CreateTrace(trace_request.tag(), operand));
   3490       operand->set_tracing(hlo_instruction);
   3491       break;
   3492     }
   3493 
   3494     case OpRequest::kSendRequest: {
   3495       const SendRequest& send_request = request.request().send_request();
   3496       HloInstruction* operand = lookup_instruction(send_request.operand());
   3497       HloInstruction* send = add_instruction(HloInstruction::CreateSend(
   3498           operand, send_request.channel_handle().handle()));
   3499       hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send));
   3500       break;
   3501     }
   3502 
   3503     case OpRequest::kGatherRequest: {
   3504       const GatherRequest& gather_request = request.request().gather_request();
   3505       HloInstruction* input_operand =
   3506           lookup_instruction(gather_request.input());
   3507       HloInstruction* gather_indices_operand =
   3508           lookup_instruction(gather_request.gather_indices());
   3509       std::vector<int64> window_bounds;
   3510       c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds));
   3511       hlo_instruction = add_instruction(HloInstruction::CreateGather(
   3512           request.output_shape(), input_operand, gather_indices_operand,
   3513           gather_request.dimension_numbers(), window_bounds));
   3514       break;
   3515     }
   3516 
   3517     case OpRequest::OP_NOT_SET:
   3518       LOG(FATAL) << "OperationRequest doesn't contain a request";
   3519 
   3520     default:
   3521       LOG(FATAL) << "Unexpected request type: " << request.request().op_case();
   3522   }
   3523   (*instructions)[handle.handle()] = hlo_instruction;
   3524 }  // NOLINT(readability/fn_size)
   3525 
   3526 }  // namespace
   3527 
   3528 StatusOr<std::unique_ptr<HloComputation>> UserComputation::BuildHloComputation(
   3529     VersionedComputationHandle::Version version,
   3530     HloComputationResolver hlo_resolver, const DebugOptions& debug_options,
   3531     bool include_unreachable_instructions) const {
   3532   tensorflow::mutex_lock lock(mutex_);
   3533 
   3534   VLOG(2) << "Building HloComputation from UserComputation " << name_
   3535           << " at version " << version;
   3536   XLA_VLOG_LINES(3, session_computation_.DebugString());
   3537 
   3538   TF_ASSIGN_OR_RETURN(
   3539       std::unique_ptr<HloComputation> hlo_computation,
   3540       ComputationLowerer::Lower(
   3541           tensorflow::strings::StrCat(name(), ".v", version),
   3542           session_computation_, version, std::move(hlo_resolver), debug_options,
   3543           include_unreachable_instructions));
   3544 
   3545   return std::move(hlo_computation);
   3546 }
   3547 
   3548 }  // namespace xla
   3549