Home | History | Annotate | Download | only in client
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
     17 #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
     18 
     19 #include <map>
     20 #include <string>
     21 #include <type_traits>
     22 #include <utility>
     23 
     24 #include "absl/container/flat_hash_map.h"
     25 #include "absl/container/flat_hash_set.h"
     26 #include "absl/strings/string_view.h"
     27 #include "absl/types/span.h"
     28 #include "tensorflow/compiler/xla/client/padding.h"
     29 #include "tensorflow/compiler/xla/client/xla_computation.h"
     30 #include "tensorflow/compiler/xla/comparison_util.h"
     31 #include "tensorflow/compiler/xla/literal.h"
     32 #include "tensorflow/compiler/xla/literal_util.h"
     33 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
     34 #include "tensorflow/compiler/xla/service/hlo.pb.h"
     35 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     36 #include "tensorflow/compiler/xla/shape_util.h"
     37 #include "tensorflow/compiler/xla/status_macros.h"
     38 #include "tensorflow/compiler/xla/statusor.h"
     39 #include "tensorflow/compiler/xla/types.h"
     40 #include "tensorflow/compiler/xla/xla_data.pb.h"
     41 #include "tensorflow/core/platform/macros.h"
     42 #include "tensorflow/core/platform/stacktrace.h"
     43 #include "tensorflow/core/platform/types.h"
     44 
     45 namespace xla {
     46 
     47 class XlaBuilder;
     48 
     49 // This represents an instruction that has been enqueued using the XlaBuilder.
     50 // This is used to pass to subsequent computations that depends upon the
     51 // instruction as an operand.
     52 class XlaOp {
     53  public:
     54   XlaOp() : handle_(-1), builder_(nullptr) {
     55     static_assert(std::is_trivially_destructible<XlaOp>::value,
     56                   "XlaOp should be trivially destructible");
     57   }
     58   ~XlaOp() = default;
     59 
     60   XlaOp(const XlaOp& other) = default;
     61   XlaOp& operator=(const XlaOp& other) = default;
     62 
     63   // Precondition: !IsUninitialized().
     64   //
     65   // It's very common to do foo.builder()->bar().  Without this precondition, if
     66   // foo.builder() is null, the call to bar will segfault at some point possibly
     67   // deep in the callstack when we finally dereference `this`.  The precondition
     68   // lets us avoid this tricky-to-debug problem.
     69   XlaBuilder* builder() const {
     70     CHECK(builder_ != nullptr);
     71     return builder_;
     72   }
     73 
     74   // Returns true if the XlaOp represents valid, non-erroneous value.
     75   bool valid() const { return handle_ >= 0; }
     76 
     77   // Returns true if the XlaOp was created by the XlaOp() constructor and
     78   // not returned by a builder.
     79   bool IsUninitialized() const { return builder_ == nullptr; }
     80 
     81   bool IsIdenticalTo(const XlaOp& rhs) const {
     82     return handle_ == rhs.handle_ && builder_ == rhs.builder_;
     83   }
     84 
     85   friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) {
     86     out << op.handle();
     87     return out;
     88   }
     89 
     90  private:
     91   explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
     92   XlaOp(int64 handle, XlaBuilder* builder)
     93       : handle_(handle), builder_(builder) {}
     94 
     95   int64 handle() const { return handle_; }
     96 
     97   friend class XlaBuilder;
     98 
     99   // < 0 means "invalid handle".
    100   int64 handle_;
    101 
    102   // Not owned. Non-null for any handle returned by XlaBuilder, even if the
    103   // handle is invalid.
    104   XlaBuilder* builder_;
    105 };
    106 
    107 // Arithmetic operator overloads for the XlaOp type.
    108 XlaOp operator-(const XlaOp& x);
    109 XlaOp operator+(const XlaOp& x, const XlaOp& y);
    110 XlaOp operator-(const XlaOp& x, const XlaOp& y);
    111 XlaOp operator*(const XlaOp& x, const XlaOp& y);
    112 XlaOp operator/(const XlaOp& x, const XlaOp& y);
    113 XlaOp operator%(const XlaOp& x, const XlaOp& y);
    114 
    115 // Bitwise operator overloads for the XlaOp type.
    116 XlaOp operator~(const XlaOp& x);
    117 XlaOp operator&(const XlaOp& x, const XlaOp& y);
    118 XlaOp operator|(const XlaOp& x, const XlaOp& y);
    119 XlaOp operator^(const XlaOp& x, const XlaOp& y);
    120 XlaOp operator<<(const XlaOp& x, const XlaOp& y);
    121 // Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
    122 // a right logical shift.
    123 XlaOp operator>>(const XlaOp& x, const XlaOp& y);
    124 
    125 // We don't overload the relational operators (==, !=, <, <=, >, >=) because the
    126 // semantics might be surprising since their result types are usually 'bool'.
    127 // Further programmers may expect == to be a structural equality.
    128 // We also choose not to overload any of the mutating operators (e.g., +=, -=)
    129 // because the semantics might be misleading  XLA computations are immutable.
    130 
    131 // A convenient interface for building up computations.
    132 //
    133 // Thread-compatible.
    134 class XlaBuilder {
    135  public:
    136   // computation_name: name to use for the built computation.
    137   XlaBuilder(const string& computation_name);
    138 
    139   XlaBuilder(const XlaBuilder&) = delete;
    140   XlaBuilder& operator=(const XlaBuilder&) = delete;
    141 
    142   ~XlaBuilder();
    143 
    144   // Returns the computation name.
    145   const string& name() const { return name_; }
    146 
    147   // Sets OpMetadata that will be added to all instructions until cleared.
    148   //
    149   // OpMetadata is often applied to a series of XLA HLO instructions. As a
    150   // result, OpMetadata is set on the Computation Builder. All subsequent
    151   // instructions generated via this Computation Builder will have the same
    152   // OpMetadata attached until a call to ClearOpMetadata.
    153   void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
    154 
    155   // Clears the HloMetadata state.
    156   void ClearOpMetadata() { metadata_.Clear(); }
    157 
    158   // Sets an OpSharding that will be attached to all instructions until cleared.
    159   void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
    160 
    161   // Clears the sharding. Ops will be sharded according to the default placement
    162   // policy.
    163   void ClearSharding() { sharding_ = absl::nullopt; }
    164 
    165   // Returns the OpSharding that will be attached to all instructions.
    166   const absl::optional<OpSharding>& sharding() const { return sharding_; }
    167 
    168   // Sets the builder to a mode where it will die immediately when an error is
    169   // encountered, rather than producing it in a deferred fashion when Build() is
    170   // called (which is the default).
    171   void set_die_immediately_on_error(bool enabled) {
    172     die_immediately_on_error_ = enabled;
    173   }
    174 
    175   // Default dimension numbers used for a 2D convolution.
    176   static constexpr int64 kConvBatchDimension = 0;
    177   static constexpr int64 kConvFeatureDimension = 1;
    178   static constexpr int64 kConvFirstSpatialDimension = 2;
    179   static constexpr int64 kConvSecondSpatialDimension = 3;
    180   static constexpr int64 kConvKernelOutputDimension = 0;
    181   static constexpr int64 kConvKernelInputDimension = 1;
    182   static constexpr int64 kConvKernelFirstSpatialDimension = 2;
    183   static constexpr int64 kConvKernelSecondSpatialDimension = 3;
    184 
    185   // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
    186   // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
    187   // the kernel operand
    188   // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
    189   static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
    190       int num_spatial_dims = 2);
    191 
    192   // Returns an error if the convolution dimension numbers have conflicts.
    193   static Status Validate(const ConvolutionDimensionNumbers& dnum);
    194 
    195   // Returns a new XlaBuilder whose resultant Computation is used only by this
    196   // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
    197   // behavior as the parent.
    198   std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
    199 
    200   // Builds the computation with the requested operations, or returns a non-ok
    201   // status. Note that all ops that have been enqueued will be moved to the
    202   // computation being returned. The root of the computation will be the last
    203   // added operation.
    204   //
    205   // `remove_dynamic_dimensions` tells the builder whether to remove the
    206   // dyanmic dimensions information in all ops.
    207   //
    208   // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the
    209   // dynamic dimensions information when XLA backend can handle dynamic
    210   // dimensions.
    211   StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = true);
    212 
    213   // Overload of Build which specifies a particular root instruction for the
    214   // computation.
    215   StatusOr<XlaComputation> Build(XlaOp root,
    216                                  bool remove_dynamic_dimensions = true);
    217 
    218   // Builds the computation with the requested operations, or notes an error in
    219   // the parent XlaBuilder and returns an empty computation if building failed.
    220   // This function is intended to be used where the returned XlaComputation is
    221   // only used by the parent XlaBuilder and hence further operation on the
    222   // returned XlaComputation will simply be error'ed out if an error occurred
    223   // while building this computation. If the built computation is to be used by
    224   // a XlaBuilder other than the parent XlaBuilder then Build() should be used
    225   // instead.
    226   XlaComputation BuildAndNoteError();
    227 
    228   // Returns a subgraph that roots on the given root. If the root is not a
    229   // compile-time constant (see `IsConstant`), returns an error.
    230   //
    231   // This will copy the needed ops/computations to the subgraph.
    232   StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op);
    233 
    234   // Returns the first error that was encountered while building the
    235   // computation. When an error is encountered, by default we return a vacuous
    236   // XlaOp and inform the user of the error that occurred while
    237   // building the computation when they make a final call to Build().
    238   //
    239   // See also set_die_immediately_on_error().
    240   Status first_error() const { return first_error_; }
    241 
    242   // Returns the current status of the builder, complete with the stack trace
    243   // information.
    244   Status GetCurrentStatus() const;
    245 
    246   // Returns the shape of the given op.
    247   StatusOr<Shape> GetShape(const XlaOp& op) const;
    248 
    249   // Returns the (inferred) result for the current computation's shape. This
    250   // assumes the root instruction is the last added instruction.
    251   StatusOr<ProgramShape> GetProgramShape() const;
    252 
    253   // Returns the (inferred) result for the current computation's shape using the
    254   // given operation as the root.
    255   StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
    256 
    257   // Reports an error to the builder, by
    258   // * storing it internally and capturing a backtrace if it's the first error
    259   //   (this deferred value will be produced on the call to
    260   //    Build()/GetShape()/...)
    261   // * dying if die_immediately_on_error_ is true.
    262   // Returns an XlaOp with an invalid handle but a valid builder. This value can
    263   // be returned in place of a value in APIs that return an XlaOp.
    264   XlaOp ReportError(const Status& error);
    265 
    266   // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
    267   // If the Status was an error, reports the error to builder and returns an
    268   // invalid XlaOp handle.
    269   XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
    270 
    271   // A helper function that runs a function that returns a StatusOr<XlaOp> and
    272   // returns an XlaOp.
    273   XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
    274 
    275   // Returns true if 'operand' is a compile-time constant. A compile-time
    276   // constant does not depend on any parameters, or on stateful operators such
    277   // as `RngNormal` or `Infeed`.
    278   //
    279   // This tests whether a computation is a compile-time constant without
    280   // evaluating the computation.
    281   StatusOr<bool> IsConstant(const XlaOp& operand) const;
    282 
    283   // Sets up binding which indicates that the `target_dim_num` in the subshape
    284   // `target_param_index` of parameter `target_param_num` is a dynamic dimension
    285   // and its real dynamic size is represented by `dynamic_param_index` in
    286   // parameter `dynamic_param_num`.
    287   //
    288   // Note that this should be called before the dynamic parameters are used to
    289   // create other operations, otherwise created operations won't have the
    290   // dynamic dimensions information.
    291   //
    292   // TODO(b/119520625): Remove this API once we have more dynamic shape infra
    293   // ready.
    294   Status SetDynamicBinding(int64 dynamic_size_param_num,
    295                            ShapeIndex dynamic_size_param_index,
    296                            int64 target_param_num,
    297                            ShapeIndex target_param_index, int64 target_dim_num);
    298 
    299   // Adds a new input/output alias. Since the input/ouput shape information are
    300   // not available until the computation is built, and eventual error in the
    301   // arguments of this API will be detected only at computation Build() time.
    302   void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
    303                   const ShapeIndex& param_index) {
    304     input_output_aliases_.push_back({output_index, param_number, param_index});
    305   }
    306 
    307   // Describes an input/output alias as inserted by the SetUpAlias() API.
    308   struct InputOutputAlias {
    309     // Specifies the index of the aliased buffer in the result tuple.
    310     ShapeIndex output_index;
    311     // Specifies the parameter containing the buffer to be aliased.
    312     int64 param_number;
    313     // Specifies the index of the aliased buffer in the parameter
    314     ShapeIndex param_index;
    315   };
    316 
    317  private:
    318   // Build helper which takes the id of the root operation..
    319   StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
    320 
    321   // Description for the methods below can be found in the corresponding public
    322   // functions section in this file.
    323 
    324   XlaOp Parameter(int64 parameter_number, const Shape& shape,
    325                   const string& name);
    326 
    327   XlaOp ConstantLiteral(const LiteralSlice& literal);
    328 
    329   XlaOp Broadcast(const XlaOp& operand,
    330                   absl::Span<const int64> broadcast_sizes);
    331 
    332   XlaOp BroadcastInDim(const XlaOp& operand,
    333                        const absl::Span<const int64> out_dim_size,
    334                        const absl::Span<const int64> broadcast_dimensions);
    335 
    336   XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
    337             const PaddingConfig& padding_config);
    338 
    339   XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
    340                 absl::Span<const int64> new_sizes);
    341 
    342   XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
    343 
    344   XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);
    345 
    346   XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
    347               absl::Span<const int64> limit_indices,
    348               absl::Span<const int64> strides);
    349 
    350   XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
    351                    int64 stride, int64 dimno);
    352 
    353   ABSL_DEPRECATED("Use span-of-indices form instead")
    354   XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
    355                      absl::Span<const int64> slice_sizes);
    356   XlaOp DynamicSlice(const XlaOp& operand,
    357                      absl::Span<const XlaOp> start_indices,
    358                      absl::Span<const int64> slice_sizes);
    359 
    360   ABSL_DEPRECATED("Use span-of-indices form instead")
    361   XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
    362                            const XlaOp& start_indices);
    363   XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
    364                            absl::Span<const XlaOp> start_indices);
    365 
    366   XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
    367 
    368   void Trace(const string& tag, const XlaOp& operand);
    369 
    370   XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
    371 
    372   XlaOp Tuple(absl::Span<const XlaOp> elements);
    373 
    374   XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
    375 
    376   XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
    377             const PrecisionConfig* precision_config = nullptr);
    378 
    379   XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
    380                    const DotDimensionNumbers& dimension_numbers,
    381                    const PrecisionConfig* precision_config = nullptr);
    382 
    383   XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
    384              absl::Span<const int64> window_strides, Padding padding,
    385              int64 feature_group_count = 1, int64 batch_group_count = 1,
    386              const PrecisionConfig* precision_config = nullptr);
    387 
    388   XlaOp ConvWithGeneralPadding(
    389       const XlaOp& lhs, const XlaOp& rhs,
    390       absl::Span<const int64> window_strides,
    391       absl::Span<const std::pair<int64, int64>> padding,
    392       int64 feature_group_count = 1, int64 batch_group_count = 1,
    393       const PrecisionConfig* precision_config = nullptr);
    394 
    395   XlaOp ConvWithGeneralDimensions(
    396       const XlaOp& lhs, const XlaOp& rhs,
    397       absl::Span<const int64> window_strides, Padding padding,
    398       const ConvolutionDimensionNumbers& dimension_numbers,
    399       int64 feature_group_count = 1, int64 batch_group_count = 1,
    400       const PrecisionConfig* precision_config = nullptr);
    401 
    402   XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
    403                     absl::Span<const int64> window_strides,
    404                     absl::Span<const std::pair<int64, int64>> padding,
    405                     const ConvolutionDimensionNumbers& dimension_numbers,
    406                     int64 feature_group_count = 1, int64 batch_group_count = 1,
    407                     const PrecisionConfig* precision_config = nullptr);
    408 
    409   XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
    410                            absl::Span<const int64> window_strides,
    411                            absl::Span<const std::pair<int64, int64>> padding,
    412                            absl::Span<const int64> lhs_dilation,
    413                            absl::Span<const int64> rhs_dilation,
    414                            const ConvolutionDimensionNumbers& dimension_numbers,
    415                            int64 feature_group_count = 1,
    416                            int64 batch_group_count = 1,
    417                            const PrecisionConfig* precision_config = nullptr);
    418 
    419   XlaOp Fft(const XlaOp& operand, FftType fft_type,
    420             absl::Span<const int64> fft_length);
    421 
    422   XlaOp Infeed(const Shape& shape, const string& config = "");
    423   XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
    424                         const string& config = "");
    425 
    426   void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
    427                const string& outfeed_config);
    428   XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
    429                          const Shape& shape_with_layout,
    430                          const string& outfeed_config);
    431 
    432   XlaOp Call(const XlaComputation& computation,
    433              absl::Span<const XlaOp> operands);
    434 
    435   XlaOp CustomCall(
    436       const string& call_target_name, absl::Span<const XlaOp> operands,
    437       const Shape& shape_with_layout, const string& opaque,
    438       absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
    439 
    440   XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
    441                const XlaComputation& computation,
    442                absl::Span<const int64> dimensions_to_reduce);
    443 
    444   XlaOp Reduce(absl::Span<const XlaOp> operands,
    445                absl::Span<const XlaOp> init_values,
    446                const XlaComputation& computation,
    447                absl::Span<const int64> dimensions_to_reduce);
    448 
    449   XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
    450                   const XlaComputation& computation);
    451 
    452   XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
    453                      const XlaComputation& computation,
    454                      absl::Span<const int64> window_dimensions,
    455                      absl::Span<const int64> window_strides, Padding padding);
    456 
    457   XlaOp ReduceWindowWithGeneralPadding(
    458       const XlaOp& operand, const XlaOp& init_value,
    459       const XlaComputation& computation,
    460       absl::Span<const int64> window_dimensions,
    461       absl::Span<const int64> window_strides,
    462       absl::Span<const int64> base_dilations,
    463       absl::Span<const int64> window_dilations,
    464       absl::Span<const std::pair<int64, int64>> padding);
    465 
    466   XlaOp CrossReplicaSum(const XlaOp& operand,
    467                         absl::Span<const ReplicaGroup> replica_groups = {});
    468 
    469   XlaOp CrossReplicaSum(
    470       const XlaOp& operand, const XlaComputation& computation,
    471       absl::Span<const ReplicaGroup> replica_groups = {},
    472       const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
    473 
    474   XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
    475                  int64 concat_dimension, int64 split_count,
    476                  const std::vector<ReplicaGroup>& replica_groups);
    477 
    478   XlaOp CollectivePermute(
    479       const XlaOp& operand,
    480       const std::vector<std::pair<int64, int64>>& source_target_pairs);
    481 
    482   XlaOp ReplicaId();
    483 
    484   XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
    485                          absl::Span<const int64> window_dimensions,
    486                          absl::Span<const int64> window_strides,
    487                          Padding padding, const XlaOp& source,
    488                          const XlaOp& init_value,
    489                          const XlaComputation& scatter);
    490 
    491   XlaOp SelectAndScatterWithGeneralPadding(
    492       const XlaOp& operand, const XlaComputation& select,
    493       absl::Span<const int64> window_dimensions,
    494       absl::Span<const int64> window_strides,
    495       absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
    496       const XlaOp& init_value, const XlaComputation& scatter);
    497 
    498   XlaOp Iota(const Shape& shape, int64 iota_dimension);
    499 
    500   XlaOp Iota(PrimitiveType type, int64 size);
    501 
    502   XlaOp ConvertElementType(const XlaOp& operand,
    503                            PrimitiveType new_element_type);
    504 
    505   XlaOp BitcastConvertType(const XlaOp& operand,
    506                            PrimitiveType new_element_type);
    507 
    508   XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);
    509 
    510   XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
    511 
    512   ABSL_DEPRECATED("Use form with comparator computation instead")
    513   XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
    514              int64 dimension = -1);
    515   XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
    516              int64 dimension = -1, bool is_stable = false);
    517 
    518   XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
    519 
    520   XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
    521             absl::Span<const int64> dimensions,
    522             absl::Span<const XlaOp> static_operands = {});
    523 
    524   XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
    525 
    526   XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
    527 
    528   XlaOp While(const XlaComputation& condition, const XlaComputation& body,
    529               const XlaOp& init);
    530 
    531   XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
    532                     const XlaComputation& true_computation,
    533                     const XlaOp& false_operand,
    534                     const XlaComputation& false_computation);
    535 
    536   XlaOp Conditional(const XlaOp& branch_index,
    537                     absl::Span<const XlaComputation* const> branch_computations,
    538                     absl::Span<const XlaOp> branch_operands);
    539 
    540   XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
    541                         const int mantissa_bits);
    542 
    543   XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
    544                const GatherDimensionNumbers& dimension_numbers,
    545                absl::Span<const int64> slice_sizes);
    546 
    547   XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
    548                 const XlaOp& updates, const XlaComputation& update_computation,
    549                 const ScatterDimensionNumbers& dimension_numbers);
    550 
    551   void Send(const XlaOp& operand, const ChannelHandle& handle);
    552   XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
    553                       const ChannelHandle& handle);
    554 
    555   XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
    556                    const Shape& shape_with_layout, const ChannelHandle& handle);
    557 
    558   XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
    559                      const ChannelHandle& handle);
    560 
    561   XlaOp CreateToken();
    562 
    563   XlaOp AfterAll(absl::Span<const XlaOp> tokens);
    564 
    565   XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
    566   XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
    567                       const ChannelHandle& handle);
    568 
    569   XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
    570                           const XlaOp& offset, float epsilon,
    571                           int64 feature_index);
    572 
    573   XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
    574                            const XlaOp& offset, const XlaOp& mean,
    575                            const XlaOp& variance, float epsilon,
    576                            int64 feature_index);
    577 
    578   XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
    579                       const XlaOp& batch_mean, const XlaOp& batch_var,
    580                       const XlaOp& grad_output, float epsilon,
    581                       int64 feature_index);
    582 
    583   XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension);
    584 
    585   StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
    586                                  absl::Span<const XlaOp> operands = {});
    587 
    588   void AddCalledComputation(const XlaComputation& computation,
    589                             HloInstructionProto* instr);
    590 
    591   StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
    592   StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
    593       int64 handle) const;
    594 
    595   // Internal helper method that does the building for an arbitrary unary op.
    596   XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
    597 
    598   // Internal helper method that does the building for an arbitrary binary op.
    599   // broadcast_dimensions specifies which dimensions to use for broadcasting
    600   // when the operation is between tensors of different ranks. The direction is
    601   // only used if opcode is kCompare.
    602   XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
    603                  absl::Span<const int64> broadcast_dimensions,
    604                  absl::optional<ComparisonDirection> direction = absl::nullopt);
    605 
    606   // Internal helper method that does the building for an arbitrary ternary op.
    607   XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
    608                   const XlaOp& ehs);
    609 
    610   XlaOp RngOp(RandomDistribution distribution,
    611               absl::Span<const XlaOp> parameters, const Shape& shape);
    612 
    613   StatusOr<XlaOp> InDimBroadcast(const Shape& shape, const XlaOp& operand,
    614                                  absl::Span<const int64> broadcast_dimensions);
    615 
    616   // Internal helper method that creates a sequence of instructions that
    617   // performs an explicit broadcast of the operand to the target shape.
    618   StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
    619                                        const XlaOp& operand);
    620 
    621   // Internal helper method for creating a Reshape op with the already inferred
    622   // shape.
    623   StatusOr<XlaOp> Reshape(const Shape& shape, const XlaOp& operand);
    624 
    625   // Returns the (inferred) result for the program shape using the given root.
    626   StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
    627 
    628   // Returns shapes for the operands.
    629   StatusOr<std::vector<Shape>> GetOperandShapes(
    630       absl::Span<const XlaOp> operands) const;
    631 
    632   // A visitor which checks whether an operation is a compile-time constant,
    633   // meaning that it doesn't depend on any parameters, or on any stateful
    634   // operation such as `RngNormal` or `Infeed`. The visitor walks the
    635   // computation starting at a given operation and sets is_constant to false iff
    636   // a parameter or stateful operation is encountered.
    637   void IsConstantVisitor(const int64 op_handle,
    638                          absl::flat_hash_set<int64>* visited,
    639                          bool* is_constant) const;
    640 
    641   // Checks bounds for convolution parameters.
    642   Status VerifyConvolution(
    643       const Shape& lhs_shape, const Shape& rhs_shape,
    644       const ConvolutionDimensionNumbers& dimension_numbers) const;
    645 
    646   // Helper function for creating a Window proto from user-supplied data.
    647   // Returns error if the user-supplied data was invalid.
    648   StatusOr<Window> MakeWindow(absl::Span<const int64> window_dimensions,
    649                               absl::Span<const int64> window_strides,
    650                               absl::Span<const std::pair<int64, int64>> padding,
    651                               absl::Span<const int64> lhs_dilation,
    652                               absl::Span<const int64> rhs_dilation) const;
    653 
    654   int64 GetNextId() { return ++next_id_; }
    655 
    656   // Populates the module with the input/output alias information stored within
    657   // the input_output_aliases vector.
    658   static Status PopulateInputOutputAlias(
    659       HloModuleProto* module, const ProgramShape& program_shape,
    660       const std::vector<InputOutputAlias>& input_output_aliases);
    661 
    662   string name_;  // Name to use for the built computation.
    663 
    664   // The next sequential ID for every instruction/computation contained within
    665   // this computation.
    666   int64 next_id_ = 0;
    667 
    668   // The first error encountered while building the computation.
    669   // This is OK until the first error is encountered.
    670   Status first_error_;
    671 
    672   // The saved stack trace from the point at which the first error occurred.
    673   tensorflow::SavedStackTrace first_error_backtrace_;
    674 
    675   // The instructions of this computation.
    676   std::vector<HloInstructionProto> instructions_;
    677 
    678   // Dynamic parameter configuration of this computation.
    679   DynamicParameterBinding dynamic_parameter_binding_;
    680 
    681   // Holds the input/output alias information populated by the SetUpAlias() API.
    682   std::vector<InputOutputAlias> input_output_aliases_;
    683 
    684   // A map from XlaOp::Handle to the index in the instructions_ vector where the
    685   // instruction is held.
    686   absl::flat_hash_map<int64, int64> handle_to_index_;
    687 
    688   // The embedded computations used by this computation. Each computation was
    689   // the entry computation of some XlaComputation, the key is the unique id of
    690   // that XlaComputation.
    691   std::map<int64, HloComputationProto> embedded_;
    692 
    693   // The unique parameter numbers.
    694   absl::flat_hash_set<int64> parameter_numbers_;
    695 
    696   // The metadata to attach to each op. This is structured as a "modal"-like
    697   // operation, in order to simplify client code (and not sprinkle this metadata
    698   // throughout the TensorFlow op kernel implementations).
    699   OpMetadata metadata_;
    700 
    701   // Sharding for this operator. This is structured as a "model"-like operation,
    702   // in order to simplify client code, similar to metadata_.
    703   absl::optional<OpSharding> sharding_;
    704 
    705   // Mode bit that indicates whether to die when a first error is encountered.
    706   bool die_immediately_on_error_ = false;
    707 
    708   XlaBuilder* parent_builder_{nullptr};
    709 
    710   friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
    711                          const Shape& shape, const string& name);
    712   friend XlaOp ConstantLiteral(XlaBuilder* builder,
    713                                const LiteralSlice& literal);
    714 
    715   friend XlaOp Broadcast(const XlaOp& operand,
    716                          absl::Span<const int64> broadcast_sizes);
    717 
    718   friend XlaOp BroadcastInDim(
    719       const XlaOp& operand, const absl::Span<const int64> out_dim_size,
    720       const absl::Span<const int64> broadcast_dimensions);
    721 
    722   friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
    723                    const PaddingConfig& padding_config);
    724 
    725   friend XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
    726                        absl::Span<const int64> new_sizes);
    727 
    728   friend XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
    729 
    730   friend XlaOp Collapse(const XlaOp& operand,
    731                         absl::Span<const int64> dimensions);
    732 
    733   friend XlaOp Slice(const XlaOp& operand,
    734                      absl::Span<const int64> start_indices,
    735                      absl::Span<const int64> limit_indices,
    736                      absl::Span<const int64> strides);
    737 
    738   friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index,
    739                           int64 limit_index, int64 stride, int64 dimno);
    740 
    741   friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
    742                             absl::Span<const int64> slice_sizes);
    743   friend XlaOp DynamicSlice(const XlaOp& operand,
    744                             absl::Span<const XlaOp> start_indices,
    745                             absl::Span<const int64> slice_sizes);
    746 
    747   friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
    748                                   const XlaOp& start_indices);
    749   friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
    750                                   absl::Span<const XlaOp> start_indices);
    751 
    752   friend XlaOp ConcatInDim(XlaBuilder* builder,
    753                            absl::Span<const XlaOp> operands, int64 dimension);
    754 
    755   friend void Trace(const string& tag, const XlaOp& operand);
    756 
    757   friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true,
    758                       const XlaOp& on_false);
    759   friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
    760   friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
    761   friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
    762                   absl::Span<const int64> broadcast_dimensions);
    763   friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
    764                   absl::Span<const int64> broadcast_dimensions);
    765   friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
    766                   absl::Span<const int64> broadcast_dimensions);
    767   friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
    768                   absl::Span<const int64> broadcast_dimensions);
    769   friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
    770                   absl::Span<const int64> broadcast_dimensions);
    771   friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
    772                   absl::Span<const int64> broadcast_dimensions);
    773   friend XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
    774                        absl::Span<const int64> broadcast_dimensions,
    775                        ComparisonDirection direction);
    776   friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
    777                    const PrecisionConfig* precision_config);
    778   friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
    779                           const DotDimensionNumbers& dimension_number,
    780                           const PrecisionConfig* precision_config);
    781   friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
    782                     absl::Span<const int64> window_strides, Padding padding,
    783                     int64 feature_group_count, int64 batch_group_count,
    784                     const PrecisionConfig* precision_config);
    785   friend XlaOp ConvWithGeneralPadding(
    786       const XlaOp& lhs, const XlaOp& rhs,
    787       absl::Span<const int64> window_strides,
    788       absl::Span<const std::pair<int64, int64>> padding,
    789       int64 feature_group_count, int64 batch_group_count,
    790       const PrecisionConfig* precision_config);
    791   friend XlaOp ConvWithGeneralDimensions(
    792       const XlaOp& lhs, const XlaOp& rhs,
    793       absl::Span<const int64> window_strides, Padding padding,
    794       const ConvolutionDimensionNumbers& dimension_numbers,
    795       int64 feature_group_count, int64 batch_group_count,
    796       const PrecisionConfig* precision_config);
    797   friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
    798                            absl::Span<const int64> window_strides,
    799                            absl::Span<const std::pair<int64, int64>> padding,
    800                            const ConvolutionDimensionNumbers& dimension_numbers,
    801                            int64 feature_group_count, int64 batch_group_count,
    802                            const PrecisionConfig* precision_config);
    803   friend XlaOp ConvGeneralDilated(
    804       const XlaOp& lhs, const XlaOp& rhs,
    805       absl::Span<const int64> window_strides,
    806       absl::Span<const std::pair<int64, int64>> padding,
    807       absl::Span<const int64> lhs_dilation,
    808       absl::Span<const int64> rhs_dilation,
    809       const ConvolutionDimensionNumbers& dimension_numbers,
    810       int64 feature_group_count, int64 batch_group_count,
    811       const PrecisionConfig* precision_config);
    812   friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
    813                    absl::Span<const int64> fft_length);
    814   friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
    815                                bool unit_diagonal,
    816                                TriangularSolveOptions::Transpose transpose_a);
    817   friend XlaOp Cholesky(XlaOp a, bool lower);
    818   friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
    819                       const string& config);
    820   friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
    821                       const string& outfeed_config);
    822   friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
    823                     absl::Span<const XlaOp> operands);
    824   friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
    825                           absl::Span<const XlaOp> operands, const Shape& shape,
    826                           const string& opaque);
    827   friend XlaOp CustomCallWithLayout(
    828       XlaBuilder* builder, const string& call_target_name,
    829       absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
    830       absl::Span<const Shape> operand_shapes_with_layout, const string& opaque);
    831   friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
    832                        absl::Span<const int64> broadcast_dimensions);
    833   friend XlaOp Conj(const XlaOp& operand);
    834   friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
    835                    absl::Span<const int64> broadcast_dimensions);
    836   friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
    837                    absl::Span<const int64> broadcast_dimensions);
    838   friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
    839                    absl::Span<const int64> broadcast_dimensions);
    840   friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
    841                    absl::Span<const int64> broadcast_dimensions);
    842   friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
    843                    absl::Span<const int64> broadcast_dimensions);
    844   friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
    845                    absl::Span<const int64> broadcast_dimensions);
    846   friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
    847                    absl::Span<const int64> broadcast_dimensions);
    848   friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
    849                    absl::Span<const int64> broadcast_dimensions);
    850   friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
    851                   absl::Span<const int64> broadcast_dimensions);
    852   friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
    853                    absl::Span<const int64> broadcast_dimensions);
    854   friend XlaOp Not(const XlaOp& operand);
    855   friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
    856                          absl::Span<const int64> broadcast_dimensions);
    857   friend XlaOp ShiftRightArithmetic(
    858       const XlaOp& lhs, const XlaOp& rhs,
    859       absl::Span<const int64> broadcast_dimensions);
    860   friend XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
    861                                  absl::Span<const int64> broadcast_dimensions);
    862   friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
    863                       const XlaComputation& computation,
    864                       absl::Span<const int64> dimensions_to_reduce);
    865   friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
    866                       absl::Span<const XlaOp> init_values,
    867                       const XlaComputation& computation,
    868                       absl::Span<const int64> dimensions_to_reduce);
    869   friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
    870                          const XlaComputation& computation);
    871   friend XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
    872                             const XlaComputation& computation,
    873                             absl::Span<const int64> window_dimensions,
    874                             absl::Span<const int64> window_strides,
    875                             Padding padding);
    876   friend XlaOp ReduceWindowWithGeneralPadding(
    877       const XlaOp& operand, const XlaOp& init_value,
    878       const XlaComputation& computation,
    879       absl::Span<const int64> window_dimensions,
    880       absl::Span<const int64> window_strides,
    881       absl::Span<const int64> base_dilations,
    882       absl::Span<const int64> window_dilations,
    883       absl::Span<const std::pair<int64, int64>> padding);
    884   friend XlaOp CrossReplicaSum(const XlaOp& operand,
    885                                absl::Span<const ReplicaGroup> replica_groups);
    886   friend XlaOp CrossReplicaSum(const XlaOp& operand,
    887                                const XlaComputation& computation,
    888                                absl::Span<const ReplicaGroup> replica_groups,
    889                                const absl::optional<ChannelHandle>& channel_id);
    890   friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
    891                         int64 concat_dimension, int64 split_count,
    892                         const std::vector<ReplicaGroup>& replica_groups);
    893   friend XlaOp CollectivePermute(
    894       const XlaOp& operand,
    895       const std::vector<std::pair<int64, int64>>& source_target_pairs);
    896   friend XlaOp ReplicaId(XlaBuilder* builder);
    897   friend XlaOp SelectAndScatter(const XlaOp& operand,
    898                                 const XlaComputation& select,
    899                                 absl::Span<const int64> window_dimensions,
    900                                 absl::Span<const int64> window_strides,
    901                                 Padding padding, const XlaOp& source,
    902                                 const XlaOp& init_value,
    903                                 const XlaComputation& scatter);
    904   friend XlaOp SelectAndScatterWithGeneralPadding(
    905       const XlaOp& operand, const XlaComputation& select,
    906       absl::Span<const int64> window_dimensions,
    907       absl::Span<const int64> window_strides,
    908       absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
    909       const XlaOp& init_value, const XlaComputation& scatter);
    910   friend XlaOp Abs(const XlaOp& operand);
    911   friend XlaOp Atan2(const XlaOp& y, const XlaOp& x,
    912                      absl::Span<const int64> broadcast_dimensions);
    913   friend XlaOp Exp(const XlaOp& operand);
    914   friend XlaOp Expm1(const XlaOp& operand);
    915   friend XlaOp Floor(const XlaOp& operand);
    916   friend XlaOp Ceil(const XlaOp& operand);
    917   friend XlaOp Round(const XlaOp& operand);
    918   friend XlaOp Log(const XlaOp& operand);
    919   friend XlaOp Log1p(const XlaOp& operand);
    920   friend XlaOp Sign(const XlaOp& operand);
    921   friend XlaOp Clz(const XlaOp& operand);
    922   friend XlaOp Cos(const XlaOp& operand);
    923   friend XlaOp Sin(const XlaOp& operand);
    924   friend XlaOp Tanh(const XlaOp& operand);
    925   friend XlaOp Real(const XlaOp& operand);
    926   friend XlaOp Imag(const XlaOp& operand);
    927   friend XlaOp Sqrt(const XlaOp& operand);
    928   friend XlaOp Rsqrt(const XlaOp& operand);
    929   friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
    930                    absl::Span<const int64> broadcast_dimensions);
    931   friend XlaOp IsFinite(const XlaOp& operand);
    932   friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
    933                     int64 iota_dimension);
    934   friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
    935   friend XlaOp ConvertElementType(const XlaOp& operand,
    936                                   PrimitiveType new_element_type);
    937   friend XlaOp BitcastConvertType(const XlaOp& operand,
    938                                   PrimitiveType new_element_type);
    939   friend XlaOp Neg(const XlaOp& operand);
    940   friend XlaOp Transpose(const XlaOp& operand,
    941                          absl::Span<const int64> permutation);
    942   friend XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
    943   friend XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values,
    944                     int64 dimension);
    945   friend XlaOp Sort(absl::Span<const XlaOp> operands,
    946                     const XlaComputation& comparator, int64 dimension,
    947                     bool is_stable);
    948   friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
    949   friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
    950                    const XlaComputation& computation,
    951                    absl::Span<const int64> dimensions,
    952                    absl::Span<const XlaOp> static_operands);
    953   friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma,
    954                          const Shape& shape);
    955   friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
    956   friend XlaOp While(const XlaComputation& condition,
    957                      const XlaComputation& body, const XlaOp& init);
    958   friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
    959                            const XlaComputation& true_computation,
    960                            const XlaOp& false_operand,
    961                            const XlaComputation& false_computation);
    962   friend XlaOp Conditional(
    963       const XlaOp& branch_index,
    964       absl::Span<const XlaComputation* const> branch_computations,
    965       absl::Span<const XlaOp> branch_operands);
    966   friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
    967                                const int mantissa_bits);
    968   friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
    969                       const GatherDimensionNumbers& dimension_numbers,
    970                       absl::Span<const int64> slice_sizes);
    971   friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
    972                        const XlaOp& updates,
    973                        const XlaComputation& update_computation,
    974                        const ScatterDimensionNumbers& dimension_numbers);
    975   friend void Send(const XlaOp& operand, const ChannelHandle& handle);
    976   friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
    977                     const ChannelHandle& handle);
    978   friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
    979                                  const XlaOp& offset, float epsilon,
    980                                  int64 feature_index);
    981   friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
    982                                   const XlaOp& offset, const XlaOp& mean,
    983                                   const XlaOp& variance, float epsilon,
    984                                   int64 feature_index);
    985   friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
    986                              const XlaOp& batch_mean, const XlaOp& batch_var,
    987                              const XlaOp& grad_output, float epsilon,
    988                              int64 feature_index);
    989   friend XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
    990                              const ChannelHandle& handle);
    991   friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
    992                              const ChannelHandle& handle);
    993   friend XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
    994                           const Shape& shape_with_layout,
    995                           const ChannelHandle& handle);
    996   friend XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
    997                             const ChannelHandle& handle);
    998   friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
    999                                const string& config);
   1000   friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
   1001                                 const Shape& shape_with_layout,
   1002                                 const string& outfeed_config);
   1003   friend XlaOp CreateToken(XlaBuilder* builder);
   1004   friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
   1005 
   1006   friend XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension);
   1007 };
   1008 
   1009 // RAII-style object: sets the current sharding assignment in builder on
   1010 // construction, and sets back to the previous assignment on destruction.
   1011 class XlaScopedShardingAssignment {
   1012  public:
   1013   XlaScopedShardingAssignment(xla::XlaBuilder* builder,
   1014                               absl::optional<OpSharding> sharding)
   1015       : builder_(builder), prev_sharding_(builder->sharding()) {
   1016     SetSharding(sharding);
   1017   }
   1018 
   1019   XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
   1020   XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
   1021       delete;
   1022 
   1023   ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
   1024 
   1025  private:
   1026   void SetSharding(const absl::optional<OpSharding>& sharding) {
   1027     if (sharding.has_value()) {
   1028       builder_->SetSharding(sharding.value());
   1029     } else {
   1030       builder_->ClearSharding();
   1031     }
   1032   }
   1033 
   1034   xla::XlaBuilder* const builder_;
   1035   absl::optional<OpSharding> prev_sharding_;
   1036 };
   1037 
   1038 // Free functions for building XlaOps. The intention is that these will
   1039 // become the public API for building XlaOps rather than calling methods on
   1040 // XlaBuilder directly.
   1041 //
   1042 
   1043 // Enqueues a "retrieve parameter value" instruction for a parameter that was
   1044 // passed to the computation.
   1045 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
   1046                 const string& name);
   1047 
   1048 // Enqueues a constant with the value of the given literal onto the
   1049 // computation.
   1050 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
   1051 
   1052 // Enqueues a constant onto the computation. Methods are templated on the
   1053 // native host type (NativeT) which corresponds to a specific XLA
   1054 // PrimitiveType as given in the following table:
   1055 //
   1056 //  Native Type   PrimitiveType
   1057 // -----------------------------
   1058 //   bool           PRED
   1059 //   int32          S32
   1060 //   int64          S64
   1061 //   uint32         U32
   1062 //   uint64         U64
   1063 //   float          F32
   1064 //   double         F64
   1065 //
   1066 // Note: not all primitive types defined in xla_data.proto have a
   1067 // corresponding native type yet.
   1068 template <typename NativeT>
   1069 XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
   1070 template <typename NativeT>
   1071 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values);
   1072 XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
   1073 template <typename NativeT>
   1074 XlaOp ConstantR2(XlaBuilder* builder,
   1075                  std::initializer_list<std::initializer_list<NativeT>> values);
   1076 template <typename NativeT>
   1077 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
   1078                                   const Array<NativeT>& values,
   1079                                   const Layout& layout);
   1080 template <typename NativeT>
   1081 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
   1082 template <typename NativeT>
   1083 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
   1084                                       const Array2D<NativeT>& values,
   1085                                       const Layout& layout);
   1086 template <typename NativeT>
   1087 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
   1088                             const Array2D<NativeT>& values);
   1089 template <typename NativeT>
   1090 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
   1091                                       const Array3D<NativeT>& values,
   1092                                       const Layout& layout);
   1093 template <typename NativeT>
   1094 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
   1095                             const Array3D<NativeT>& values);
   1096 template <typename NativeT>
   1097 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
   1098                                       const Array4D<NativeT>& values,
   1099                                       const Layout& layout);
   1100 template <typename NativeT>
   1101 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
   1102                             const Array4D<NativeT>& values);
   1103 
   1104 // Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
   1105 // computation. The vector has size 'length' and every element has the value
   1106 // 'value'.
   1107 template <typename NativeT>
   1108 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
   1109 
   1110 // Adds dimensions to an array by duplicating the data in the array.
   1111 //
   1112 // The new dimensions are inserted on the left, i.e. if
   1113 // broadcast_sizes has values {a0, ..., aN} and the operand shape
   1114 // has dimensions {b0, ..., bM} then the shape of the output has
   1115 // dimensions {a0, ..., aN, b0, ..., bM}.
   1116 //
   1117 // The new dimensions index into copies of the operand, i.e.
   1118 //
   1119 //   output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
   1120 XlaOp Broadcast(const XlaOp& operand, absl::Span<const int64> broadcast_sizes);
   1121 
   1122 // This op broadcasts the `operand` to an output with the given `shape`.
   1123 // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the
   1124 // i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th
   1125 // dimension of the output. This also requires that the i'th input dimension is
   1126 // either 1 or is the same as the output dimension it's broadcasting into.
   1127 //
   1128 // For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the
   1129 // output shape is s32[2,2]:
   1130 // - Specifying {1} as brodcast_dimension will generate output
   1131 //   {{1, 2},
   1132 //    {1, 2}}
   1133 // - On the other hand, specifying {0} as broadcast_dimension
   1134 //   will generate output
   1135 //   {{1 , 1},
   1136 //    {2 , 2}}
   1137 XlaOp BroadcastInDim(const XlaOp& operand,
   1138                      const absl::Span<const int64> out_dim_size,
   1139                      const absl::Span<const int64> broadcast_dimensions);
   1140 
   1141 // Enqueues a pad operation onto the computation that pads the given value on
   1142 // the edges as well as between the elements of the input. padding_config
   1143 // specifies the padding amount for each dimension.
   1144 XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
   1145           const PaddingConfig& padding_config);
   1146 
   1147 // Enqueues an operation onto the computation that flattens the operand based
   1148 // on the dimension order (major/slowest-varying to minor/fastest-varying)
   1149 // given, followed by reshaping it into the shape with the given dimension
   1150 // sizes (also major to minor). Conceptually, this is a limited form of
   1151 // "shape casting".
   1152 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
   1153               absl::Span<const int64> new_sizes);
   1154 
   1155 // Enqueues an operation onto the computation that collapses the operand, from
   1156 // first to last dimension (C order), then reshapes it to the given dimension
   1157 // sizes. Conceptually, this is a limited form of "shape casting".
   1158 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
   1159 
   1160 // Wrapper for Reshape.
   1161 // Enqueues an operation to collapse the provided dimensions; e.g. an
   1162 // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
   1163 // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
   1164 // be a consecutive, in-order subsequence of the operand dimensions.
   1165 //
   1166 // Note that collapsing a single dimension does nothing:
   1167 //
   1168 //    {256} collapsing {0} => {256}
   1169 //    {1} collapsing {0} => {1}
   1170 //
   1171 // Collapsing multiple dimensions produces a single result dimension:
   1172 //
   1173 //    {256, 2} collapsing {0,1} => {512}
   1174 //    {256, 2, 3} collapsing {0,1} => {512, 3}
   1175 //
   1176 // This could potentially cause data to be moved -- it provides a more
   1177 // structured form of reshaping than an arbitrary Reshape operation.
   1178 XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);
   1179 
   1180 // Enqueues a slice operation onto the computation that slices the operand
   1181 // from the start indices to the limit indices; e.g.
   1182 //
   1183 //        x
   1184 //   [ 0 1 2 3 ]
   1185 // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
   1186 //   [ 8 9 a b ]
   1187 //
   1188 // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
   1189 // range notation.
   1190 // The strides parameter determines the stride over the slice
   1191 XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
   1192             absl::Span<const int64> limit_indices,
   1193             absl::Span<const int64> strides);
   1194 
   1195 // Enqueues a slice operation in a given dimension, taking all other
   1196 // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
   1197 // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
   1198 // for:
   1199 //
   1200 //  array[:, 2:4:1, :]
   1201 XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
   1202                  int64 stride, int64 dimno);
   1203 
   1204 // Enqueues a slice operation onto the computation that slices the 'operand'
   1205 // from dynamic start indices which are passed in 'start_indices'.
   1206 // The size of the slice in each dimension is passed in 'slice_sizes',
   1207 // which specify the end point of exclusive slice intervals in each
   1208 // dimension [start, start + size).
   1209 // The shape of each element of 'start_indices' must be scalar, with the span
   1210 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
   1211 // have the same shape.
   1212 // Slice index calculations are computed modulo input dimension sizes to
   1213 // prevent dynamic start indices from generating out-of-bound array accesses.
   1214 XlaOp DynamicSlice(const XlaOp& operand, absl::Span<const XlaOp> start_indices,
   1215                    absl::Span<const int64> slice_sizes);
   1216 
   1217 ABSL_DEPRECATED("Use span-of-indices form instead")
   1218 XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
   1219                    absl::Span<const int64> slice_sizes);
   1220 
   1221 // Enqueues a dynamic update slice operation onto the computation, which
   1222 // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
   1223 // The shape of 'update' determines the shape of the slice of 'operand'
   1224 // which is updated.
   1225 // The indices specified in 'start_indices' specify the offset of the slice
   1226 // of 'operand' which is updated.
   1227 //
   1228 //               update = {10, 11} // calculated at runtime.
   1229 //   [1 2 3]     start  = {1, 1}   // calculated at runtime.  [1 2  3 ]
   1230 //   [4 5 6]  => DynamicUpdateslice(data, update, start)   => [4 10 11]
   1231 //   [7 8 9]                                                  [7 8  9 ]
   1232 //
   1233 // The shape of each element of 'start_indices' must be scalar, with the span
   1234 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
   1235 // have the same shape.
   1236 // Slice index calculations are computed modulo update dimension sizes to
   1237 // prevent dynamic start indices from generating out-of-bound array accesses.
   1238 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
   1239                          absl::Span<const XlaOp> start_indices);
   1240 
   1241 ABSL_DEPRECATED("Use span-of-indices form instead")
   1242 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
   1243                          const XlaOp& start_indices);
   1244 
   1245 // Enqueues a concatenate instruction onto the computation. 'operands' must
   1246 // have >= 1 entry.
   1247 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
   1248                   int64 dimension);
   1249 
   1250 // Enqueue a tracing operation onto the computation; the computation will emit
   1251 // a logging message with the operand.
   1252 void Trace(const string& tag, const XlaOp& operand);
   1253 
   1254 // Enqueues a conditional-move-like select operation onto the computation;
   1255 // predicated on pred, selects between on_true and on_false.
   1256 XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
   1257 
   1258 // Enqueues a tuple-creation instruction onto the computation.
   1259 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
   1260 
   1261 // Enqueues a tuple-element-get instruction onto the computation.
   1262 XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
   1263 
   1264 // Enqueues an equal-to comparison instruction onto the computation.
   1265 XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
   1266          absl::Span<const int64> broadcast_dimensions = {});
   1267 
   1268 // Enqueues a not-equal comparison instruction onto the computation.
   1269 XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
   1270          absl::Span<const int64> broadcast_dimensions = {});
   1271 
   1272 // Enqueues a greater-or-equal comparison instruction onto the computation.
   1273 XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
   1274          absl::Span<const int64> broadcast_dimensions = {});
   1275 
   1276 // Enqueues a greater-than comparison instruction onto the computation.
   1277 XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
   1278          absl::Span<const int64> broadcast_dimensions = {});
   1279 
   1280 // Enqueues a less-than comparison instruction onto the computation.
   1281 XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
   1282          absl::Span<const int64> broadcast_dimensions = {});
   1283 
   1284 // Enqueues a less-or-equal comparison instruction onto the computation.
   1285 XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
   1286          absl::Span<const int64> broadcast_dimensions = {});
   1287 
   1288 // Enqueues a comparison instruction onto the computation.
   1289 XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
   1290               absl::Span<const int64> broadcast_dimensions,
   1291               ComparisonDirection direction);
   1292 
   1293 // Enqueues a dot instruction onto the computation.
   1294 XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
   1295           const PrecisionConfig* precision_config = nullptr);
   1296 
   1297 // Enqueues a general dot instruction onto the computation.
   1298 XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
   1299                  const DotDimensionNumbers& dimension_numbers,
   1300                  const PrecisionConfig* precision_config = nullptr);
   1301 
   1302 // Enqueues a convolution instruction onto the computation, which uses the
   1303 // default convolution dimension numbers.
   1304 XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
   1305            absl::Span<const int64> window_strides, Padding padding,
   1306            int64 feature_group_count = 1, int64 batch_group_count = 1,
   1307            const PrecisionConfig* precision_config = nullptr);
   1308 
   1309 // Enqueues a convolution instruction onto the computation, with the caller
   1310 // provided padding configuration in the format returned by MakePadding().
   1311 XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs,
   1312                              absl::Span<const int64> window_strides,
   1313                              absl::Span<const std::pair<int64, int64>> padding,
   1314                              int64 feature_group_count = 1,
   1315                              int64 batch_group_count = 1,
   1316                              const PrecisionConfig* precision_config = nullptr);
   1317 
   1318 // Enqueues a convolution instruction onto the computation, with the caller
   1319 // provided dimension numbers configuration.
   1320 XlaOp ConvWithGeneralDimensions(
   1321     const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
   1322     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
   1323     int64 feature_group_count = 1, int64 batch_group_count = 1,
   1324     const PrecisionConfig* precision_config = nullptr);
   1325 
   1326 // Enqueues a convolution instruction onto the computation, with the caller
   1327 // provided padding configuration as well as the dimension numbers.
   1328 XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
   1329                   absl::Span<const int64> window_strides,
   1330                   absl::Span<const std::pair<int64, int64>> padding,
   1331                   const ConvolutionDimensionNumbers& dimension_numbers,
   1332                   int64 feature_group_count = 1, int64 batch_group_count = 1,
   1333                   const PrecisionConfig* precision_config = nullptr);
   1334 
   1335 // Enqueues a convolution instruction onto the computation, with the caller
   1336 // provided padding configuration, dilation factors and dimension numbers.
   1337 XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
   1338                          absl::Span<const int64> window_strides,
   1339                          absl::Span<const std::pair<int64, int64>> padding,
   1340                          absl::Span<const int64> lhs_dilation,
   1341                          absl::Span<const int64> rhs_dilation,
   1342                          const ConvolutionDimensionNumbers& dimension_numbers,
   1343                          int64 feature_group_count = 1,
   1344                          int64 batch_group_count = 1,
   1345                          const PrecisionConfig* precision_config = nullptr);
   1346 
   1347 // Enqueues an FFT instruction onto the computation, of the given type and
   1348 // with the given FFT length.
   1349 XlaOp Fft(const XlaOp& operand, FftType fft_type,
   1350           absl::Span<const int64> fft_length);
   1351 
   1352 // Solves systems of linear equations with lower or upper triangular coefficient
   1353 // matrices by forward- or back-substitution. Broadcasting along leading
   1354 // dimensions, this routine solves for x in one of the matrix systems
   1355 //   `op(a) * x = b`,  or `x * op(a) = b`,
   1356 // for the variable `x` given `a` and `b`, where `op(a)` is either
   1357 //   `op(a) = a`,  or `op(a) = transpose(a)`,  or `op(a) = conj(transpose(a))`.
   1358 //
   1359 // * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form
   1360 //   square matrices. If `lower` is true (false), then the strictly upper
   1361 //   (lower) triangular part of each innermost matrix in `a` is assumed to be
   1362 //   zero and is not accessed.
   1363 // * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a
   1364 //   tensor of shape `[..., K, M]`.
   1365 // * `left_side` is a boolean, indicating whether to solve a system of the form
   1366 //   op(a) * x = b (true) or x * op(a) = b (false).
   1367 // * `lower` is a boolean, indicating whether the argument `a` is
   1368 //   lower-triangular (true) or upper-triangular (false).
   1369 // * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be
   1370 //   1 and not accessed.
   1371 // * `transpose_a` indicates which function `op` we use to transform the tensor
   1372 //   `a`: the identity function, transpose(a), or conjugate(transpose(a))
   1373 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
   1374                       bool unit_diagonal,
   1375                       TriangularSolveOptions::Transpose transpose_a);
   1376 
   1377 // Computes the Cholesky decompositions of a batch of symmetric (Hermitian)
   1378 // positive definite matrices.
   1379 // `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
   1380 // two minor dimensions equal.
   1381 // If `lower` is true, the data from the lower triangle is used; if false, the
   1382 // upper triangle is used. The input data in the other triangle of the input
   1383 // does not affect the output. Returns the output in the same lower/uppper
   1384 // triangle. The data returned in the other output triangle is arbitrary and
   1385 // implementation-defined.
   1386 //
   1387 // The value returned if `a` is not Hermitian positive definite is
   1388 // implementation-defined.
   1389 XlaOp Cholesky(XlaOp a, bool lower);
   1390 
   1391 // Enqueues an infeed instruction onto the computation, which writes data of
   1392 // the given shape to the infeed buffer of the device.
   1393 XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
   1394              const string& config = "");
   1395 
   1396 // Variant of Infeed which takes a token-shaped operand and produces a
   1397 // two-element tuple containing the data value and a token-shaped value.
   1398 // Tokens are used for ordering side-effecting operations.
   1399 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
   1400 XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
   1401                       const string& config = "");
   1402 
   1403 // Enqueues an outfeed instruction onto the computation. This instruction
   1404 // generates outgoing data transfers for the given data.
   1405 //
   1406 // shape_with_layout communicates the laid out shape that we want to outfeed
   1407 // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
   1408 // will occur.
   1409 void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
   1410              const string& outfeed_config);
   1411 
   1412 // Variant of Outfeed which takes a token-shaped operand and produces a
   1413 // token-shaped value. Tokens are used for ordering side-effecting operations.
   1414 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
   1415 XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
   1416                        const Shape& shape_with_layout,
   1417                        const string& outfeed_config);
   1418 
   1419 // Enqueues a call instruction onto the computation.
   1420 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
   1421            absl::Span<const XlaOp> operands);
   1422 
   1423 // Enqueues a custom call instruction onto the computation. A custom call
   1424 // invokes code external to XLA. The |operands| are passed to the external code,
   1425 // and the external code is expected to produce a result of the given
   1426 // |shape|. The exact mechanism is backend-specific. For example, in the CPU
   1427 // backend, a call instruction is emitted which targets a symbol with the name
   1428 // |call_target_name|.  |call_target_name| and |opaque| can arbitrary strings,
   1429 // but |call_target_name| should be short as it may be used in labels. |opaque|
   1430 // can encode arbitrarily large amounts of information.
   1431 XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
   1432                  absl::Span<const XlaOp> operands, const Shape& shape,
   1433                  const string& opaque = "");
   1434 
   1435 // Overload which constructs a custom call with fixed layouts. The operands will
   1436 // have the layouts specified by |operand_shapes_with_layout| when provided to
   1437 // external code, and the external code is expected to produce a result with the
   1438 // layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
   1439 // and |operand_shapes_with_layout| must have layouts.
   1440 XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
   1441                            absl::Span<const XlaOp> operands,
   1442                            const Shape& shape_with_layout,
   1443                            absl::Span<const Shape> operand_shapes_with_layout,
   1444                            const string& opaque = "");
   1445 
   1446 // The following methods enqueue element-wise binary arithmetic operations
   1447 // onto the computation. The shapes of the operands have to match unless one
   1448 // of the operands is a scalar, or an explicit broadcast dimension is given
   1449 // (see g3doc for more details).
   1450 
   1451 // Enqueues a complex compose instruction onto the computation.
   1452 XlaOp Complex(const XlaOp& real, const XlaOp& imag,
   1453               absl::Span<const int64> broadcast_dimensions = {});
   1454 
   1455 // Enqueues a complex conjugate instruction onto the computation.
   1456 XlaOp Conj(const XlaOp& operand);
   1457 
   1458 // Enqueues an add instruction onto the computation.
   1459 XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
   1460           absl::Span<const int64> broadcast_dimensions = {});
   1461 
   1462 // Enqueues a subtract instruction onto the computation.
   1463 XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
   1464           absl::Span<const int64> broadcast_dimensions = {});
   1465 
   1466 // Enqueues a multiply instruction onto the computation.
   1467 XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
   1468           absl::Span<const int64> broadcast_dimensions = {});
   1469 
   1470 // Enqueues a divide instruction onto the computation.
   1471 XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
   1472           absl::Span<const int64> broadcast_dimensions = {});
   1473 
   1474 // Enqueues a remainder instruction onto the computation.
   1475 XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
   1476           absl::Span<const int64> broadcast_dimensions = {});
   1477 
   1478 // Enqueues a max instruction onto the computation.
   1479 XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
   1480           absl::Span<const int64> broadcast_dimensions = {});
   1481 
   1482 // Enqueues a min instruction onto the computation.
   1483 XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
   1484           absl::Span<const int64> broadcast_dimensions = {});
   1485 
   1486 // Element-wise logical operators
   1487 XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
   1488           absl::Span<const int64> broadcast_dimensions = {});
   1489 
   1490 // Overload to call And with 3 or more operands.  We need the following somewhat
   1491 // convoluted overload set to disambiguate with the overload that takes the
   1492 // `broadcast_dimensions` optional param.
   1493 inline XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) {
   1494   return And(op1, And(op2, op3));
   1495 }
   1496 template <typename... XlaOpTs>
   1497 XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3,
   1498           const XlaOpTs&... operands) {
   1499   return And(op1, And(op2, And(op3, operands...)));
   1500 }
   1501 
   1502 XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
   1503          absl::Span<const int64> broadcast_dimensions = {});
   1504 
   1505 // Overload to call Or with 3 or more operands.  As with `And`, we need the
   1506 // following complicated overload set to handle the default arg in the `Or`
   1507 // overload above.
   1508 inline XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) {
   1509   return Or(op1, Or(op2, op3));
   1510 }
   1511 template <typename... XlaOpTs>
   1512 XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3,
   1513          const XlaOpTs&... operands) {
   1514   return Or(op1, Or(op2, Or(op3, operands...)));
   1515 }
   1516 
   1517 XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
   1518           absl::Span<const int64> broadcast_dimensions = {});
   1519 
   1520 XlaOp Not(const XlaOp& operand);
   1521 
   1522 XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
   1523                 absl::Span<const int64> broadcast_dimensions = {});
   1524 XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
   1525                            absl::Span<const int64> broadcast_dimensions = {});
   1526 XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
   1527                         absl::Span<const int64> broadcast_dimensions = {});
   1528 
   1529 // Reduces an array among the provided dimensions, given "computation" as a
   1530 // reduction operator.
   1531 XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
   1532              const XlaComputation& computation,
   1533              absl::Span<const int64> dimensions_to_reduce);
   1534 
   1535 // Reduces several arrays simultaneously among the provided dimensions, given
   1536 // "computation" as a reduction operator.
   1537 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
   1538              absl::Span<const XlaOp> init_values,
   1539              const XlaComputation& computation,
   1540              absl::Span<const int64> dimensions_to_reduce);
   1541 
   1542 // Convenience wrapper around the above that reduces all the dimensions in the
   1543 // operand shape.
   1544 XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
   1545                 const XlaComputation& computation);
   1546 
   1547 // Enqueues a windowed reduce instruction onto the computation.
   1548 XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
   1549                    const XlaComputation& computation,
   1550                    absl::Span<const int64> window_dimensions,
   1551                    absl::Span<const int64> window_strides, Padding padding);
   1552 
   1553 // As ReduceWindow(), but the padding is given in the format
   1554 // returned by MakePadding().
   1555 XlaOp ReduceWindowWithGeneralPadding(
   1556     const XlaOp& operand, const XlaOp& init_value,
   1557     const XlaComputation& computation,
   1558     absl::Span<const int64> window_dimensions,
   1559     absl::Span<const int64> window_strides,
   1560     absl::Span<const int64> base_dilations,
   1561     absl::Span<const int64> window_dilations,
   1562     absl::Span<const std::pair<int64, int64>> padding);
   1563 
   1564 // Returns the sum of the operand value within each subgroup of replicas. All
   1565 // replicas supply one input to the sum and all replicas receive the resulting
   1566 // sum for each subgroup.
   1567 XlaOp CrossReplicaSum(const XlaOp& operand,
   1568                       absl::Span<const ReplicaGroup> replica_groups = {});
   1569 
   1570 // Enqueues an operation that do an AllReduce of the operand cross cores. Here
   1571 // AllReduce means doing a reduction on the input operand cross cores and then
   1572 // broadcasting the reduction result to those cores. The reduction function is
   1573 // defined by `computation`, which should be a commutative computation on
   1574 // scalars, e.g., add, min, or max. The way that AllReduce is applied is
   1575 // configured by:
   1576 //
   1577 // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
   1578 // empty, all replicas belong to one group. Allreduce will be applied within
   1579 // subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}}
   1580 // means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
   1581 //
   1582 // - `channel_id`: for Allreduce nodes from different modules, if they have the
   1583 // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
   1584 // applied cross modules.
   1585 //
   1586 // TODO(b/117564385): Rename this to AllReduce when it's ready to use.
   1587 XlaOp CrossReplicaSum(
   1588     const XlaOp& operand, const XlaComputation& computation,
   1589     absl::Span<const ReplicaGroup> replica_groups = {},
   1590     const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
   1591 
   1592 // Enqueues an operation that do an Alltoall of the operand cross cores.
   1593 XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
   1594                int64 concat_dimension, int64 split_count,
   1595                const std::vector<ReplicaGroup>& replica_groups = {});
   1596 
   1597 // Enqueues an collective operation that sends and receives data cross replicas.
   1598 //
   1599 // - `source_target_pair`: a list of (source_replica_id, target_replica_id)
   1600 // pairs. For each pair, the operand is sent from source replica to target
   1601 // replica. Note that, 1) any two pairs should not have the same target replica
   1602 // id, and they should not have the same source replica id; 2) if a replica id
   1603 // is not a target in any pair, then the output on that replica is a tensor
   1604 // consists of 0(s) with the same shape as the input.
   1605 XlaOp CollectivePermute(
   1606     const XlaOp& operand,
   1607     const std::vector<std::pair<int64, int64>>& source_target_pairs);
   1608 
   1609 // Enqueues an operation that returns the replica ID.
   1610 XlaOp ReplicaId(XlaBuilder* builder);
   1611 
   1612 // Enqueues an operation that scatters the `source` array to the selected
   1613 // indices of each window.
   1614 XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
   1615                        absl::Span<const int64> window_dimensions,
   1616                        absl::Span<const int64> window_strides, Padding padding,
   1617                        const XlaOp& source, const XlaOp& init_value,
   1618                        const XlaComputation& scatter);
   1619 
   1620 // As SelectAndScatter(), but the padding is given in the format
   1621 // returned by MakePadding().
   1622 XlaOp SelectAndScatterWithGeneralPadding(
   1623     const XlaOp& operand, const XlaComputation& select,
   1624     absl::Span<const int64> window_dimensions,
   1625     absl::Span<const int64> window_strides,
   1626     absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
   1627     const XlaOp& init_value, const XlaComputation& scatter);
   1628 
   1629 // Enqueues an abs instruction onto the computation.
   1630 XlaOp Abs(const XlaOp& operand);
   1631 
   1632 // Enqueues a atan2 instruction onto the computation.
   1633 XlaOp Atan2(const XlaOp& y, const XlaOp& x,
   1634             absl::Span<const int64> broadcast_dimensions = {});
   1635 
   1636 // Enqueues an exp instruction onto the computation.
   1637 XlaOp Exp(const XlaOp& operand);
   1638 
   1639 // Enqueues an expm1 instruction onto the computation.
   1640 XlaOp Expm1(const XlaOp& operand);
   1641 
   1642 // Enqueues a floor instruction onto the computation.
   1643 XlaOp Floor(const XlaOp& operand);
   1644 
   1645 // Enqueues a ceil instruction onto the computation.
   1646 XlaOp Ceil(const XlaOp& operand);
   1647 
   1648 // Enqueues a round instruction onto the computation, rounding to nearest even
   1649 // with half-way cases rounding away from zero.
   1650 XlaOp Round(const XlaOp& operand);
   1651 
   1652 // Enqueues an log instruction (natural logarithm) onto the computation.
   1653 XlaOp Log(const XlaOp& operand);
   1654 
   1655 // Enqueues an log1p instruction (log(x+1)) onto the computation.
   1656 XlaOp Log1p(const XlaOp& operand);
   1657 
   1658 // Enqueues a sign instruction onto the computation.
   1659 XlaOp Sign(const XlaOp& operand);
   1660 
   1661 // Enqueues a count leading zeros instruction onto the computation.
   1662 XlaOp Clz(const XlaOp& operand);
   1663 
   1664 // Enqueues a cosine instruction onto the computation.
   1665 XlaOp Cos(const XlaOp& operand);
   1666 
   1667 // Enqueues a sine instruction onto the computation.
   1668 XlaOp Sin(const XlaOp& operand);
   1669 
   1670 // Enqueues a tanh instruction onto the computation.
   1671 XlaOp Tanh(const XlaOp& operand);
   1672 
   1673 // Enqueues a real-part instruction onto the computation.
   1674 XlaOp Real(const XlaOp& operand);
   1675 
   1676 // Enqueues an imaginary-part instruction onto the computation.
   1677 XlaOp Imag(const XlaOp& operand);
   1678 
   1679 // Enqueues a sqrt computation onto the computation.
   1680 XlaOp Sqrt(const XlaOp& operand);
   1681 
   1682 // Enqueues a rsqrt computation onto the computation.
   1683 XlaOp Rsqrt(const XlaOp& operand);
   1684 
   1685 // Enqueues a lhs^rhs computation onto the computation.
   1686 XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
   1687           absl::Span<const int64> broadcast_dimensions = {});
   1688 
   1689 // Enqueues an operator that tests if the operand's values are finite, i.e., not
   1690 // +/-Inf or NaN.  Returns an array of booleans with the same shape where
   1691 // entries are true iff the corresponding entry was not infinite or NaN.
   1692 //
   1693 // Defined only for real-valued (i.e. not complex) floating-point types; raises
   1694 // an error for other types.
   1695 //
   1696 // See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h.
   1697 XlaOp IsFinite(const XlaOp& operand);
   1698 
   1699 // Enqueues an iota operation onto the computation.
   1700 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension);
   1701 
   1702 // Enqueues a rank-1 iota operation onto the computation.
   1703 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
   1704 
   1705 // Enqueues a convert instruction onto the computation that changes the
   1706 // element type of the operand array to primitive_type.
   1707 XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type);
   1708 
   1709 // Enqueues a no-op instruction onto the computation that changes
   1710 // the element type of the operand array to primitive_type. The
   1711 // bit-widths of the source and destination element types must be
   1712 // identical.
   1713 XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type);
   1714 
   1715 // Enqueues a negate instruction onto the computation.
   1716 XlaOp Neg(const XlaOp& operand);
   1717 
   1718 // Enqueues a transpose instruction onto the computation.
   1719 XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);
   1720 
   1721 // Enqueues a reverse instruction onto the computation. The order of the
   1722 // elements in the given dimensions is reversed (i.e., the element at index i
   1723 // is moved to index dimension_size - 1 - i).
   1724 XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
   1725 
   1726 // Enqueues a sort (as increasing order) instruction onto the computation.
   1727 // If only keys are provided:
   1728 // * If the keys are an rank-1 tensor (an array), the result is a sorted array
   1729 // of keys, in ascending order.
   1730 // * If the keys have higher rank, the keys are sorted along the provided
   1731 // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
   1732 // value of 0 will independently sort every column, and a dimension value of 1
   1733 // will independently sort each row. If no dimension number is provided, then
   1734 // the last dimension is chosen by default.
   1735 //
   1736 // If both keys and values are provided:
   1737 // * The keys and all values must be tensors with the same dimensions. The
   1738 // element types of the tensors may be different.
   1739 // * The result is a tuple that consists of a sorted tensor of keys (along the
   1740 // provided dimension, as above) as the first element, and tensors with their
   1741 // corresponding values as the other elements.
   1742 ABSL_DEPRECATED("Use form with comparator computation instead")
   1743 XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
   1744            int64 dimension = -1);
   1745 
   1746 // Enqueues a sort instruction onto the computation, using 'comparator' for
   1747 // comparisons. 'comparator' needs to define a strict weak order. 'is_stable'
   1748 // determines whether the stable sorting should be used.
   1749 // If only one operand is provided:
   1750 // * If the operand is a rank-1 tensor (an array), the result is a sorted array.
   1751 //   The resulting sorting order has the property that for all index positions
   1752 //   i, j with i < j, either
   1753 //   comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or
   1754 //   comparator(value[i], value[j]) = true.
   1755 // * If the operand has higher rank, the operand is sorted along the provided
   1756 //   dimension. For example, for a rank-2 tensor (a matrix), a dimension value
   1757 //   of 0 will independently sort every column, and a dimension value of 1 will
   1758 //   independently sort each row. If no dimension number is provided, then the
   1759 //   last dimension is chosen by default. For the dimension which is sorted, the
   1760 //   same sorting order applies as in the rank-1 case.
   1761 //
   1762 // If more than one operand is provided:
   1763 // * All operands must be tensors with the same dimensions. The element types of
   1764 //   the tensors may be different.
   1765 // * The result is a tuple that consists of the operands in sorted order (along
   1766 //   the provided dimension, as above). The same permutation as implied by the
   1767 //   comparison computation is applied to all operand tensors. When comparing
   1768 //   two index positions, 'comparator' is called with 2 * n scalar parameters,
   1769 //   where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at
   1770 //   two index positions.
   1771 // Default comparator computations can be found in lib/comparators.h
   1772 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
   1773            int64 dimension = -1, bool is_stable = false);
   1774 
   1775 // Enqueues a clamp instruction onto the computation.
   1776 XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
   1777 
   1778 // Enqueues a map instruction onto the computation.
   1779 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
   1780           const XlaComputation& computation, absl::Span<const int64> dimensions,
   1781           absl::Span<const XlaOp> static_operands = {});
   1782 
   1783 // Enqueues a N(mu, sigma) random number generation instruction onto the
   1784 // computation.
   1785 XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
   1786 
   1787 // Enqueues a U(a, b) random number generation instruction onto the
   1788 // computation. Returns values in the semi-open interval [a, b).
   1789 XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
   1790 
   1791 // Enqueues a while node onto the computation.
   1792 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
   1793             const XlaOp& init);
   1794 
   1795 // Enqueues a conditional node onto the computation.
   1796 XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
   1797                   const XlaComputation& true_computation,
   1798                   const XlaOp& false_operand,
   1799                   const XlaComputation& false_computation);
   1800 
   1801 // Enqueues either a predicated (if/else) or indexed (switch/case/default)
   1802 // conditional node onto the computation. N >= 1 branch_computations and
   1803 // branch_operands are matched by index. branch_index selects the branch that
   1804 // will be executed. Out of range branch_index uses the N-1'th
   1805 // branch_computation as default.
   1806 XlaOp Conditional(const XlaOp& branch_index,
   1807                   absl::Span<const XlaComputation* const> branch_computations,
   1808                   absl::Span<const XlaOp> branch_operands);
   1809 
   1810 // Enqueues a ReducePrecision node onto the computation.
   1811 XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
   1812                       const int mantissa_bits);
   1813 
   1814 // Enqueues a Gather node onto the computation.
   1815 XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
   1816              const GatherDimensionNumbers& dimension_numbers,
   1817              absl::Span<const int64> slice_sizes);
   1818 
   1819 // Enqueues a Scatter node onto the computation.
   1820 XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
   1821               const XlaOp& updates, const XlaComputation& update_computation,
   1822               const ScatterDimensionNumbers& dimension_numbers);
   1823 
   1824 // Enqueues a Send node onto the computation for device-to-device
   1825 // communication. This operation sends the given operand to
   1826 // a Recv instruction in a different computation that shares the same channel
   1827 // handle.
   1828 void Send(const XlaOp& operand, const ChannelHandle& handle);
   1829 
   1830 // Variant of Send which takes a token-shaped operand and produces a
   1831 // token-shaped value.  Tokens are used for ordering side-effecting operations.
   1832 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
   1833 XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
   1834                     const ChannelHandle& handle);
   1835 
   1836 // Enqueues a Recv node onto the computation for device-to-device
   1837 // communication. The data comes from a Send instruction in a different
   1838 // computation that shares the same channel handle and its shape must be the
   1839 // same as the given shape.
   1840 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
   1841            const ChannelHandle& handle);
   1842 
   1843 // Variant of Recv which takes a token-shaped operand and produces a two-element
   1844 // tuple containing the data value and a token-shaped value. Tokens are used
   1845 // for ordering side-effecting operations.
   1846 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
   1847 XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
   1848                     const ChannelHandle& handle);
   1849 
   1850 // Enqueues a Send node which transfers data from the device to the host. The
   1851 // 'shape_with_layout' argument defines the layout of the data transferred; its
   1852 // shape must be compatible with the shape of the operand. The operand must be
   1853 // array-shaped.
   1854 // TODO(b/111544877): Support tuple shapes.
   1855 XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
   1856                  const Shape& shape_with_layout, const ChannelHandle& handle);
   1857 
   1858 // Enqueues a Recv node which transfers data from the host to the device. The
   1859 // given shape must contain a layout and must be an array.
   1860 // TODO(b/111544877): Support tuple shapes.
   1861 XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
   1862                    const ChannelHandle& handle);
   1863 
   1864 // Enqueues an operation (AfterAll) with no operands that produces a
   1865 // token-shaped value.  Tokens are used for ordering side-effecting operations.
   1866 // This is a separate method from AfterAll to facility the removal of
   1867 // operand-less AfterAll instructions.
   1868 // TODO(b/110532604): Remove this function when all tokens are derived from a
   1869 // single token generated or passed into the entry computation.
   1870 XlaOp CreateToken(XlaBuilder* builder);
   1871 
   1872 // Enqueues an AfterAll instruction which produces a token-shaped value and
   1873 // takes a variadic number of token-shaped operands. The number of operands must
   1874 // be greater than zero. Used for joining tokens.
   1875 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
   1876 
   1877 // Normalizes operand across spatial and batch dimensions for each feature.
   1878 //
   1879 // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
   1880 // is the normalized result and batch_mean and batch_var are the mean and
   1881 // variance, respectively, across batch for the operand.
   1882 XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
   1883                         const XlaOp& offset, float epsilon,
   1884                         int64 feature_index);
   1885 
   1886 // Normalizes operand across spatial and batch dimensions for each feature.
   1887 //
   1888 // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
   1889 // computing `mean` and `variance` for each batch inside the operation. It
   1890 // uses the input `mean` and `variance` instead as estimated values. The
   1891 // purpose of this op is to reduce latency in inference, hence the name
   1892 // `BatchNormInference`.
   1893 //
   1894 // The output has the same shape as `operand`, and contains the normalized
   1895 // values for each batch.
   1896 XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
   1897                          const XlaOp& offset, const XlaOp& mean,
   1898                          const XlaOp& variance, float epsilon,
   1899                          int64 feature_index);
   1900 
   1901 // Calculates the gradients of a batch norm op.
   1902 //
   1903 // The inputs `batch_mean` and `batch_var` represent the mean and variance
   1904 // across the batch.
   1905 //
   1906 // Returns a tuple of three elements:
   1907 //   - grad_operand: Gradient with respect to input `operand`
   1908 //   - grad_offset: Gradient with respect to input `offset`
   1909 //   - grad_scale: Gradient with respect to input `scale`
   1910 XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
   1911                     const XlaOp& batch_mean, const XlaOp& batch_var,
   1912                     const XlaOp& grad_output, float epsilon,
   1913                     int64 feature_index);
   1914 
   1915 // Returns the size of the given dimension of the operand. The operand must be
   1916 // array shaped.
   1917 XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension);
   1918 
   1919 // Implementation details below this point.
   1920 //
   1921 
   1922 // Free function template implementations.
   1923 
   1924 template <typename NativeT>
   1925 XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
   1926   return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
   1927 }
   1928 
   1929 template <typename NativeT>
   1930 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
   1931   return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values));
   1932 }
   1933 
   1934 template <typename NativeT>
   1935 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
   1936   Literal literal(ShapeUtil::MakeShape(
   1937       primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
   1938   literal.PopulateWithValue(value);
   1939   return ConstantLiteral(builder, literal);
   1940 }
   1941 
   1942 inline XlaOp ConstantR1(XlaBuilder* builder,
   1943                         const tensorflow::core::Bitmap& values) {
   1944   return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
   1945 }
   1946 
   1947 template <typename NativeT>
   1948 XlaOp ConstantR2(XlaBuilder* builder,
   1949                  std::initializer_list<std::initializer_list<NativeT>> values) {
   1950   return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
   1951 }
   1952 
   1953 template <typename NativeT>
   1954 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
   1955                                   const Array<NativeT>& values,
   1956                                   const Layout& layout) {
   1957   return ConstantLiteral(
   1958       builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
   1959 }
   1960 
   1961 template <typename NativeT>
   1962 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
   1963   return ConstantLiteral(builder,
   1964                          LiteralUtil::CreateFromArray<NativeT>(values));
   1965 }
   1966 
   1967 template <typename NativeT>
   1968 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
   1969                                       const Array2D<NativeT>& values,
   1970                                       const Layout& layout) {
   1971   return ConstantLiteral(
   1972       builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
   1973 }
   1974 
   1975 template <typename NativeT>
   1976 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
   1977                             const Array2D<NativeT>& values) {
   1978   return ConstantLiteral(builder,
   1979                          LiteralUtil::CreateR2FromArray2D<NativeT>(values));
   1980 }
   1981 
   1982 template <typename NativeT>
   1983 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
   1984                                       const Array3D<NativeT>& values,
   1985                                       const Layout& layout) {
   1986   return ConstantLiteral(
   1987       builder,
   1988       LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
   1989 }
   1990 
   1991 template <typename NativeT>
   1992 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
   1993                             const Array3D<NativeT>& values) {
   1994   return ConstantFromArray(builder, values);
   1995 }
   1996 
   1997 template <typename NativeT>
   1998 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
   1999                                       const Array4D<NativeT>& values,
   2000                                       const Layout& layout) {
   2001   return ConstantFromArrayWithLayout(builder, values, layout);
   2002 }
   2003 
   2004 template <typename NativeT>
   2005 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
   2006                             const Array4D<NativeT>& values) {
   2007   return ConstantFromArray(builder, values);
   2008 }
   2009 
   2010 }  // namespace xla
   2011 
   2012 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
   2013