Home | History | Annotate | Download | only in client
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/client/computation_builder.h"
     17 
     18 #include <stddef.h>
     19 #include <array>
     20 #include <numeric>
     21 #include <set>
     22 #include <vector>
     23 
     24 #include "tensorflow/compiler/xla/ptr_util.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/status_macros.h"
     27 #include "tensorflow/compiler/xla/types.h"
     28 #include "tensorflow/compiler/xla/util.h"
     29 #include "tensorflow/compiler/xla/xla.pb.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/lib/strings/strcat.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 #include "tensorflow/core/platform/protobuf.h"
     34 
     35 namespace xla {
     36 
     37 ComputationBuilder::ComputationBuilder(Client* client,
     38                                        const string& computation_name)
     39     : name_(computation_name), client_(client) {}
     40 
     41 ComputationBuilder::~ComputationBuilder() {}
     42 
     43 void ComputationBuilder::NoteError(const Status& error) {
     44   if (die_immediately_on_error_) {
     45     LOG(FATAL) << "error building computation: " << error;
     46   }
     47 
     48   if (first_error_.ok()) {
     49     first_error_ = error;
     50     first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
     51   }
     52 }
     53 
     54 std::unique_ptr<ComputationBuilder> ComputationBuilder::CreateSubBuilder(
     55     const string& computation_name) {
     56   auto sub_builder = MakeUnique<ComputationBuilder>(client_, computation_name);
     57   sub_builder->parent_builder_ = this;
     58   sub_builder->die_immediately_on_error_ = die_immediately_on_error_;
     59   return sub_builder;
     60 }
     61 
     62 Status ComputationBuilder::PrepareComputation() {
     63   TF_RETURN_IF_ERROR(first_error_);
     64 
     65   if (!computation_.IsNull()) {
     66     return Status::OK();
     67   }
     68 
     69   ComputationRequest request;
     70   request.set_name(name_);
     71   ComputationResponse response;
     72 
     73   VLOG(2) << "making computation request";
     74   Status s = client_->stub()->Computation(&request, &response);
     75   VLOG(2) << "done with computation request";
     76 
     77   if (!s.ok()) {
     78     NoteError(s);
     79     return first_error_;
     80   }
     81 
     82   computation_ = Computation(client_->stub(), response.computation());
     83   return Status::OK();
     84 }
     85 
     86 Status ComputationBuilder::RunOp(OpRequest* op_request,
     87                                  OpResponse* op_response) {
     88   TF_RETURN_IF_ERROR(first_error_);
     89   TF_RETURN_IF_ERROR(PrepareComputation());
     90 
     91   // Fill in fields that are set on every OpRequest.
     92   *op_request->mutable_computation() = computation_.handle();
     93   *op_request->mutable_metadata() = metadata_;
     94   if (sharding_) {
     95     *op_request->mutable_sharding() = *sharding_;
     96   }
     97 
     98   const string& op_name =
     99       OpRequest::descriptor()->FindFieldByNumber(op_request->op_case())->name();
    100   VLOG(2) << "running op request: " << op_name;
    101   Status status = client_->stub()->Op(op_request, op_response);
    102   VLOG(2) << "done with op request: " << op_name;
    103   return status;
    104 }
    105 
    106 void ComputationBuilder::RunOpAndNoteError(OpRequest* op_request) {
    107   OpResponse op_response;
    108   Status status = RunOp(op_request, &op_response);
    109   if (!status.ok()) {
    110     NoteError(status);
    111   }
    112 }
    113 
    114 ComputationDataHandle ComputationBuilder::RunOpAndParseResponse(
    115     OpRequest* op_request) {
    116   OpResponse op_response;
    117   Status status = RunOp(op_request, &op_response);
    118   if (!status.ok()) {
    119     NoteError(status);
    120     return ComputationDataHandle();
    121   }
    122   if (op_response.output().handle() == 0) {
    123     NoteError(InternalError("No output handle"));
    124     return ComputationDataHandle();
    125   }
    126   return op_response.output();
    127 }
    128 
    129 bool ComputationBuilder::MakeWindow(
    130     tensorflow::gtl::ArraySlice<int64> window_dimensions,
    131     tensorflow::gtl::ArraySlice<int64> window_strides,
    132     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
    133     tensorflow::gtl::ArraySlice<int64> lhs_dilation,
    134     tensorflow::gtl::ArraySlice<int64> rhs_dilation, Window* window) {
    135   const auto verify_size = [&](const size_t x, const char* x_name) {
    136     if (x == 0 || x == window_dimensions.size()) {
    137       return true;
    138     } else {
    139       NoteError(InvalidArgument(
    140           "%s", tensorflow::strings::StrCat(
    141                     "Window has different number of window dimensions than of ",
    142                     x_name, "\nNumber of window dimensions: ",
    143                     window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
    144                     "\n")
    145                     .c_str()));  //
    146       return false;
    147     }
    148   };
    149   if (!verify_size(window_strides.size(), "window strides") ||
    150       !verify_size(padding.size(), "padding entries") ||
    151       !verify_size(lhs_dilation.size(), "lhs dilation factors") ||
    152       !verify_size(rhs_dilation.size(), "rhs dilation factors")) {
    153     return false;
    154   }
    155 
    156   window->Clear();
    157   for (size_t i = 0; i < window_dimensions.size(); i++) {
    158     auto dim = window->add_dimensions();
    159     dim->set_size(window_dimensions[i]);
    160     if (!window_strides.empty()) {
    161       dim->set_stride(window_strides[i]);
    162     } else {
    163       dim->set_stride(1);
    164     }
    165     if (!padding.empty()) {
    166       dim->set_padding_low(padding[i].first);
    167       dim->set_padding_high(padding[i].second);
    168     } else {
    169       dim->set_padding_low(0);
    170       dim->set_padding_high(0);
    171     }
    172     if (!lhs_dilation.empty()) {
    173       dim->set_base_dilation(lhs_dilation[i]);
    174     } else {
    175       dim->set_base_dilation(1);
    176     }
    177     if (!rhs_dilation.empty()) {
    178       dim->set_window_dilation(rhs_dilation[i]);
    179     } else {
    180       dim->set_window_dilation(1);
    181     }
    182     dim->set_window_reversal(false);
    183   }
    184   return true;
    185 }
    186 
    187 ComputationDataHandle ComputationBuilder::ConstantLiteral(
    188     const Literal& literal) {
    189   OpRequest op_request;
    190   ConstantRequest* request = op_request.mutable_constant_request();
    191   *request->mutable_literal() = literal.ToProto();
    192   VLOG(3) << "created constant: " << request->literal().ShortDebugString();
    193   return RunOpAndParseResponse(&op_request);
    194 }
    195 
    196 ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number,
    197                                                     const Shape& shape,
    198                                                     const string& name) {
    199   OpRequest op_request;
    200   ParameterRequest* request = op_request.mutable_parameter_request();
    201   *request->mutable_shape() = shape;
    202   request->set_parameter(parameter_number);
    203   request->set_name(name);
    204   return RunOpAndParseResponse(&op_request);
    205 }
    206 
    207 StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShapeWithoutNoteError(
    208     const ComputationDataHandle& operand) {
    209   GetLocalShapeRequest request;
    210   *request.mutable_computation() = computation_.handle();
    211   *request.mutable_operand() = operand;
    212   GetLocalShapeResponse response;
    213 
    214   VLOG(2) << "making get-shape request";
    215   TF_RETURN_IF_ERROR(client_->stub()->GetLocalShape(&request, &response));
    216   VLOG(2) << "done with request";
    217 
    218   TF_RET_CHECK(response.has_shape());
    219   std::unique_ptr<Shape> shape = WrapUnique(response.release_shape());
    220   TF_RET_CHECK(shape != nullptr);
    221   return std::move(shape);
    222 }
    223 
    224 StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShape(
    225     const ComputationDataHandle& operand) {
    226   TF_RETURN_IF_ERROR(first_error_);
    227 
    228   auto status_or_shape = GetShapeWithoutNoteError(operand);
    229   if (!status_or_shape.ok()) {
    230     NoteError(status_or_shape.status());
    231     return first_error_;
    232   }
    233   return status_or_shape;
    234 }
    235 
    236 StatusOr<ProgramShape> ComputationBuilder::GetProgramShape() {
    237   TF_RETURN_IF_ERROR(first_error_);
    238 
    239   GetComputationShapeRequest request;
    240   *request.mutable_computation() = computation_.handle();
    241   GetComputationShapeResponse response;
    242 
    243   VLOG(2) << "making get-program-shape-request";
    244   Status status = client_->stub()->GetComputationShape(&request, &response);
    245   VLOG(2) << "done with get-program-shape-request";
    246 
    247   if (!status.ok()) {
    248     first_error_ = status;
    249     return status;
    250   }
    251 
    252   TF_RET_CHECK(response.has_program_shape());
    253   return std::move(*response.mutable_program_shape());
    254 }
    255 
    256 ComputationDataHandle ComputationBuilder::CheckShape(
    257     const ComputationDataHandle& operand, const Shape& expected_shape) {
    258   std::unique_ptr<Shape> actual_shape = GetShape(operand).ConsumeValueOrDie();
    259   CHECK(ShapeUtil::Equal(expected_shape, *actual_shape))
    260       << "want " << ShapeUtil::HumanString(expected_shape) << " got "
    261       << ShapeUtil::HumanString(*actual_shape);
    262   return operand;
    263 }
    264 
    265 void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs,
    266                                         const ComputationDataHandle& rhs) {
    267   std::unique_ptr<Shape> lhs_shape = GetShape(lhs).ConsumeValueOrDie();
    268   std::unique_ptr<Shape> rhs_shape = GetShape(rhs).ConsumeValueOrDie();
    269   VLOG(2) << "checking " << ShapeUtil::HumanString(*lhs_shape) << " equals "
    270           << ShapeUtil::HumanString(*rhs_shape);
    271   CHECK(ShapeUtil::Equal(*lhs_shape, *rhs_shape))
    272       << "lhs " << ShapeUtil::HumanString(*lhs_shape) << " rhs "
    273       << ShapeUtil::HumanString(*rhs_shape);
    274 }
    275 
    276 ComputationDataHandle ComputationBuilder::Slice(
    277     const ComputationDataHandle& operand,
    278     tensorflow::gtl::ArraySlice<int64> start_indices,
    279     tensorflow::gtl::ArraySlice<int64> limit_indices,
    280     tensorflow::gtl::ArraySlice<int64> strides) {
    281   OpRequest op_request;
    282   SliceRequest* request = op_request.mutable_slice_request();
    283   *request->mutable_operand() = operand;
    284   for (int64 index : start_indices) {
    285     request->add_start_indices(index);
    286   }
    287   for (int64 index : limit_indices) {
    288     request->add_limit_indices(index);
    289   }
    290   for (int64 index : strides) {
    291     request->add_strides(index);
    292   }
    293   return RunOpAndParseResponse(&op_request);
    294 }
    295 
    296 ComputationDataHandle ComputationBuilder::SliceInDim(
    297     const ComputationDataHandle& operand, int64 start_index, int64 limit_index,
    298     int64 stride, int64 dimno) {
    299   StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
    300   if (!shape_status.ok()) {
    301     NoteError(shape_status.status());
    302     return ComputationDataHandle{};
    303   }
    304   const Shape& shape = *shape_status.ValueOrDie();
    305   std::vector<int64> starts(ShapeUtil::Rank(shape), 0);
    306   std::vector<int64> limits(shape.dimensions().begin(),
    307                             shape.dimensions().end());
    308   std::vector<int64> strides(ShapeUtil::Rank(shape), 1);
    309   starts[dimno] = start_index;
    310   limits[dimno] = limit_index;
    311   strides[dimno] = stride;
    312   return Slice(operand, starts, limits, strides);
    313 }
    314 
    315 ComputationDataHandle ComputationBuilder::DynamicSlice(
    316     const ComputationDataHandle& operand,
    317     const ComputationDataHandle& start_indices,
    318     tensorflow::gtl::ArraySlice<int64> slice_sizes) {
    319   OpRequest op_request;
    320   DynamicSliceRequest* request = op_request.mutable_dynamic_slice_request();
    321   *request->mutable_operand() = operand;
    322   *request->mutable_start_indices() = start_indices;
    323   for (int64 index : slice_sizes) {
    324     request->add_slice_sizes(index);
    325   }
    326   return RunOpAndParseResponse(&op_request);
    327 }
    328 
    329 ComputationDataHandle ComputationBuilder::DynamicUpdateSlice(
    330     const ComputationDataHandle& operand, const ComputationDataHandle& update,
    331     const ComputationDataHandle& start_indices) {
    332   OpRequest op_request;
    333   DynamicUpdateSliceRequest* request =
    334       op_request.mutable_dynamic_update_slice_request();
    335   *request->mutable_operand() = operand;
    336   *request->mutable_update() = update;
    337   *request->mutable_start_indices() = start_indices;
    338   return RunOpAndParseResponse(&op_request);
    339 }
    340 
    341 ComputationDataHandle ComputationBuilder::ConcatInDim(
    342     tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
    343     int64 dimension) {
    344   OpRequest op_request;
    345   ConcatenateRequest* request = op_request.mutable_concatenate_request();
    346   for (const ComputationDataHandle& operand : operands) {
    347     *request->add_operands() = operand;
    348   }
    349   request->set_dimension(dimension);
    350   return RunOpAndParseResponse(&op_request);
    351 }
    352 
    353 ComputationDataHandle ComputationBuilder::Broadcast(
    354     const ComputationDataHandle& operand,
    355     tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
    356   OpRequest op_request;
    357   BroadcastRequest* request = op_request.mutable_broadcast_request();
    358   *request->mutable_operand() = operand;
    359   for (int64 size : broadcast_sizes) {
    360     request->add_broadcast_sizes(size);
    361   }
    362   return RunOpAndParseResponse(&op_request);
    363 }
    364 
    365 ComputationDataHandle ComputationBuilder::Pad(
    366     const ComputationDataHandle& operand,
    367     const ComputationDataHandle& padding_value,
    368     const PaddingConfig& padding_config) {
    369   OpRequest op_request;
    370   PadRequest* request = op_request.mutable_pad_request();
    371   *request->mutable_operand() = operand;
    372   *request->mutable_padding_value() = padding_value;
    373   *request->mutable_padding_config() = padding_config;
    374   return RunOpAndParseResponse(&op_request);
    375 }
    376 
    377 ComputationDataHandle ComputationBuilder::Reshape(
    378     const ComputationDataHandle& operand,
    379     tensorflow::gtl::ArraySlice<int64> dimensions,
    380     tensorflow::gtl::ArraySlice<int64> new_sizes) {
    381   OpRequest op_request;
    382   ReshapeRequest* request = op_request.mutable_reshape_request();
    383   *request->mutable_operand() = operand;
    384   for (int64 dimension : dimensions) {
    385     request->add_dimensions(dimension);
    386   }
    387   for (int64 new_size : new_sizes) {
    388     request->add_new_sizes(new_size);
    389   }
    390   return RunOpAndParseResponse(&op_request);
    391 }
    392 
    393 ComputationDataHandle ComputationBuilder::Reshape(
    394     const ComputationDataHandle& operand,
    395     tensorflow::gtl::ArraySlice<int64> new_sizes) {
    396   if (!first_error_.ok()) {
    397     return ComputationDataHandle();
    398   }
    399 
    400   StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
    401   if (!shape.ok()) {
    402     return ComputationDataHandle();
    403   }
    404   std::vector<int64> dimensions(shape.ValueOrDie()->dimensions().size());
    405   std::iota(dimensions.begin(), dimensions.end(), 0);
    406   return Reshape(operand, dimensions, new_sizes);
    407 }
    408 
    409 ComputationDataHandle ComputationBuilder::Collapse(
    410     const ComputationDataHandle& operand,
    411     tensorflow::gtl::ArraySlice<int64> dims_to_collapse) {
    412   if (!first_error_.ok()) {
    413     return ComputationDataHandle();
    414   }
    415 
    416   // Don't support out-of-order collapse here.
    417   // Checks that the collapsed dimensions are in order and consecutive.
    418   for (tensorflow::gtl::ArraySlice<int64>::size_type i = 1;
    419        i < dims_to_collapse.size(); ++i) {
    420     if (dims_to_collapse[i] - 1 != dims_to_collapse[i - 1]) {
    421       NoteError(InvalidArgument(
    422           "Collapsed dimensions are not in order and consecutive."));
    423       return ComputationDataHandle();
    424     }
    425   }
    426 
    427   // Create a new sizes vector from the old shape, replacing the collapsed
    428   // dimensions by the product of their sizes.
    429   StatusOr<std::unique_ptr<Shape>> shape_or_status = GetShape(operand);
    430   if (!shape_or_status.ok()) {
    431     return ComputationDataHandle();
    432   }
    433   std::unique_ptr<Shape> original_shape = shape_or_status.ConsumeValueOrDie();
    434 
    435   VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
    436   VLOG(3) << "dims to collapse: "
    437           << tensorflow::str_util::Join(dims_to_collapse, ",");
    438 
    439   if (dims_to_collapse.size() <= 1) {
    440     // Not collapsing anything, trivially we can return the operand versus
    441     // enqueueing a trivial reshape.
    442     return operand;
    443   }
    444 
    445   std::vector<int64> new_sizes;
    446   for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) {
    447     if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) {
    448       new_sizes.push_back(original_shape->dimensions(i));
    449     } else {
    450       new_sizes.back() *= original_shape->dimensions(i);
    451     }
    452   }
    453 
    454   VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",")
    455           << "]";
    456 
    457   return Reshape(operand, new_sizes);
    458 }
    459 
    460 void ComputationBuilder::Trace(const string& tag,
    461                                const ComputationDataHandle& operand) {
    462   OpRequest op_request;
    463   TraceRequest* request = op_request.mutable_trace_request();
    464   request->set_tag(tag);
    465   *request->mutable_operand() = operand;
    466   RunOpAndNoteError(&op_request);
    467 }
    468 
    469 ComputationDataHandle ComputationBuilder::Select(
    470     const ComputationDataHandle& pred, const ComputationDataHandle& on_true,
    471     const ComputationDataHandle& on_false) {
    472   return TernaryOp(TRIOP_SELECT, pred, on_true, on_false);
    473 }
    474 
    475 ComputationDataHandle ComputationBuilder::Tuple(
    476     tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
    477   OpRequest op_request;
    478   VariadicOpRequest* request = op_request.mutable_variadic_op_request();
    479   request->set_varop(VAROP_TUPLE);
    480   for (const ComputationDataHandle& operand : elements) {
    481     *request->add_operands() = operand;
    482   }
    483   return RunOpAndParseResponse(&op_request);
    484 }
    485 
    486 ComputationDataHandle ComputationBuilder::GetTupleElement(
    487     const ComputationDataHandle& tuple_data, int64 index) {
    488   OpRequest op_request;
    489   GetTupleElementRequest* request =
    490       op_request.mutable_get_tuple_element_request();
    491   *request->mutable_operand() = tuple_data;
    492   request->set_index(index);
    493   return RunOpAndParseResponse(&op_request);
    494 }
    495 
    496 ComputationDataHandle ComputationBuilder::Eq(
    497     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    498     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    499   return BinaryOp(BINOP_EQ, lhs, rhs, broadcast_dimensions);
    500 }
    501 
    502 ComputationDataHandle ComputationBuilder::Ne(
    503     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    504     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    505   return BinaryOp(BINOP_NE, lhs, rhs, broadcast_dimensions);
    506 }
    507 
    508 ComputationDataHandle ComputationBuilder::Ge(
    509     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    510     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    511   return BinaryOp(BINOP_GE, lhs, rhs, broadcast_dimensions);
    512 }
    513 
    514 ComputationDataHandle ComputationBuilder::Gt(
    515     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    516     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    517   return BinaryOp(BINOP_GT, lhs, rhs, broadcast_dimensions);
    518 }
    519 
    520 ComputationDataHandle ComputationBuilder::Le(
    521     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    522     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    523   return BinaryOp(BINOP_LE, lhs, rhs, broadcast_dimensions);
    524 }
    525 
    526 ComputationDataHandle ComputationBuilder::Lt(
    527     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    528     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    529   return BinaryOp(BINOP_LT, lhs, rhs, broadcast_dimensions);
    530 }
    531 
    532 ComputationDataHandle ComputationBuilder::Dot(
    533     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) {
    534   StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
    535   if (!lhs_shape_or_status.ok()) {
    536     return ComputationDataHandle();
    537   }
    538   std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
    539 
    540   DotDimensionNumbers dimension_numbers;
    541   dimension_numbers.add_lhs_contracting_dimensions(
    542       lhs_shape->dimensions_size() == 1 ? 0 : 1);
    543   dimension_numbers.add_rhs_contracting_dimensions(0);
    544   return DotGeneral(lhs, rhs, dimension_numbers);
    545 }
    546 
    547 ComputationDataHandle ComputationBuilder::DotGeneral(
    548     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    549     const DotDimensionNumbers& dimension_numbers) {
    550   OpRequest op_request;
    551   DotRequest* request = op_request.mutable_dot_request();
    552   *request->mutable_lhs() = lhs;
    553   *request->mutable_rhs() = rhs;
    554   *request->mutable_dimension_numbers() = dimension_numbers;
    555   return RunOpAndParseResponse(&op_request);
    556 }
    557 
    558 ComputationDataHandle ComputationBuilder::Conv(
    559     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    560     tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
    561   return ConvWithGeneralDimensions(
    562       lhs, rhs, window_strides, padding,
    563       CreateDefaultConvDimensionNumbers(window_strides.size()));
    564 }
    565 
    566 ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding(
    567     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    568     tensorflow::gtl::ArraySlice<int64> window_strides,
    569     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
    570   return ConvGeneral(lhs, rhs, window_strides, padding,
    571                      CreateDefaultConvDimensionNumbers(window_strides.size()));
    572 }
    573 
    574 bool ComputationBuilder::VerifyConvolution(
    575     const Shape& lhs_shape, const Shape& rhs_shape,
    576     const ConvolutionDimensionNumbers& dimension_numbers) {
    577   if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) {
    578     NoteError(
    579         InvalidArgument("Convolution arguments must have same number of "
    580                         "dimensions. Got: %s and %s",
    581                         ShapeUtil::HumanString(lhs_shape).c_str(),
    582                         ShapeUtil::HumanString(rhs_shape).c_str()));
    583     return false;
    584   }
    585   int num_dims = ShapeUtil::Rank(lhs_shape);
    586   if (num_dims < 2) {
    587     NoteError(InvalidArgument(
    588         "Convolution expects argument arrays with >= 3 dimensions. "
    589         "Got: %s and %s",
    590         ShapeUtil::HumanString(lhs_shape).c_str(),
    591         ShapeUtil::HumanString(rhs_shape).c_str()));
    592     return false;
    593   }
    594   int num_spatial_dims = num_dims - 2;
    595 
    596   const auto check_spatial_dimensions =
    597       [&](const char* const field_name,
    598           const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
    599               numbers) {
    600         if (numbers.size() != num_spatial_dims) {
    601           NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
    602                                     num_spatial_dims, field_name,
    603                                     numbers.size()));
    604           return false;
    605         }
    606         for (int i = 0; i < numbers.size(); ++i) {
    607           if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
    608             NoteError(
    609                 InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
    610                                 field_name, i, numbers.Get(i)));
    611             return false;
    612           }
    613         }
    614         return true;
    615       };
    616   return check_spatial_dimensions(
    617              "input_spatial_dimensions",
    618              dimension_numbers.input_spatial_dimensions()) &&
    619          check_spatial_dimensions(
    620              "kernel_spatial_dimensions",
    621              dimension_numbers.kernel_spatial_dimensions()) &&
    622          check_spatial_dimensions(
    623              "output_spatial_dimensions",
    624              dimension_numbers.output_spatial_dimensions());
    625 }
    626 
    627 ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions(
    628     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    629     tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
    630     const ConvolutionDimensionNumbers& dimension_numbers) {
    631   if (!first_error_.ok() || !PrepareComputation().ok()) {
    632     return ComputationDataHandle();
    633   }
    634 
    635   StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
    636   if (!lhs_shape_or_status.ok()) {
    637     return ComputationDataHandle();
    638   }
    639 
    640   StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs);
    641   if (!rhs_shape_or_status.ok()) {
    642     return ComputationDataHandle();
    643   }
    644 
    645   std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
    646   std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie();
    647 
    648   if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) {
    649     NoteError(InternalError("failed to verify convolution"));
    650     return ComputationDataHandle();
    651   }
    652 
    653   std::vector<int64> base_area_dimensions(
    654       dimension_numbers.input_spatial_dimensions_size());
    655   for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
    656        ++i) {
    657     base_area_dimensions[i] =
    658         lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
    659   }
    660 
    661   std::vector<int64> window_dimensions(
    662       dimension_numbers.kernel_spatial_dimensions_size());
    663   for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
    664     window_dimensions[i] =
    665         rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
    666   }
    667 
    668   return ConvGeneral(lhs, rhs, window_strides,
    669                      MakePadding(base_area_dimensions, window_dimensions,
    670                                  window_strides, padding),
    671                      dimension_numbers);
    672 }
    673 
    674 ComputationDataHandle ComputationBuilder::ConvGeneral(
    675     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    676     tensorflow::gtl::ArraySlice<int64> window_strides,
    677     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
    678     const ConvolutionDimensionNumbers& dimension_numbers) {
    679   return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
    680                             dimension_numbers);
    681 }
    682 
    683 ComputationDataHandle ComputationBuilder::ConvGeneralDilated(
    684     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    685     tensorflow::gtl::ArraySlice<int64> window_strides,
    686     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
    687     tensorflow::gtl::ArraySlice<int64> lhs_dilation,
    688     tensorflow::gtl::ArraySlice<int64> rhs_dilation,
    689     const ConvolutionDimensionNumbers& dimension_numbers) {
    690   if (!first_error_.ok() || !PrepareComputation().ok()) {
    691     return ComputationDataHandle();
    692   }
    693 
    694   StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
    695   if (!lhs_shape_or_status.ok()) {
    696     return ComputationDataHandle();
    697   }
    698 
    699   StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs);
    700   if (!rhs_shape_or_status.ok()) {
    701     return ComputationDataHandle();
    702   }
    703 
    704   std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
    705   std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie();
    706   if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) {
    707     // Error is recorded in VerifyConvolution.
    708     return ComputationDataHandle();
    709   }
    710 
    711   std::vector<int64> window_dimensions(
    712       dimension_numbers.kernel_spatial_dimensions_size());
    713   for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
    714     window_dimensions[i] =
    715         rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
    716   }
    717 
    718   OpRequest op_request;
    719   ConvolveRequest* request = op_request.mutable_convolve_request();
    720   *request->mutable_lhs() = lhs;
    721   *request->mutable_rhs() = rhs;
    722   *request->mutable_dimension_numbers() = dimension_numbers;
    723 
    724   if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation,
    725                   rhs_dilation, request->mutable_window())) {
    726     // Error is recorded in MakeWindow.
    727     return ComputationDataHandle();
    728   }
    729 
    730   return RunOpAndParseResponse(&op_request);
    731 }
    732 
    733 ComputationDataHandle ComputationBuilder::Fft(
    734     const ComputationDataHandle& operand, const FftType fft_type,
    735     const tensorflow::gtl::ArraySlice<int64> fft_length) {
    736   OpRequest op_request;
    737   FftRequest* request = op_request.mutable_fft_request();
    738   *request->mutable_operand() = operand;
    739   request->set_fft_type(fft_type);
    740   for (int64 dim_len : fft_length) {
    741     request->add_fft_length(dim_len);
    742   }
    743   return RunOpAndParseResponse(&op_request);
    744 }
    745 
    746 ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape,
    747                                                  const string& config) {
    748   OpRequest op_request;
    749   InfeedRequest* request = op_request.mutable_infeed_request();
    750   *request->mutable_shape() = shape;
    751   *request->mutable_config() = config;
    752   return RunOpAndParseResponse(&op_request);
    753 }
    754 
    755 void ComputationBuilder::Outfeed(const ComputationDataHandle& operand,
    756                                  const Shape& shape,
    757                                  const string& outfeed_config) {
    758   OpRequest op_request;
    759   OutfeedRequest* request = op_request.mutable_outfeed_request();
    760   request->set_outfeed_config(outfeed_config);
    761   *request->mutable_operand() = operand;
    762   *request->mutable_shape() = shape;
    763   RunOpAndNoteError(&op_request);
    764 }
    765 
    766 ComputationDataHandle ComputationBuilder::Call(
    767     const Computation& computation,
    768     tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) {
    769   OpRequest op_request;
    770   CallRequest* request = op_request.mutable_call_request();
    771   *request->mutable_to_apply() = computation.handle();
    772   for (const ComputationDataHandle& operand : operands) {
    773     *request->add_operands() = operand;
    774   }
    775   return RunOpAndParseResponse(&op_request);
    776 }
    777 
    778 ComputationDataHandle ComputationBuilder::CustomCall(
    779     const string& call_target_name,
    780     tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
    781     const Shape& shape) {
    782   OpRequest op_request;
    783   CustomCallRequest* request = op_request.mutable_custom_call_request();
    784   request->set_call_target_name(call_target_name);
    785   for (const ComputationDataHandle& operand : operands) {
    786     *request->add_operands() = operand;
    787   }
    788   *request->mutable_shape() = shape;
    789   return RunOpAndParseResponse(&op_request);
    790 }
    791 
    792 ComputationDataHandle ComputationBuilder::HostCompute(
    793     tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
    794     const string& channel_name, int64 cost_estimate_ns, const Shape& shape) {
    795   OpRequest op_request;
    796   HostComputeRequest* request = op_request.mutable_host_compute_request();
    797   for (const ComputationDataHandle& operand : operands) {
    798     *request->add_operands() = operand;
    799   }
    800   *request->mutable_shape() = shape;
    801   request->set_channel_name(channel_name);
    802   request->set_cost_estimate_ns(cost_estimate_ns);
    803   return RunOpAndParseResponse(&op_request);
    804 }
    805 
    806 ComputationDataHandle ComputationBuilder::Complex(
    807     const ComputationDataHandle& real, const ComputationDataHandle& imag,
    808     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    809   return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions);
    810 }
    811 
    812 ComputationDataHandle ComputationBuilder::Conj(
    813     const ComputationDataHandle& operand) {
    814   return Complex(Real(operand), Neg(Imag(operand)));
    815 }
    816 
    817 ComputationDataHandle ComputationBuilder::Add(
    818     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    819     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    820   return BinaryOp(BINOP_ADD, lhs, rhs, broadcast_dimensions);
    821 }
    822 
    823 ComputationDataHandle ComputationBuilder::Sub(
    824     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    825     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    826   return BinaryOp(BINOP_SUB, lhs, rhs, broadcast_dimensions);
    827 }
    828 
    829 ComputationDataHandle ComputationBuilder::Mul(
    830     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    831     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    832   return BinaryOp(BINOP_MUL, lhs, rhs, broadcast_dimensions);
    833 }
    834 
    835 ComputationDataHandle ComputationBuilder::Div(
    836     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    837     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    838   return BinaryOp(BINOP_DIV, lhs, rhs, broadcast_dimensions);
    839 }
    840 
    841 ComputationDataHandle ComputationBuilder::Rem(
    842     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    843     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    844   return BinaryOp(BINOP_REM, lhs, rhs, broadcast_dimensions);
    845 }
    846 
    847 ComputationDataHandle ComputationBuilder::Max(
    848     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    849     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    850   return BinaryOp(BINOP_MAX, lhs, rhs, broadcast_dimensions);
    851 }
    852 
    853 ComputationDataHandle ComputationBuilder::Min(
    854     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    855     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    856   return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions);
    857 }
    858 
    859 ComputationDataHandle ComputationBuilder::And(
    860     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    861     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    862   return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions);
    863 }
    864 
    865 ComputationDataHandle ComputationBuilder::Or(
    866     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    867     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    868   return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions);
    869 }
    870 
    871 ComputationDataHandle ComputationBuilder::Not(
    872     const ComputationDataHandle& operand) {
    873   return UnaryOp(UNOP_NOT, operand);
    874 }
    875 
    876 ComputationDataHandle ComputationBuilder::ShiftLeft(
    877     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    878     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    879   return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions);
    880 }
    881 
    882 ComputationDataHandle ComputationBuilder::ShiftRightArithmetic(
    883     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    884     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    885   return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions);
    886 }
    887 
    888 ComputationDataHandle ComputationBuilder::ShiftRightLogical(
    889     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    890     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    891   return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions);
    892 }
    893 
    894 ComputationDataHandle ComputationBuilder::Abs(
    895     const ComputationDataHandle& operand) {
    896   return UnaryOp(UNOP_ABS, operand);
    897 }
    898 
    899 ComputationDataHandle ComputationBuilder::Atan2(
    900     const ComputationDataHandle& y, const ComputationDataHandle& x,
    901     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    902   return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions);
    903 }
    904 
    905 ComputationDataHandle ComputationBuilder::Exp(
    906     const ComputationDataHandle& operand) {
    907   return UnaryOp(UNOP_EXP, operand);
    908 }
    909 
    910 ComputationDataHandle ComputationBuilder::Floor(
    911     const ComputationDataHandle& operand) {
    912   return UnaryOp(UNOP_FLOOR, operand);
    913 }
    914 
    915 ComputationDataHandle ComputationBuilder::Ceil(
    916     const ComputationDataHandle& operand) {
    917   return UnaryOp(UNOP_CEIL, operand);
    918 }
    919 
    920 ComputationDataHandle ComputationBuilder::Round(
    921     const ComputationDataHandle& operand) {
    922   return UnaryOp(UNOP_ROUND_NEAREST_AFZ, operand);
    923 }
    924 
    925 ComputationDataHandle ComputationBuilder::Log(
    926     const ComputationDataHandle& operand) {
    927   return UnaryOp(UNOP_LOG, operand);
    928 }
    929 
    930 ComputationDataHandle ComputationBuilder::Sign(
    931     const ComputationDataHandle& operand) {
    932   return UnaryOp(UNOP_SIGN, operand);
    933 }
    934 
    935 ComputationDataHandle ComputationBuilder::Cos(
    936     const ComputationDataHandle& operand) {
    937   return UnaryOp(UNOP_COS, operand);
    938 }
    939 
    940 ComputationDataHandle ComputationBuilder::Sin(
    941     const ComputationDataHandle& operand) {
    942   return UnaryOp(UNOP_SIN, operand);
    943 }
    944 
    945 ComputationDataHandle ComputationBuilder::Tanh(
    946     const ComputationDataHandle& operand) {
    947   return UnaryOp(UNOP_TANH, operand);
    948 }
    949 
    950 ComputationDataHandle ComputationBuilder::Real(
    951     const ComputationDataHandle& operand) {
    952   return UnaryOp(UNOP_REAL, operand);
    953 }
    954 
    955 ComputationDataHandle ComputationBuilder::Imag(
    956     const ComputationDataHandle& operand) {
    957   return UnaryOp(UNOP_IMAG, operand);
    958 }
    959 
    960 ComputationDataHandle ComputationBuilder::IsFinite(
    961     const ComputationDataHandle& operand) {
    962   return UnaryOp(UNOP_IS_FINITE, operand);
    963 }
    964 
    965 ComputationDataHandle ComputationBuilder::Transpose(
    966     const ComputationDataHandle& operand,
    967     tensorflow::gtl::ArraySlice<int64> permutation) {
    968   OpRequest op_request;
    969   TransposeRequest* request = op_request.mutable_transpose_request();
    970   *request->mutable_operand() = operand;
    971   for (int64 dimension : permutation) {
    972     request->add_dimensions(dimension);
    973   }
    974   return RunOpAndParseResponse(&op_request);
    975 }
    976 
    977 ComputationDataHandle ComputationBuilder::Rev(
    978     const ComputationDataHandle& operand,
    979     tensorflow::gtl::ArraySlice<int64> dimensions) {
    980   OpRequest op_request;
    981   ReverseRequest* request = op_request.mutable_reverse_request();
    982   *request->mutable_operand() = operand;
    983   for (int64 dimension : dimensions) {
    984     request->add_dimensions(dimension);
    985   }
    986   return RunOpAndParseResponse(&op_request);
    987 }
    988 
    989 ComputationDataHandle ComputationBuilder::Sort(
    990     const ComputationDataHandle& operand) {
    991   return UnaryOp(UNOP_SORT, operand);
    992 }
    993 
    994 ComputationDataHandle ComputationBuilder::SqrtF32(
    995     const ComputationDataHandle& operand) {
    996   return BinaryOp(BINOP_POW, operand, ConstantR0<float>(0.5),
    997                   /*broadcast_dimensions=*/{});
    998 }
    999 
   1000 ComputationDataHandle ComputationBuilder::Pow(
   1001     const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
   1002     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
   1003   return BinaryOp(BINOP_POW, lhs, rhs, broadcast_dimensions);
   1004 }
   1005 
   1006 ComputationDataHandle ComputationBuilder::ConvertElementType(
   1007     const ComputationDataHandle& operand, PrimitiveType new_element_type) {
   1008   if (!first_error_.ok() || !PrepareComputation().ok()) {
   1009     return ComputationDataHandle();
   1010   }
   1011 
   1012   StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
   1013   if (!shape_status.ok()) {
   1014     return ComputationDataHandle();
   1015   }
   1016   std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
   1017 
   1018   OpRequest op_request;
   1019   ConvertRequest* request = op_request.mutable_convert_request();
   1020   *request->mutable_operand() = operand;
   1021   request->set_new_element_type(new_element_type);
   1022   return RunOpAndParseResponse(&op_request);
   1023 }
   1024 
   1025 ComputationDataHandle ComputationBuilder::BitcastConvertType(
   1026     const ComputationDataHandle& operand, PrimitiveType new_element_type) {
   1027   if (!first_error_.ok() || !PrepareComputation().ok()) {
   1028     return ComputationDataHandle();
   1029   }
   1030 
   1031   StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
   1032   if (!shape_status.ok()) {
   1033     return ComputationDataHandle();
   1034   }
   1035   std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
   1036 
   1037   OpRequest op_request;
   1038   ConvertRequest* request = op_request.mutable_bitcast_convert_request();
   1039   *request->mutable_operand() = operand;
   1040   request->set_new_element_type(new_element_type);
   1041   return RunOpAndParseResponse(&op_request);
   1042 }
   1043 
   1044 ComputationDataHandle ComputationBuilder::SquareF32(
   1045     const ComputationDataHandle& operand) {
   1046   return BinaryOp(BINOP_POW, operand, ConstantR0<float>(2.0),
   1047                   /*broadcast_dimensions=*/{});
   1048 }
   1049 
   1050 ComputationDataHandle ComputationBuilder::ReciprocalF32(
   1051     const ComputationDataHandle& operand) {
   1052   return BinaryOp(BINOP_POW, operand, ConstantR0<float>(-1.0),
   1053                   /*broadcast_dimensions=*/{});
   1054 }
   1055 
   1056 ComputationDataHandle ComputationBuilder::Neg(
   1057     const ComputationDataHandle& operand) {
   1058   return UnaryOp(UNOP_NEGATE, operand);
   1059 }
   1060 
   1061 ComputationDataHandle ComputationBuilder::Clamp(
   1062     const ComputationDataHandle& min, const ComputationDataHandle& operand,
   1063     const ComputationDataHandle& max) {
   1064   return TernaryOp(TRIOP_CLAMP, min, operand, max);
   1065 }
   1066 
   1067 ComputationDataHandle ComputationBuilder::UnaryOp(
   1068     UnaryOperation unop, const ComputationDataHandle& operand) {
   1069   OpRequest op_request;
   1070   UnaryOpRequest* request = op_request.mutable_unary_op_request();
   1071   request->set_unop(unop);
   1072   *request->mutable_operand() = operand;
   1073   return RunOpAndParseResponse(&op_request);
   1074 }
   1075 
   1076 ComputationDataHandle ComputationBuilder::BinaryOp(
   1077     BinaryOperation binop, const ComputationDataHandle& lhs,
   1078     const ComputationDataHandle& rhs,
   1079     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
   1080   OpRequest op_request;
   1081   BinaryOpRequest* request = op_request.mutable_binary_op_request();
   1082   request->set_binop(binop);
   1083   *request->mutable_lhs() = lhs;
   1084   *request->mutable_rhs() = rhs;
   1085   for (int64 dimension : broadcast_dimensions) {
   1086     request->add_broadcast_dimensions(dimension);
   1087   }
   1088   return RunOpAndParseResponse(&op_request);
   1089 }
   1090 
   1091 ComputationDataHandle ComputationBuilder::RngOp(
   1092     RandomDistribution distribution,
   1093     tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,
   1094     const Shape& shape) {
   1095   OpRequest op_request;
   1096   RngRequest* request = op_request.mutable_rng_request();
   1097   request->set_distribution(distribution);
   1098   for (const ComputationDataHandle& param : parameters) {
   1099     *request->add_parameter() = param;
   1100   }
   1101   *request->mutable_shape() = shape;
   1102   return RunOpAndParseResponse(&op_request);
   1103 }
   1104 
   1105 ComputationDataHandle ComputationBuilder::TernaryOp(
   1106     TernaryOperation triop, const ComputationDataHandle& lhs,
   1107     const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) {
   1108   OpRequest op_request;
   1109   TernaryOpRequest* request = op_request.mutable_ternary_op_request();
   1110   request->set_triop(triop);
   1111   *request->mutable_lhs() = lhs;
   1112   *request->mutable_rhs() = rhs;
   1113   *request->mutable_ehs() = ehs;
   1114   return RunOpAndParseResponse(&op_request);
   1115 }
   1116 
   1117 Status ComputationBuilder::SetReturnValue(
   1118     const ComputationDataHandle& operand) {
   1119   TF_RETURN_IF_ERROR(first_error_);
   1120 
   1121   SetReturnValueRequest request;
   1122   *request.mutable_computation() = computation_.handle();
   1123   *request.mutable_operand() = operand;
   1124 
   1125   SetReturnValueResponse response;
   1126 
   1127   VLOG(2) << "making set-handle-to-execute request";
   1128   Status s = client_->stub()->SetReturnValue(&request, &response);
   1129   VLOG(2) << "done with request";
   1130 
   1131   if (!s.ok()) {
   1132     NoteError(s);
   1133     return first_error_;
   1134   }
   1135 
   1136   return Status::OK();
   1137 }
   1138 
   1139 StatusOr<bool> ComputationBuilder::IsConstant(
   1140     const ComputationDataHandle& operand, int64 num_parameters) {
   1141   TF_RETURN_IF_ERROR(first_error_);
   1142 
   1143   IsConstantRequest request;
   1144   *request.mutable_computation() = computation_.handle();
   1145   *request.mutable_operand() = operand;
   1146   request.set_num_parameters(num_parameters);
   1147   IsConstantResponse response;
   1148 
   1149   VLOG(2) << "making IsConstant request";
   1150   Status s = client_->stub()->IsConstant(&request, &response);
   1151   VLOG(2) << "done with request";
   1152 
   1153   if (!s.ok()) {
   1154     return s;
   1155   }
   1156   return response.is_constant();
   1157 }
   1158 
   1159 StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant(
   1160     const ComputationDataHandle& operand, const Layout* output_layout,
   1161     tensorflow::gtl::ArraySlice<Literal> parameters) {
   1162   TF_RETURN_IF_ERROR(first_error_);
   1163 
   1164   ComputeConstantRequest request;
   1165   *request.mutable_computation() = computation_.handle();
   1166   *request.mutable_operand() = operand;
   1167   if (output_layout != nullptr) {
   1168     *request.mutable_output_layout() = *output_layout;
   1169   }
   1170   for (const auto& param : parameters) {
   1171     *request.add_parameters() = param.ToProto();
   1172   }
   1173 
   1174   ComputeConstantResponse response;
   1175 
   1176   VLOG(2) << "making compute-constant request";
   1177   Status s = client_->stub()->ComputeConstant(&request, &response);
   1178   VLOG(2) << "done with request";
   1179 
   1180   if (!s.ok()) {
   1181     return s;
   1182   }
   1183 
   1184   VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
   1185 
   1186   if (!response.has_literal()) {
   1187     return InternalError(
   1188         "no computed literal in the provided response in ComputeConstant "
   1189         "request");
   1190   }
   1191   return Literal::CreateFromProto(response.literal());
   1192 }
   1193 
   1194 ComputationDataHandle ComputationBuilder::Map(
   1195     tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
   1196     const Computation& computation,
   1197     tensorflow::gtl::ArraySlice<int64> dimensions,
   1198     tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) {
   1199   OpRequest op_request;
   1200   MapRequest* request = op_request.mutable_map_request();
   1201   for (const ComputationDataHandle& operand : operands) {
   1202     *request->add_operands() = operand;
   1203   }
   1204   *request->mutable_to_apply() = computation.handle();
   1205   for (int64 dimension : dimensions) {
   1206     request->add_dimensions(dimension);
   1207   }
   1208   for (const ComputationDataHandle& sop : static_operands) {
   1209     *request->add_static_operands() = sop;
   1210   }
   1211   return RunOpAndParseResponse(&op_request);
   1212 }
   1213 
   1214 ComputationDataHandle ComputationBuilder::RngNormal(
   1215     const ComputationDataHandle& mu, const ComputationDataHandle& sigma,
   1216     const Shape& shape) {
   1217   return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
   1218 }
   1219 
   1220 ComputationDataHandle ComputationBuilder::RngUniform(
   1221     const ComputationDataHandle& a, const ComputationDataHandle& b,
   1222     const Shape& shape) {
   1223   return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
   1224 }
   1225 
   1226 ComputationDataHandle ComputationBuilder::While(
   1227     const Computation& condition, const Computation& body,
   1228     const ComputationDataHandle& init) {
   1229   OpRequest op_request;
   1230   WhileRequest* request = op_request.mutable_while_request();
   1231   *request->mutable_condition() = condition.handle();
   1232   *request->mutable_body() = body.handle();
   1233   *request->mutable_init() = init;
   1234   return RunOpAndParseResponse(&op_request);
   1235 }
   1236 
   1237 ComputationDataHandle ComputationBuilder::Gather(
   1238     const ComputationDataHandle& input,
   1239     const ComputationDataHandle& gather_indices,
   1240     const GatherDimensionNumbers& dimension_numbers,
   1241     tensorflow::gtl::ArraySlice<int64> window_bounds) {
   1242   OpRequest op_request;
   1243   GatherRequest* gather_request = op_request.mutable_gather_request();
   1244   *gather_request->mutable_input() = input;
   1245   *gather_request->mutable_gather_indices() = gather_indices;
   1246   *gather_request->mutable_dimension_numbers() = dimension_numbers;
   1247   for (int64 window_bound : window_bounds) {
   1248     gather_request->add_window_bounds(window_bound);
   1249   }
   1250   return RunOpAndParseResponse(&op_request);
   1251 }
   1252 
   1253 ComputationDataHandle ComputationBuilder::Conditional(
   1254     const ComputationDataHandle& predicate,
   1255     const ComputationDataHandle& true_operand,
   1256     const Computation& true_computation,
   1257     const ComputationDataHandle& false_operand,
   1258     const Computation& false_computation) {
   1259   OpRequest op_request;
   1260   ConditionalRequest* request = op_request.mutable_conditional_request();
   1261   *request->mutable_predicate() = predicate;
   1262   *request->mutable_true_operand() = true_operand;
   1263   *request->mutable_true_computation() = true_computation.handle();
   1264   *request->mutable_false_operand() = false_operand;
   1265   *request->mutable_false_computation() = false_computation.handle();
   1266   return RunOpAndParseResponse(&op_request);
   1267 }
   1268 
   1269 ComputationDataHandle ComputationBuilder::Reduce(
   1270     const ComputationDataHandle& operand,
   1271     const ComputationDataHandle& init_value, const Computation& computation,
   1272     tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
   1273   OpRequest op_request;
   1274   ReduceRequest* request = op_request.mutable_reduce_request();
   1275   *request->mutable_operand() = operand;
   1276   *request->mutable_init_value() = init_value;
   1277   for (int64 dimension : dimensions_to_reduce) {
   1278     request->add_dimensions(dimension);
   1279   }
   1280   *request->mutable_to_apply() = computation.handle();
   1281   return RunOpAndParseResponse(&op_request);
   1282 }
   1283 
   1284 ComputationDataHandle ComputationBuilder::ReduceAll(
   1285     const ComputationDataHandle& operand,
   1286     const ComputationDataHandle& init_value, const Computation& computation) {
   1287   if (!first_error_.ok() || !PrepareComputation().ok()) {
   1288     return ComputationDataHandle();
   1289   }
   1290 
   1291   StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
   1292   if (!shape.ok()) {
   1293     return ComputationDataHandle();
   1294   }
   1295 
   1296   std::vector<int64> all_dimnos(ShapeUtil::Rank(*shape.ValueOrDie()));
   1297   std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
   1298   return Reduce(operand, init_value, computation, all_dimnos);
   1299 }
   1300 
   1301 ComputationDataHandle ComputationBuilder::ReduceWindow(
   1302     const ComputationDataHandle& operand,
   1303     const ComputationDataHandle& init_value, const Computation& computation,
   1304     tensorflow::gtl::ArraySlice<int64> window_dimensions,
   1305     tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
   1306   if (!first_error_.ok()) {
   1307     return ComputationDataHandle();
   1308   }
   1309 
   1310   StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
   1311   if (!shape.ok()) {
   1312     return ComputationDataHandle();
   1313   }
   1314 
   1315   Status padding_valid =
   1316       ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()),
   1317                             window_dimensions, window_strides);
   1318   if (!padding_valid.ok()) {
   1319     first_error_ = padding_valid;
   1320     return ComputationDataHandle();
   1321   }
   1322 
   1323   std::vector<std::pair<int64, int64>> padding_values =
   1324       MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()),
   1325                   window_dimensions, window_strides, padding);
   1326   return ReduceWindowWithGeneralPadding(operand, init_value, computation,
   1327                                         window_dimensions, window_strides,
   1328                                         padding_values);
   1329 }
   1330 
   1331 ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding(
   1332     const ComputationDataHandle& operand,
   1333     const ComputationDataHandle& init_value, const Computation& computation,
   1334     tensorflow::gtl::ArraySlice<int64> window_dimensions,
   1335     tensorflow::gtl::ArraySlice<int64> window_strides,
   1336     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
   1337   OpRequest op_request;
   1338   ReduceWindowRequest* request = op_request.mutable_reduce_window_request();
   1339   *request->mutable_operand() = operand;
   1340   *request->mutable_to_apply() = computation.handle();
   1341   *request->mutable_init_value() = init_value;
   1342 
   1343   if (!MakeWindow(window_dimensions, window_strides, padding, {}, {},
   1344                   request->mutable_window())) {
   1345     NoteError(InternalError("failed to make window"));
   1346     return ComputationDataHandle();
   1347   }
   1348 
   1349   return RunOpAndParseResponse(&op_request);
   1350 }
   1351 
   1352 ComputationDataHandle ComputationBuilder::BatchNormTraining(
   1353     const ComputationDataHandle& operand, const ComputationDataHandle& scale,
   1354     const ComputationDataHandle& offset, float epsilon, int64 feature_index) {
   1355   OpRequest op_request;
   1356   BatchNormTrainingRequest* request =
   1357       op_request.mutable_batch_norm_training_request();
   1358   *request->mutable_operand() = operand;
   1359   *request->mutable_scale() = scale;
   1360   *request->mutable_offset() = offset;
   1361   request->set_epsilon(epsilon);
   1362   request->set_feature_index(feature_index);
   1363   return RunOpAndParseResponse(&op_request);
   1364 }
   1365 
   1366 ComputationDataHandle ComputationBuilder::BatchNormInference(
   1367     const ComputationDataHandle& operand, const ComputationDataHandle& scale,
   1368     const ComputationDataHandle& offset, const ComputationDataHandle& mean,
   1369     const ComputationDataHandle& variance, float epsilon, int64 feature_index) {
   1370   OpRequest op_request;
   1371   BatchNormInferenceRequest* request =
   1372       op_request.mutable_batch_norm_inference_request();
   1373   *request->mutable_operand() = operand;
   1374   *request->mutable_scale() = scale;
   1375   *request->mutable_offset() = offset;
   1376   *request->mutable_mean() = mean;
   1377   *request->mutable_variance() = variance;
   1378   request->set_epsilon(epsilon);
   1379   request->set_feature_index(feature_index);
   1380   return RunOpAndParseResponse(&op_request);
   1381 }
   1382 
   1383 ComputationDataHandle ComputationBuilder::BatchNormGrad(
   1384     const ComputationDataHandle& operand, const ComputationDataHandle& scale,
   1385     const ComputationDataHandle& mean, const ComputationDataHandle& var,
   1386     const ComputationDataHandle& grad_output, float epsilon,
   1387     int64 feature_index) {
   1388   OpRequest op_request;
   1389   BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request();
   1390   *request->mutable_operand() = operand;
   1391   *request->mutable_scale() = scale;
   1392   *request->mutable_mean() = mean;
   1393   *request->mutable_variance() = var;
   1394   *request->mutable_grad_output() = grad_output;
   1395   request->set_epsilon(epsilon);
   1396   request->set_feature_index(feature_index);
   1397   return RunOpAndParseResponse(&op_request);
   1398 }
   1399 
   1400 ComputationDataHandle ComputationBuilder::CrossReplicaSum(
   1401     const ComputationDataHandle& operand) {
   1402   OpRequest op_request;
   1403   CrossReplicaSumRequest* request =
   1404       op_request.mutable_cross_replica_sum_request();
   1405   *request->mutable_operand() = operand;
   1406   return RunOpAndParseResponse(&op_request);
   1407 }
   1408 
   1409 ComputationDataHandle ComputationBuilder::SelectAndScatter(
   1410     const ComputationDataHandle& operand, const Computation& select,
   1411     tensorflow::gtl::ArraySlice<int64> window_dimensions,
   1412     tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
   1413     const ComputationDataHandle& source,
   1414     const ComputationDataHandle& init_value, const Computation& scatter) {
   1415   if (!first_error_.ok()) {
   1416     return ComputationDataHandle();
   1417   }
   1418 
   1419   StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
   1420   if (!shape.ok()) {
   1421     return ComputationDataHandle();
   1422   }
   1423   return SelectAndScatterWithGeneralPadding(
   1424       operand, select, window_dimensions, window_strides,
   1425       MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()),
   1426                   window_dimensions, window_strides, padding),
   1427       source, init_value, scatter);
   1428 }
   1429 
   1430 ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding(
   1431     const ComputationDataHandle& operand, const Computation& select,
   1432     tensorflow::gtl::ArraySlice<int64> window_dimensions,
   1433     tensorflow::gtl::ArraySlice<int64> window_strides,
   1434     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
   1435     const ComputationDataHandle& source,
   1436     const ComputationDataHandle& init_value, const Computation& scatter) {
   1437   OpRequest op_request;
   1438   SelectAndScatterRequest* request =
   1439       op_request.mutable_select_and_scatter_request();
   1440   *request->mutable_operand() = operand;
   1441   *request->mutable_select() = select.handle();
   1442   *request->mutable_source() = source;
   1443   *request->mutable_init_value() = init_value;
   1444   *request->mutable_scatter() = scatter.handle();
   1445 
   1446   if (!MakeWindow(window_dimensions, window_strides, padding, {}, {},
   1447                   request->mutable_window())) {
   1448     NoteError(InternalError("failed to make window"));
   1449     return ComputationDataHandle();
   1450   }
   1451 
   1452   return RunOpAndParseResponse(&op_request);
   1453 }
   1454 
   1455 ComputationDataHandle ComputationBuilder::ReducePrecision(
   1456     const ComputationDataHandle& operand, const int exponent_bits,
   1457     const int mantissa_bits) {
   1458   OpRequest op_request;
   1459   ReducePrecisionRequest* request =
   1460       op_request.mutable_reduce_precision_request();
   1461   *request->mutable_operand() = operand;
   1462   request->set_exponent_bits(exponent_bits);
   1463   request->set_mantissa_bits(mantissa_bits);
   1464   return RunOpAndParseResponse(&op_request);
   1465 }
   1466 
   1467 void ComputationBuilder::Send(const ComputationDataHandle& operand,
   1468                               const ChannelHandle& handle) {
   1469   OpRequest op_request;
   1470   SendRequest* request = op_request.mutable_send_request();
   1471   *request->mutable_operand() = operand;
   1472   *request->mutable_channel_handle() = handle;
   1473   *op_request.mutable_computation() = computation_.handle();
   1474   RunOpAndNoteError(&op_request);
   1475 }
   1476 
   1477 ComputationDataHandle ComputationBuilder::Recv(const Shape& shape,
   1478                                                const ChannelHandle& handle) {
   1479   OpRequest op_request;
   1480   RecvRequest* request = op_request.mutable_recv_request();
   1481   *request->mutable_shape() = shape;
   1482   *request->mutable_channel_handle() = handle;
   1483   return RunOpAndParseResponse(&op_request);
   1484 }
   1485 
   1486 Computation ComputationBuilder::BuildAndNoteError() {
   1487   DCHECK(parent_builder_ != nullptr);
   1488   auto build_status = Build();
   1489   if (!build_status.ok()) {
   1490     parent_builder_->NoteError(
   1491         AddStatus(build_status.status(),
   1492                   tensorflow::strings::StrCat("error from: ", name_)));
   1493     return Computation();
   1494   }
   1495   return build_status.ConsumeValueOrDie();
   1496 }
   1497 
   1498 StatusOr<Computation> ComputationBuilder::Build() {
   1499   if (!first_error_.ok()) {
   1500     string backtrace;
   1501     first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
   1502     return AppendStatus(first_error_, backtrace);
   1503   }
   1504 
   1505   if (computation_.IsNull()) {
   1506     return FailedPrecondition("no computation was built");
   1507   }
   1508 
   1509   return {std::move(computation_)};
   1510 }
   1511 
   1512 /* static */ ConvolutionDimensionNumbers
   1513 ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
   1514   ConvolutionDimensionNumbers dimension_numbers;
   1515   dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
   1516   dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
   1517   dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
   1518   dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
   1519   dimension_numbers.set_kernel_output_feature_dimension(
   1520       kConvKernelOutputDimension);
   1521   dimension_numbers.set_kernel_input_feature_dimension(
   1522       kConvKernelInputDimension);
   1523   for (int i = 0; i < num_spatial_dims; ++i) {
   1524     dimension_numbers.add_input_spatial_dimensions(i + 2);
   1525     dimension_numbers.add_kernel_spatial_dimensions(i + 2);
   1526     dimension_numbers.add_output_spatial_dimensions(i + 2);
   1527   }
   1528   return dimension_numbers;
   1529 }
   1530 
   1531 /* static */ StatusOr<ConvolutionDimensionNumbers>
   1532 ComputationBuilder::CreateConvDimensionNumbers(
   1533     int64 input_batch, int64 input_feature, int64 input_first_spatial,
   1534     int64 input_second_spatial, int64 output_batch, int64 output_feature,
   1535     int64 output_first_spatial, int64 output_second_spatial,
   1536     int64 kernel_output_feature, int64 kernel_input_feature,
   1537     int64 kernel_first_spatial, int64 kernel_second_spatial) {
   1538   if (std::set<int64>({input_batch, input_feature, input_first_spatial,
   1539                        input_second_spatial})
   1540           .size() != 4) {
   1541     return FailedPrecondition(
   1542         "dimension numbers for the input are not unique: (%lld, %lld, %lld, "
   1543         "%lld)",
   1544         input_batch, input_feature, input_first_spatial, input_second_spatial);
   1545   }
   1546   if (std::set<int64>({kernel_output_feature, kernel_input_feature,
   1547                        kernel_first_spatial, kernel_second_spatial})
   1548           .size() != 4) {
   1549     return FailedPrecondition(
   1550         "dimension numbers for the weight are not unique: (%lld, %lld, %lld, "
   1551         "%lld)",
   1552         kernel_output_feature, kernel_input_feature, kernel_first_spatial,
   1553         kernel_second_spatial);
   1554   }
   1555   if (std::set<int64>({output_batch, output_feature, output_first_spatial,
   1556                        output_second_spatial})
   1557           .size() != 4) {
   1558     return FailedPrecondition(
   1559         "dimension numbers for the output are not unique: (%lld, %lld, %lld, "
   1560         "%lld)",
   1561         output_batch, output_feature, output_first_spatial,
   1562         output_second_spatial);
   1563   }
   1564   ConvolutionDimensionNumbers dimension_numbers;
   1565   dimension_numbers.set_input_batch_dimension(input_batch);
   1566   dimension_numbers.set_input_feature_dimension(input_feature);
   1567   dimension_numbers.add_input_spatial_dimensions(input_first_spatial);
   1568   dimension_numbers.add_input_spatial_dimensions(input_second_spatial);
   1569   dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature);
   1570   dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature);
   1571   dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial);
   1572   dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial);
   1573   dimension_numbers.set_output_batch_dimension(output_batch);
   1574   dimension_numbers.set_output_feature_dimension(output_feature);
   1575   dimension_numbers.add_output_spatial_dimensions(output_first_spatial);
   1576   dimension_numbers.add_output_spatial_dimensions(output_second_spatial);
   1577   return dimension_numbers;
   1578 }
   1579 
   1580 }  // namespace xla
   1581