1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ 17 #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ 18 19 #include <map> 20 #include <string> 21 #include <type_traits> 22 #include <utility> 23 24 #include "absl/container/flat_hash_map.h" 25 #include "absl/container/flat_hash_set.h" 26 #include "absl/strings/string_view.h" 27 #include "absl/types/span.h" 28 #include "tensorflow/compiler/xla/client/padding.h" 29 #include "tensorflow/compiler/xla/client/xla_computation.h" 30 #include "tensorflow/compiler/xla/comparison_util.h" 31 #include "tensorflow/compiler/xla/literal.h" 32 #include "tensorflow/compiler/xla/literal_util.h" 33 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" 34 #include "tensorflow/compiler/xla/service/hlo.pb.h" 35 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 36 #include "tensorflow/compiler/xla/shape_util.h" 37 #include "tensorflow/compiler/xla/status_macros.h" 38 #include "tensorflow/compiler/xla/statusor.h" 39 #include "tensorflow/compiler/xla/types.h" 40 #include "tensorflow/compiler/xla/xla_data.pb.h" 41 #include "tensorflow/core/platform/macros.h" 42 #include "tensorflow/core/platform/stacktrace.h" 43 #include "tensorflow/core/platform/types.h" 44 45 namespace xla { 46 47 class XlaBuilder; 48 49 // This represents an instruction that has been enqueued using the XlaBuilder. 50 // This is used to pass to subsequent computations that depends upon the 51 // instruction as an operand. 52 class XlaOp { 53 public: 54 XlaOp() : handle_(-1), builder_(nullptr) { 55 static_assert(std::is_trivially_destructible<XlaOp>::value, 56 "XlaOp should be trivially destructible"); 57 } 58 ~XlaOp() = default; 59 60 XlaOp(const XlaOp& other) = default; 61 XlaOp& operator=(const XlaOp& other) = default; 62 63 // Precondition: !IsUninitialized(). 64 // 65 // It's very common to do foo.builder()->bar(). Without this precondition, if 66 // foo.builder() is null, the call to bar will segfault at some point possibly 67 // deep in the callstack when we finally dereference `this`. The precondition 68 // lets us avoid this tricky-to-debug problem. 69 XlaBuilder* builder() const { 70 CHECK(builder_ != nullptr); 71 return builder_; 72 } 73 74 // Returns true if the XlaOp represents valid, non-erroneous value. 75 bool valid() const { return handle_ >= 0; } 76 77 // Returns true if the XlaOp was created by the XlaOp() constructor and 78 // not returned by a builder. 79 bool IsUninitialized() const { return builder_ == nullptr; } 80 81 bool IsIdenticalTo(const XlaOp& rhs) const { 82 return handle_ == rhs.handle_ && builder_ == rhs.builder_; 83 } 84 85 friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) { 86 out << op.handle(); 87 return out; 88 } 89 90 private: 91 explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {} 92 XlaOp(int64 handle, XlaBuilder* builder) 93 : handle_(handle), builder_(builder) {} 94 95 int64 handle() const { return handle_; } 96 97 friend class XlaBuilder; 98 99 // < 0 means "invalid handle". 100 int64 handle_; 101 102 // Not owned. Non-null for any handle returned by XlaBuilder, even if the 103 // handle is invalid. 104 XlaBuilder* builder_; 105 }; 106 107 // Arithmetic operator overloads for the XlaOp type. 108 XlaOp operator-(const XlaOp& x); 109 XlaOp operator+(const XlaOp& x, const XlaOp& y); 110 XlaOp operator-(const XlaOp& x, const XlaOp& y); 111 XlaOp operator*(const XlaOp& x, const XlaOp& y); 112 XlaOp operator/(const XlaOp& x, const XlaOp& y); 113 XlaOp operator%(const XlaOp& x, const XlaOp& y); 114 115 // Bitwise operator overloads for the XlaOp type. 116 XlaOp operator~(const XlaOp& x); 117 XlaOp operator&(const XlaOp& x, const XlaOp& y); 118 XlaOp operator|(const XlaOp& x, const XlaOp& y); 119 XlaOp operator^(const XlaOp& x, const XlaOp& y); 120 XlaOp operator<<(const XlaOp& x, const XlaOp& y); 121 // Performs a right arithmetic shift if 'x' is a signed type, otherwise performs 122 // a right logical shift. 123 XlaOp operator>>(const XlaOp& x, const XlaOp& y); 124 125 // We don't overload the relational operators (==, !=, <, <=, >, >=) because the 126 // semantics might be surprising since their result types are usually 'bool'. 127 // Further programmers may expect == to be a structural equality. 128 // We also choose not to overload any of the mutating operators (e.g., +=, -=) 129 // because the semantics might be misleading XLA computations are immutable. 130 131 // A convenient interface for building up computations. 132 // 133 // Thread-compatible. 134 class XlaBuilder { 135 public: 136 // computation_name: name to use for the built computation. 137 XlaBuilder(const string& computation_name); 138 139 XlaBuilder(const XlaBuilder&) = delete; 140 XlaBuilder& operator=(const XlaBuilder&) = delete; 141 142 ~XlaBuilder(); 143 144 // Returns the computation name. 145 const string& name() const { return name_; } 146 147 // Sets OpMetadata that will be added to all instructions until cleared. 148 // 149 // OpMetadata is often applied to a series of XLA HLO instructions. As a 150 // result, OpMetadata is set on the Computation Builder. All subsequent 151 // instructions generated via this Computation Builder will have the same 152 // OpMetadata attached until a call to ClearOpMetadata. 153 void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } 154 155 // Clears the HloMetadata state. 156 void ClearOpMetadata() { metadata_.Clear(); } 157 158 // Sets an OpSharding that will be attached to all instructions until cleared. 159 void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } 160 161 // Clears the sharding. Ops will be sharded according to the default placement 162 // policy. 163 void ClearSharding() { sharding_ = absl::nullopt; } 164 165 // Returns the OpSharding that will be attached to all instructions. 166 const absl::optional<OpSharding>& sharding() const { return sharding_; } 167 168 // Sets the builder to a mode where it will die immediately when an error is 169 // encountered, rather than producing it in a deferred fashion when Build() is 170 // called (which is the default). 171 void set_die_immediately_on_error(bool enabled) { 172 die_immediately_on_error_ = enabled; 173 } 174 175 // Default dimension numbers used for a 2D convolution. 176 static constexpr int64 kConvBatchDimension = 0; 177 static constexpr int64 kConvFeatureDimension = 1; 178 static constexpr int64 kConvFirstSpatialDimension = 2; 179 static constexpr int64 kConvSecondSpatialDimension = 3; 180 static constexpr int64 kConvKernelOutputDimension = 0; 181 static constexpr int64 kConvKernelInputDimension = 1; 182 static constexpr int64 kConvKernelFirstSpatialDimension = 2; 183 static constexpr int64 kConvKernelSecondSpatialDimension = 3; 184 185 // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for 186 // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for 187 // the kernel operand 188 // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. 189 static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( 190 int num_spatial_dims = 2); 191 192 // Returns an error if the convolution dimension numbers have conflicts. 193 static Status Validate(const ConvolutionDimensionNumbers& dnum); 194 195 // Returns a new XlaBuilder whose resultant Computation is used only by this 196 // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error 197 // behavior as the parent. 198 std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name); 199 200 // Builds the computation with the requested operations, or returns a non-ok 201 // status. Note that all ops that have been enqueued will be moved to the 202 // computation being returned. The root of the computation will be the last 203 // added operation. 204 // 205 // `remove_dynamic_dimensions` tells the builder whether to remove the 206 // dyanmic dimensions information in all ops. 207 // 208 // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the 209 // dynamic dimensions information when XLA backend can handle dynamic 210 // dimensions. 211 StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = true); 212 213 // Overload of Build which specifies a particular root instruction for the 214 // computation. 215 StatusOr<XlaComputation> Build(XlaOp root, 216 bool remove_dynamic_dimensions = true); 217 218 // Builds the computation with the requested operations, or notes an error in 219 // the parent XlaBuilder and returns an empty computation if building failed. 220 // This function is intended to be used where the returned XlaComputation is 221 // only used by the parent XlaBuilder and hence further operation on the 222 // returned XlaComputation will simply be error'ed out if an error occurred 223 // while building this computation. If the built computation is to be used by 224 // a XlaBuilder other than the parent XlaBuilder then Build() should be used 225 // instead. 226 XlaComputation BuildAndNoteError(); 227 228 // Returns a subgraph that roots on the given root. If the root is not a 229 // compile-time constant (see `IsConstant`), returns an error. 230 // 231 // This will copy the needed ops/computations to the subgraph. 232 StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op); 233 234 // Returns the first error that was encountered while building the 235 // computation. When an error is encountered, by default we return a vacuous 236 // XlaOp and inform the user of the error that occurred while 237 // building the computation when they make a final call to Build(). 238 // 239 // See also set_die_immediately_on_error(). 240 Status first_error() const { return first_error_; } 241 242 // Returns the current status of the builder, complete with the stack trace 243 // information. 244 Status GetCurrentStatus() const; 245 246 // Returns the shape of the given op. 247 StatusOr<Shape> GetShape(const XlaOp& op) const; 248 249 // Returns the (inferred) result for the current computation's shape. This 250 // assumes the root instruction is the last added instruction. 251 StatusOr<ProgramShape> GetProgramShape() const; 252 253 // Returns the (inferred) result for the current computation's shape using the 254 // given operation as the root. 255 StatusOr<ProgramShape> GetProgramShape(XlaOp root) const; 256 257 // Reports an error to the builder, by 258 // * storing it internally and capturing a backtrace if it's the first error 259 // (this deferred value will be produced on the call to 260 // Build()/GetShape()/...) 261 // * dying if die_immediately_on_error_ is true. 262 // Returns an XlaOp with an invalid handle but a valid builder. This value can 263 // be returned in place of a value in APIs that return an XlaOp. 264 XlaOp ReportError(const Status& error); 265 266 // A helper function that converts a StatusOr<XlaOp> into an XlaOp. 267 // If the Status was an error, reports the error to builder and returns an 268 // invalid XlaOp handle. 269 XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op); 270 271 // A helper function that runs a function that returns a StatusOr<XlaOp> and 272 // returns an XlaOp. 273 XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator); 274 275 // Returns true if 'operand' is a compile-time constant. A compile-time 276 // constant does not depend on any parameters, or on stateful operators such 277 // as `RngNormal` or `Infeed`. 278 // 279 // This tests whether a computation is a compile-time constant without 280 // evaluating the computation. 281 StatusOr<bool> IsConstant(const XlaOp& operand) const; 282 283 // Sets up binding which indicates that the `target_dim_num` in the subshape 284 // `target_param_index` of parameter `target_param_num` is a dynamic dimension 285 // and its real dynamic size is represented by `dynamic_param_index` in 286 // parameter `dynamic_param_num`. 287 // 288 // Note that this should be called before the dynamic parameters are used to 289 // create other operations, otherwise created operations won't have the 290 // dynamic dimensions information. 291 // 292 // TODO(b/119520625): Remove this API once we have more dynamic shape infra 293 // ready. 294 Status SetDynamicBinding(int64 dynamic_size_param_num, 295 ShapeIndex dynamic_size_param_index, 296 int64 target_param_num, 297 ShapeIndex target_param_index, int64 target_dim_num); 298 299 // Adds a new input/output alias. Since the input/ouput shape information are 300 // not available until the computation is built, and eventual error in the 301 // arguments of this API will be detected only at computation Build() time. 302 void SetUpAlias(const ShapeIndex& output_index, int64 param_number, 303 const ShapeIndex& param_index) { 304 input_output_aliases_.push_back({output_index, param_number, param_index}); 305 } 306 307 // Describes an input/output alias as inserted by the SetUpAlias() API. 308 struct InputOutputAlias { 309 // Specifies the index of the aliased buffer in the result tuple. 310 ShapeIndex output_index; 311 // Specifies the parameter containing the buffer to be aliased. 312 int64 param_number; 313 // Specifies the index of the aliased buffer in the parameter 314 ShapeIndex param_index; 315 }; 316 317 private: 318 // Build helper which takes the id of the root operation.. 319 StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions); 320 321 // Description for the methods below can be found in the corresponding public 322 // functions section in this file. 323 324 XlaOp Parameter(int64 parameter_number, const Shape& shape, 325 const string& name); 326 327 XlaOp ConstantLiteral(const LiteralSlice& literal); 328 329 XlaOp Broadcast(const XlaOp& operand, 330 absl::Span<const int64> broadcast_sizes); 331 332 XlaOp BroadcastInDim(const XlaOp& operand, 333 const absl::Span<const int64> out_dim_size, 334 const absl::Span<const int64> broadcast_dimensions); 335 336 XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, 337 const PaddingConfig& padding_config); 338 339 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions, 340 absl::Span<const int64> new_sizes); 341 342 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes); 343 344 XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions); 345 346 XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices, 347 absl::Span<const int64> limit_indices, 348 absl::Span<const int64> strides); 349 350 XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, 351 int64 stride, int64 dimno); 352 353 ABSL_DEPRECATED("Use span-of-indices form instead") 354 XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, 355 absl::Span<const int64> slice_sizes); 356 XlaOp DynamicSlice(const XlaOp& operand, 357 absl::Span<const XlaOp> start_indices, 358 absl::Span<const int64> slice_sizes); 359 360 ABSL_DEPRECATED("Use span-of-indices form instead") 361 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, 362 const XlaOp& start_indices); 363 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, 364 absl::Span<const XlaOp> start_indices); 365 366 XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension); 367 368 void Trace(const string& tag, const XlaOp& operand); 369 370 XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); 371 372 XlaOp Tuple(absl::Span<const XlaOp> elements); 373 374 XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); 375 376 XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, 377 const PrecisionConfig* precision_config = nullptr); 378 379 XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, 380 const DotDimensionNumbers& dimension_numbers, 381 const PrecisionConfig* precision_config = nullptr); 382 383 XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, 384 absl::Span<const int64> window_strides, Padding padding, 385 int64 feature_group_count = 1, int64 batch_group_count = 1, 386 const PrecisionConfig* precision_config = nullptr); 387 388 XlaOp ConvWithGeneralPadding( 389 const XlaOp& lhs, const XlaOp& rhs, 390 absl::Span<const int64> window_strides, 391 absl::Span<const std::pair<int64, int64>> padding, 392 int64 feature_group_count = 1, int64 batch_group_count = 1, 393 const PrecisionConfig* precision_config = nullptr); 394 395 XlaOp ConvWithGeneralDimensions( 396 const XlaOp& lhs, const XlaOp& rhs, 397 absl::Span<const int64> window_strides, Padding padding, 398 const ConvolutionDimensionNumbers& dimension_numbers, 399 int64 feature_group_count = 1, int64 batch_group_count = 1, 400 const PrecisionConfig* precision_config = nullptr); 401 402 XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, 403 absl::Span<const int64> window_strides, 404 absl::Span<const std::pair<int64, int64>> padding, 405 const ConvolutionDimensionNumbers& dimension_numbers, 406 int64 feature_group_count = 1, int64 batch_group_count = 1, 407 const PrecisionConfig* precision_config = nullptr); 408 409 XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, 410 absl::Span<const int64> window_strides, 411 absl::Span<const std::pair<int64, int64>> padding, 412 absl::Span<const int64> lhs_dilation, 413 absl::Span<const int64> rhs_dilation, 414 const ConvolutionDimensionNumbers& dimension_numbers, 415 int64 feature_group_count = 1, 416 int64 batch_group_count = 1, 417 const PrecisionConfig* precision_config = nullptr); 418 419 XlaOp Fft(const XlaOp& operand, FftType fft_type, 420 absl::Span<const int64> fft_length); 421 422 XlaOp Infeed(const Shape& shape, const string& config = ""); 423 XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, 424 const string& config = ""); 425 426 void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, 427 const string& outfeed_config); 428 XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, 429 const Shape& shape_with_layout, 430 const string& outfeed_config); 431 432 XlaOp Call(const XlaComputation& computation, 433 absl::Span<const XlaOp> operands); 434 435 XlaOp CustomCall( 436 const string& call_target_name, absl::Span<const XlaOp> operands, 437 const Shape& shape_with_layout, const string& opaque, 438 absl::optional<absl::Span<const Shape>> operand_shapes_with_layout); 439 440 XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, 441 const XlaComputation& computation, 442 absl::Span<const int64> dimensions_to_reduce); 443 444 XlaOp Reduce(absl::Span<const XlaOp> operands, 445 absl::Span<const XlaOp> init_values, 446 const XlaComputation& computation, 447 absl::Span<const int64> dimensions_to_reduce); 448 449 XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, 450 const XlaComputation& computation); 451 452 XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, 453 const XlaComputation& computation, 454 absl::Span<const int64> window_dimensions, 455 absl::Span<const int64> window_strides, Padding padding); 456 457 XlaOp ReduceWindowWithGeneralPadding( 458 const XlaOp& operand, const XlaOp& init_value, 459 const XlaComputation& computation, 460 absl::Span<const int64> window_dimensions, 461 absl::Span<const int64> window_strides, 462 absl::Span<const int64> base_dilations, 463 absl::Span<const int64> window_dilations, 464 absl::Span<const std::pair<int64, int64>> padding); 465 466 XlaOp CrossReplicaSum(const XlaOp& operand, 467 absl::Span<const ReplicaGroup> replica_groups = {}); 468 469 XlaOp CrossReplicaSum( 470 const XlaOp& operand, const XlaComputation& computation, 471 absl::Span<const ReplicaGroup> replica_groups = {}, 472 const absl::optional<ChannelHandle>& channel_id = absl::nullopt); 473 474 XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, 475 int64 concat_dimension, int64 split_count, 476 const std::vector<ReplicaGroup>& replica_groups); 477 478 XlaOp CollectivePermute( 479 const XlaOp& operand, 480 const std::vector<std::pair<int64, int64>>& source_target_pairs); 481 482 XlaOp ReplicaId(); 483 484 XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, 485 absl::Span<const int64> window_dimensions, 486 absl::Span<const int64> window_strides, 487 Padding padding, const XlaOp& source, 488 const XlaOp& init_value, 489 const XlaComputation& scatter); 490 491 XlaOp SelectAndScatterWithGeneralPadding( 492 const XlaOp& operand, const XlaComputation& select, 493 absl::Span<const int64> window_dimensions, 494 absl::Span<const int64> window_strides, 495 absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source, 496 const XlaOp& init_value, const XlaComputation& scatter); 497 498 XlaOp Iota(const Shape& shape, int64 iota_dimension); 499 500 XlaOp Iota(PrimitiveType type, int64 size); 501 502 XlaOp ConvertElementType(const XlaOp& operand, 503 PrimitiveType new_element_type); 504 505 XlaOp BitcastConvertType(const XlaOp& operand, 506 PrimitiveType new_element_type); 507 508 XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation); 509 510 XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions); 511 512 ABSL_DEPRECATED("Use form with comparator computation instead") 513 XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {}, 514 int64 dimension = -1); 515 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator, 516 int64 dimension = -1, bool is_stable = false); 517 518 XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); 519 520 XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation, 521 absl::Span<const int64> dimensions, 522 absl::Span<const XlaOp> static_operands = {}); 523 524 XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); 525 526 XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); 527 528 XlaOp While(const XlaComputation& condition, const XlaComputation& body, 529 const XlaOp& init); 530 531 XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, 532 const XlaComputation& true_computation, 533 const XlaOp& false_operand, 534 const XlaComputation& false_computation); 535 536 XlaOp Conditional(const XlaOp& branch_index, 537 absl::Span<const XlaComputation* const> branch_computations, 538 absl::Span<const XlaOp> branch_operands); 539 540 XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, 541 const int mantissa_bits); 542 543 XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, 544 const GatherDimensionNumbers& dimension_numbers, 545 absl::Span<const int64> slice_sizes); 546 547 XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, 548 const XlaOp& updates, const XlaComputation& update_computation, 549 const ScatterDimensionNumbers& dimension_numbers); 550 551 void Send(const XlaOp& operand, const ChannelHandle& handle); 552 XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, 553 const ChannelHandle& handle); 554 555 XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, 556 const Shape& shape_with_layout, const ChannelHandle& handle); 557 558 XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, 559 const ChannelHandle& handle); 560 561 XlaOp CreateToken(); 562 563 XlaOp AfterAll(absl::Span<const XlaOp> tokens); 564 565 XlaOp Recv(const Shape& shape, const ChannelHandle& handle); 566 XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, 567 const ChannelHandle& handle); 568 569 XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, 570 const XlaOp& offset, float epsilon, 571 int64 feature_index); 572 573 XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, 574 const XlaOp& offset, const XlaOp& mean, 575 const XlaOp& variance, float epsilon, 576 int64 feature_index); 577 578 XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, 579 const XlaOp& batch_mean, const XlaOp& batch_var, 580 const XlaOp& grad_output, float epsilon, 581 int64 feature_index); 582 583 XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); 584 585 StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, 586 absl::Span<const XlaOp> operands = {}); 587 588 void AddCalledComputation(const XlaComputation& computation, 589 HloInstructionProto* instr); 590 591 StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const; 592 StatusOr<const HloInstructionProto*> LookUpInstructionByHandle( 593 int64 handle) const; 594 595 // Internal helper method that does the building for an arbitrary unary op. 596 XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); 597 598 // Internal helper method that does the building for an arbitrary binary op. 599 // broadcast_dimensions specifies which dimensions to use for broadcasting 600 // when the operation is between tensors of different ranks. The direction is 601 // only used if opcode is kCompare. 602 XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, 603 absl::Span<const int64> broadcast_dimensions, 604 absl::optional<ComparisonDirection> direction = absl::nullopt); 605 606 // Internal helper method that does the building for an arbitrary ternary op. 607 XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, 608 const XlaOp& ehs); 609 610 XlaOp RngOp(RandomDistribution distribution, 611 absl::Span<const XlaOp> parameters, const Shape& shape); 612 613 StatusOr<XlaOp> InDimBroadcast(const Shape& shape, const XlaOp& operand, 614 absl::Span<const int64> broadcast_dimensions); 615 616 // Internal helper method that creates a sequence of instructions that 617 // performs an explicit broadcast of the operand to the target shape. 618 StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape, 619 const XlaOp& operand); 620 621 // Internal helper method for creating a Reshape op with the already inferred 622 // shape. 623 StatusOr<XlaOp> Reshape(const Shape& shape, const XlaOp& operand); 624 625 // Returns the (inferred) result for the program shape using the given root. 626 StatusOr<ProgramShape> GetProgramShape(int64 root_id) const; 627 628 // Returns shapes for the operands. 629 StatusOr<std::vector<Shape>> GetOperandShapes( 630 absl::Span<const XlaOp> operands) const; 631 632 // A visitor which checks whether an operation is a compile-time constant, 633 // meaning that it doesn't depend on any parameters, or on any stateful 634 // operation such as `RngNormal` or `Infeed`. The visitor walks the 635 // computation starting at a given operation and sets is_constant to false iff 636 // a parameter or stateful operation is encountered. 637 void IsConstantVisitor(const int64 op_handle, 638 absl::flat_hash_set<int64>* visited, 639 bool* is_constant) const; 640 641 // Checks bounds for convolution parameters. 642 Status VerifyConvolution( 643 const Shape& lhs_shape, const Shape& rhs_shape, 644 const ConvolutionDimensionNumbers& dimension_numbers) const; 645 646 // Helper function for creating a Window proto from user-supplied data. 647 // Returns error if the user-supplied data was invalid. 648 StatusOr<Window> MakeWindow(absl::Span<const int64> window_dimensions, 649 absl::Span<const int64> window_strides, 650 absl::Span<const std::pair<int64, int64>> padding, 651 absl::Span<const int64> lhs_dilation, 652 absl::Span<const int64> rhs_dilation) const; 653 654 int64 GetNextId() { return ++next_id_; } 655 656 // Populates the module with the input/output alias information stored within 657 // the input_output_aliases vector. 658 static Status PopulateInputOutputAlias( 659 HloModuleProto* module, const ProgramShape& program_shape, 660 const std::vector<InputOutputAlias>& input_output_aliases); 661 662 string name_; // Name to use for the built computation. 663 664 // The next sequential ID for every instruction/computation contained within 665 // this computation. 666 int64 next_id_ = 0; 667 668 // The first error encountered while building the computation. 669 // This is OK until the first error is encountered. 670 Status first_error_; 671 672 // The saved stack trace from the point at which the first error occurred. 673 tensorflow::SavedStackTrace first_error_backtrace_; 674 675 // The instructions of this computation. 676 std::vector<HloInstructionProto> instructions_; 677 678 // Dynamic parameter configuration of this computation. 679 DynamicParameterBinding dynamic_parameter_binding_; 680 681 // Holds the input/output alias information populated by the SetUpAlias() API. 682 std::vector<InputOutputAlias> input_output_aliases_; 683 684 // A map from XlaOp::Handle to the index in the instructions_ vector where the 685 // instruction is held. 686 absl::flat_hash_map<int64, int64> handle_to_index_; 687 688 // The embedded computations used by this computation. Each computation was 689 // the entry computation of some XlaComputation, the key is the unique id of 690 // that XlaComputation. 691 std::map<int64, HloComputationProto> embedded_; 692 693 // The unique parameter numbers. 694 absl::flat_hash_set<int64> parameter_numbers_; 695 696 // The metadata to attach to each op. This is structured as a "modal"-like 697 // operation, in order to simplify client code (and not sprinkle this metadata 698 // throughout the TensorFlow op kernel implementations). 699 OpMetadata metadata_; 700 701 // Sharding for this operator. This is structured as a "model"-like operation, 702 // in order to simplify client code, similar to metadata_. 703 absl::optional<OpSharding> sharding_; 704 705 // Mode bit that indicates whether to die when a first error is encountered. 706 bool die_immediately_on_error_ = false; 707 708 XlaBuilder* parent_builder_{nullptr}; 709 710 friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, 711 const Shape& shape, const string& name); 712 friend XlaOp ConstantLiteral(XlaBuilder* builder, 713 const LiteralSlice& literal); 714 715 friend XlaOp Broadcast(const XlaOp& operand, 716 absl::Span<const int64> broadcast_sizes); 717 718 friend XlaOp BroadcastInDim( 719 const XlaOp& operand, const absl::Span<const int64> out_dim_size, 720 const absl::Span<const int64> broadcast_dimensions); 721 722 friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, 723 const PaddingConfig& padding_config); 724 725 friend XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions, 726 absl::Span<const int64> new_sizes); 727 728 friend XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes); 729 730 friend XlaOp Collapse(const XlaOp& operand, 731 absl::Span<const int64> dimensions); 732 733 friend XlaOp Slice(const XlaOp& operand, 734 absl::Span<const int64> start_indices, 735 absl::Span<const int64> limit_indices, 736 absl::Span<const int64> strides); 737 738 friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index, 739 int64 limit_index, int64 stride, int64 dimno); 740 741 friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, 742 absl::Span<const int64> slice_sizes); 743 friend XlaOp DynamicSlice(const XlaOp& operand, 744 absl::Span<const XlaOp> start_indices, 745 absl::Span<const int64> slice_sizes); 746 747 friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, 748 const XlaOp& start_indices); 749 friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, 750 absl::Span<const XlaOp> start_indices); 751 752 friend XlaOp ConcatInDim(XlaBuilder* builder, 753 absl::Span<const XlaOp> operands, int64 dimension); 754 755 friend void Trace(const string& tag, const XlaOp& operand); 756 757 friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true, 758 const XlaOp& on_false); 759 friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements); 760 friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); 761 friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, 762 absl::Span<const int64> broadcast_dimensions); 763 friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, 764 absl::Span<const int64> broadcast_dimensions); 765 friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, 766 absl::Span<const int64> broadcast_dimensions); 767 friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, 768 absl::Span<const int64> broadcast_dimensions); 769 friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, 770 absl::Span<const int64> broadcast_dimensions); 771 friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, 772 absl::Span<const int64> broadcast_dimensions); 773 friend XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs, 774 absl::Span<const int64> broadcast_dimensions, 775 ComparisonDirection direction); 776 friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, 777 const PrecisionConfig* precision_config); 778 friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, 779 const DotDimensionNumbers& dimension_number, 780 const PrecisionConfig* precision_config); 781 friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, 782 absl::Span<const int64> window_strides, Padding padding, 783 int64 feature_group_count, int64 batch_group_count, 784 const PrecisionConfig* precision_config); 785 friend XlaOp ConvWithGeneralPadding( 786 const XlaOp& lhs, const XlaOp& rhs, 787 absl::Span<const int64> window_strides, 788 absl::Span<const std::pair<int64, int64>> padding, 789 int64 feature_group_count, int64 batch_group_count, 790 const PrecisionConfig* precision_config); 791 friend XlaOp ConvWithGeneralDimensions( 792 const XlaOp& lhs, const XlaOp& rhs, 793 absl::Span<const int64> window_strides, Padding padding, 794 const ConvolutionDimensionNumbers& dimension_numbers, 795 int64 feature_group_count, int64 batch_group_count, 796 const PrecisionConfig* precision_config); 797 friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, 798 absl::Span<const int64> window_strides, 799 absl::Span<const std::pair<int64, int64>> padding, 800 const ConvolutionDimensionNumbers& dimension_numbers, 801 int64 feature_group_count, int64 batch_group_count, 802 const PrecisionConfig* precision_config); 803 friend XlaOp ConvGeneralDilated( 804 const XlaOp& lhs, const XlaOp& rhs, 805 absl::Span<const int64> window_strides, 806 absl::Span<const std::pair<int64, int64>> padding, 807 absl::Span<const int64> lhs_dilation, 808 absl::Span<const int64> rhs_dilation, 809 const ConvolutionDimensionNumbers& dimension_numbers, 810 int64 feature_group_count, int64 batch_group_count, 811 const PrecisionConfig* precision_config); 812 friend XlaOp Fft(const XlaOp& operand, FftType fft_type, 813 absl::Span<const int64> fft_length); 814 friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, 815 bool unit_diagonal, 816 TriangularSolveOptions::Transpose transpose_a); 817 friend XlaOp Cholesky(XlaOp a, bool lower); 818 friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, 819 const string& config); 820 friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, 821 const string& outfeed_config); 822 friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, 823 absl::Span<const XlaOp> operands); 824 friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, 825 absl::Span<const XlaOp> operands, const Shape& shape, 826 const string& opaque); 827 friend XlaOp CustomCallWithLayout( 828 XlaBuilder* builder, const string& call_target_name, 829 absl::Span<const XlaOp> operands, const Shape& shape_with_layout, 830 absl::Span<const Shape> operand_shapes_with_layout, const string& opaque); 831 friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, 832 absl::Span<const int64> broadcast_dimensions); 833 friend XlaOp Conj(const XlaOp& operand); 834 friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, 835 absl::Span<const int64> broadcast_dimensions); 836 friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, 837 absl::Span<const int64> broadcast_dimensions); 838 friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, 839 absl::Span<const int64> broadcast_dimensions); 840 friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, 841 absl::Span<const int64> broadcast_dimensions); 842 friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, 843 absl::Span<const int64> broadcast_dimensions); 844 friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, 845 absl::Span<const int64> broadcast_dimensions); 846 friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, 847 absl::Span<const int64> broadcast_dimensions); 848 friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs, 849 absl::Span<const int64> broadcast_dimensions); 850 friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, 851 absl::Span<const int64> broadcast_dimensions); 852 friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, 853 absl::Span<const int64> broadcast_dimensions); 854 friend XlaOp Not(const XlaOp& operand); 855 friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, 856 absl::Span<const int64> broadcast_dimensions); 857 friend XlaOp ShiftRightArithmetic( 858 const XlaOp& lhs, const XlaOp& rhs, 859 absl::Span<const int64> broadcast_dimensions); 860 friend XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, 861 absl::Span<const int64> broadcast_dimensions); 862 friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, 863 const XlaComputation& computation, 864 absl::Span<const int64> dimensions_to_reduce); 865 friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands, 866 absl::Span<const XlaOp> init_values, 867 const XlaComputation& computation, 868 absl::Span<const int64> dimensions_to_reduce); 869 friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, 870 const XlaComputation& computation); 871 friend XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, 872 const XlaComputation& computation, 873 absl::Span<const int64> window_dimensions, 874 absl::Span<const int64> window_strides, 875 Padding padding); 876 friend XlaOp ReduceWindowWithGeneralPadding( 877 const XlaOp& operand, const XlaOp& init_value, 878 const XlaComputation& computation, 879 absl::Span<const int64> window_dimensions, 880 absl::Span<const int64> window_strides, 881 absl::Span<const int64> base_dilations, 882 absl::Span<const int64> window_dilations, 883 absl::Span<const std::pair<int64, int64>> padding); 884 friend XlaOp CrossReplicaSum(const XlaOp& operand, 885 absl::Span<const ReplicaGroup> replica_groups); 886 friend XlaOp CrossReplicaSum(const XlaOp& operand, 887 const XlaComputation& computation, 888 absl::Span<const ReplicaGroup> replica_groups, 889 const absl::optional<ChannelHandle>& channel_id); 890 friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, 891 int64 concat_dimension, int64 split_count, 892 const std::vector<ReplicaGroup>& replica_groups); 893 friend XlaOp CollectivePermute( 894 const XlaOp& operand, 895 const std::vector<std::pair<int64, int64>>& source_target_pairs); 896 friend XlaOp ReplicaId(XlaBuilder* builder); 897 friend XlaOp SelectAndScatter(const XlaOp& operand, 898 const XlaComputation& select, 899 absl::Span<const int64> window_dimensions, 900 absl::Span<const int64> window_strides, 901 Padding padding, const XlaOp& source, 902 const XlaOp& init_value, 903 const XlaComputation& scatter); 904 friend XlaOp SelectAndScatterWithGeneralPadding( 905 const XlaOp& operand, const XlaComputation& select, 906 absl::Span<const int64> window_dimensions, 907 absl::Span<const int64> window_strides, 908 absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source, 909 const XlaOp& init_value, const XlaComputation& scatter); 910 friend XlaOp Abs(const XlaOp& operand); 911 friend XlaOp Atan2(const XlaOp& y, const XlaOp& x, 912 absl::Span<const int64> broadcast_dimensions); 913 friend XlaOp Exp(const XlaOp& operand); 914 friend XlaOp Expm1(const XlaOp& operand); 915 friend XlaOp Floor(const XlaOp& operand); 916 friend XlaOp Ceil(const XlaOp& operand); 917 friend XlaOp Round(const XlaOp& operand); 918 friend XlaOp Log(const XlaOp& operand); 919 friend XlaOp Log1p(const XlaOp& operand); 920 friend XlaOp Sign(const XlaOp& operand); 921 friend XlaOp Clz(const XlaOp& operand); 922 friend XlaOp Cos(const XlaOp& operand); 923 friend XlaOp Sin(const XlaOp& operand); 924 friend XlaOp Tanh(const XlaOp& operand); 925 friend XlaOp Real(const XlaOp& operand); 926 friend XlaOp Imag(const XlaOp& operand); 927 friend XlaOp Sqrt(const XlaOp& operand); 928 friend XlaOp Rsqrt(const XlaOp& operand); 929 friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, 930 absl::Span<const int64> broadcast_dimensions); 931 friend XlaOp IsFinite(const XlaOp& operand); 932 friend XlaOp Iota(XlaBuilder* builder, const Shape& shape, 933 int64 iota_dimension); 934 friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); 935 friend XlaOp ConvertElementType(const XlaOp& operand, 936 PrimitiveType new_element_type); 937 friend XlaOp BitcastConvertType(const XlaOp& operand, 938 PrimitiveType new_element_type); 939 friend XlaOp Neg(const XlaOp& operand); 940 friend XlaOp Transpose(const XlaOp& operand, 941 absl::Span<const int64> permutation); 942 friend XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions); 943 friend XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values, 944 int64 dimension); 945 friend XlaOp Sort(absl::Span<const XlaOp> operands, 946 const XlaComputation& comparator, int64 dimension, 947 bool is_stable); 948 friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); 949 friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands, 950 const XlaComputation& computation, 951 absl::Span<const int64> dimensions, 952 absl::Span<const XlaOp> static_operands); 953 friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, 954 const Shape& shape); 955 friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); 956 friend XlaOp While(const XlaComputation& condition, 957 const XlaComputation& body, const XlaOp& init); 958 friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, 959 const XlaComputation& true_computation, 960 const XlaOp& false_operand, 961 const XlaComputation& false_computation); 962 friend XlaOp Conditional( 963 const XlaOp& branch_index, 964 absl::Span<const XlaComputation* const> branch_computations, 965 absl::Span<const XlaOp> branch_operands); 966 friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, 967 const int mantissa_bits); 968 friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, 969 const GatherDimensionNumbers& dimension_numbers, 970 absl::Span<const int64> slice_sizes); 971 friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, 972 const XlaOp& updates, 973 const XlaComputation& update_computation, 974 const ScatterDimensionNumbers& dimension_numbers); 975 friend void Send(const XlaOp& operand, const ChannelHandle& handle); 976 friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, 977 const ChannelHandle& handle); 978 friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, 979 const XlaOp& offset, float epsilon, 980 int64 feature_index); 981 friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, 982 const XlaOp& offset, const XlaOp& mean, 983 const XlaOp& variance, float epsilon, 984 int64 feature_index); 985 friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, 986 const XlaOp& batch_mean, const XlaOp& batch_var, 987 const XlaOp& grad_output, float epsilon, 988 int64 feature_index); 989 friend XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, 990 const ChannelHandle& handle); 991 friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, 992 const ChannelHandle& handle); 993 friend XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, 994 const Shape& shape_with_layout, 995 const ChannelHandle& handle); 996 friend XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, 997 const ChannelHandle& handle); 998 friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, 999 const string& config); 1000 friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, 1001 const Shape& shape_with_layout, 1002 const string& outfeed_config); 1003 friend XlaOp CreateToken(XlaBuilder* builder); 1004 friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens); 1005 1006 friend XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); 1007 }; 1008 1009 // RAII-style object: sets the current sharding assignment in builder on 1010 // construction, and sets back to the previous assignment on destruction. 1011 class XlaScopedShardingAssignment { 1012 public: 1013 XlaScopedShardingAssignment(xla::XlaBuilder* builder, 1014 absl::optional<OpSharding> sharding) 1015 : builder_(builder), prev_sharding_(builder->sharding()) { 1016 SetSharding(sharding); 1017 } 1018 1019 XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete; 1020 XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) = 1021 delete; 1022 1023 ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } 1024 1025 private: 1026 void SetSharding(const absl::optional<OpSharding>& sharding) { 1027 if (sharding.has_value()) { 1028 builder_->SetSharding(sharding.value()); 1029 } else { 1030 builder_->ClearSharding(); 1031 } 1032 } 1033 1034 xla::XlaBuilder* const builder_; 1035 absl::optional<OpSharding> prev_sharding_; 1036 }; 1037 1038 // Free functions for building XlaOps. The intention is that these will 1039 // become the public API for building XlaOps rather than calling methods on 1040 // XlaBuilder directly. 1041 // 1042 1043 // Enqueues a "retrieve parameter value" instruction for a parameter that was 1044 // passed to the computation. 1045 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape, 1046 const string& name); 1047 1048 // Enqueues a constant with the value of the given literal onto the 1049 // computation. 1050 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); 1051 1052 // Enqueues a constant onto the computation. Methods are templated on the 1053 // native host type (NativeT) which corresponds to a specific XLA 1054 // PrimitiveType as given in the following table: 1055 // 1056 // Native Type PrimitiveType 1057 // ----------------------------- 1058 // bool PRED 1059 // int32 S32 1060 // int64 S64 1061 // uint32 U32 1062 // uint64 U64 1063 // float F32 1064 // double F64 1065 // 1066 // Note: not all primitive types defined in xla_data.proto have a 1067 // corresponding native type yet. 1068 template <typename NativeT> 1069 XlaOp ConstantR0(XlaBuilder* builder, NativeT value); 1070 template <typename NativeT> 1071 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values); 1072 XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values); 1073 template <typename NativeT> 1074 XlaOp ConstantR2(XlaBuilder* builder, 1075 std::initializer_list<std::initializer_list<NativeT>> values); 1076 template <typename NativeT> 1077 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, 1078 const Array<NativeT>& values, 1079 const Layout& layout); 1080 template <typename NativeT> 1081 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values); 1082 template <typename NativeT> 1083 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, 1084 const Array2D<NativeT>& values, 1085 const Layout& layout); 1086 template <typename NativeT> 1087 XlaOp ConstantR2FromArray2D(XlaBuilder* builder, 1088 const Array2D<NativeT>& values); 1089 template <typename NativeT> 1090 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, 1091 const Array3D<NativeT>& values, 1092 const Layout& layout); 1093 template <typename NativeT> 1094 XlaOp ConstantR3FromArray3D(XlaBuilder* builder, 1095 const Array3D<NativeT>& values); 1096 template <typename NativeT> 1097 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, 1098 const Array4D<NativeT>& values, 1099 const Layout& layout); 1100 template <typename NativeT> 1101 XlaOp ConstantR4FromArray4D(XlaBuilder* builder, 1102 const Array4D<NativeT>& values); 1103 1104 // Enqueues a rank one constant (XlaBuilder* builder, vector) onto the 1105 // computation. The vector has size 'length' and every element has the value 1106 // 'value'. 1107 template <typename NativeT> 1108 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); 1109 1110 // Adds dimensions to an array by duplicating the data in the array. 1111 // 1112 // The new dimensions are inserted on the left, i.e. if 1113 // broadcast_sizes has values {a0, ..., aN} and the operand shape 1114 // has dimensions {b0, ..., bM} then the shape of the output has 1115 // dimensions {a0, ..., aN, b0, ..., bM}. 1116 // 1117 // The new dimensions index into copies of the operand, i.e. 1118 // 1119 // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] 1120 XlaOp Broadcast(const XlaOp& operand, absl::Span<const int64> broadcast_sizes); 1121 1122 // This op broadcasts the `operand` to an output with the given `shape`. 1123 // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the 1124 // i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th 1125 // dimension of the output. This also requires that the i'th input dimension is 1126 // either 1 or is the same as the output dimension it's broadcasting into. 1127 // 1128 // For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the 1129 // output shape is s32[2,2]: 1130 // - Specifying {1} as brodcast_dimension will generate output 1131 // {{1, 2}, 1132 // {1, 2}} 1133 // - On the other hand, specifying {0} as broadcast_dimension 1134 // will generate output 1135 // {{1 , 1}, 1136 // {2 , 2}} 1137 XlaOp BroadcastInDim(const XlaOp& operand, 1138 const absl::Span<const int64> out_dim_size, 1139 const absl::Span<const int64> broadcast_dimensions); 1140 1141 // Enqueues a pad operation onto the computation that pads the given value on 1142 // the edges as well as between the elements of the input. padding_config 1143 // specifies the padding amount for each dimension. 1144 XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, 1145 const PaddingConfig& padding_config); 1146 1147 // Enqueues an operation onto the computation that flattens the operand based 1148 // on the dimension order (major/slowest-varying to minor/fastest-varying) 1149 // given, followed by reshaping it into the shape with the given dimension 1150 // sizes (also major to minor). Conceptually, this is a limited form of 1151 // "shape casting". 1152 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions, 1153 absl::Span<const int64> new_sizes); 1154 1155 // Enqueues an operation onto the computation that collapses the operand, from 1156 // first to last dimension (C order), then reshapes it to the given dimension 1157 // sizes. Conceptually, this is a limited form of "shape casting". 1158 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes); 1159 1160 // Wrapper for Reshape. 1161 // Enqueues an operation to collapse the provided dimensions; e.g. an 1162 // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to 1163 // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must 1164 // be a consecutive, in-order subsequence of the operand dimensions. 1165 // 1166 // Note that collapsing a single dimension does nothing: 1167 // 1168 // {256} collapsing {0} => {256} 1169 // {1} collapsing {0} => {1} 1170 // 1171 // Collapsing multiple dimensions produces a single result dimension: 1172 // 1173 // {256, 2} collapsing {0,1} => {512} 1174 // {256, 2, 3} collapsing {0,1} => {512, 3} 1175 // 1176 // This could potentially cause data to be moved -- it provides a more 1177 // structured form of reshaping than an arbitrary Reshape operation. 1178 XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions); 1179 1180 // Enqueues a slice operation onto the computation that slices the operand 1181 // from the start indices to the limit indices; e.g. 1182 // 1183 // x 1184 // [ 0 1 2 3 ] 1185 // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] 1186 // [ 8 9 a b ] 1187 // 1188 // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D 1189 // range notation. 1190 // The strides parameter determines the stride over the slice 1191 XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices, 1192 absl::Span<const int64> limit_indices, 1193 absl::Span<const int64> strides); 1194 1195 // Enqueues a slice operation in a given dimension, taking all other 1196 // dimensions as they are; e.g. if dimno is 1 from start_index 2 to 1197 // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand 1198 // for: 1199 // 1200 // array[:, 2:4:1, :] 1201 XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, 1202 int64 stride, int64 dimno); 1203 1204 // Enqueues a slice operation onto the computation that slices the 'operand' 1205 // from dynamic start indices which are passed in 'start_indices'. 1206 // The size of the slice in each dimension is passed in 'slice_sizes', 1207 // which specify the end point of exclusive slice intervals in each 1208 // dimension [start, start + size). 1209 // The shape of each element of 'start_indices' must be scalar, with the span 1210 // size equal to the rank of the 'operand'. All elements of 'start_indices' must 1211 // have the same shape. 1212 // Slice index calculations are computed modulo input dimension sizes to 1213 // prevent dynamic start indices from generating out-of-bound array accesses. 1214 XlaOp DynamicSlice(const XlaOp& operand, absl::Span<const XlaOp> start_indices, 1215 absl::Span<const int64> slice_sizes); 1216 1217 ABSL_DEPRECATED("Use span-of-indices form instead") 1218 XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, 1219 absl::Span<const int64> slice_sizes); 1220 1221 // Enqueues a dynamic update slice operation onto the computation, which 1222 // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. 1223 // The shape of 'update' determines the shape of the slice of 'operand' 1224 // which is updated. 1225 // The indices specified in 'start_indices' specify the offset of the slice 1226 // of 'operand' which is updated. 1227 // 1228 // update = {10, 11} // calculated at runtime. 1229 // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] 1230 // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] 1231 // [7 8 9] [7 8 9 ] 1232 // 1233 // The shape of each element of 'start_indices' must be scalar, with the span 1234 // size equal to the rank of the 'operand'. All elements of 'start_indices' must 1235 // have the same shape. 1236 // Slice index calculations are computed modulo update dimension sizes to 1237 // prevent dynamic start indices from generating out-of-bound array accesses. 1238 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, 1239 absl::Span<const XlaOp> start_indices); 1240 1241 ABSL_DEPRECATED("Use span-of-indices form instead") 1242 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, 1243 const XlaOp& start_indices); 1244 1245 // Enqueues a concatenate instruction onto the computation. 'operands' must 1246 // have >= 1 entry. 1247 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands, 1248 int64 dimension); 1249 1250 // Enqueue a tracing operation onto the computation; the computation will emit 1251 // a logging message with the operand. 1252 void Trace(const string& tag, const XlaOp& operand); 1253 1254 // Enqueues a conditional-move-like select operation onto the computation; 1255 // predicated on pred, selects between on_true and on_false. 1256 XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); 1257 1258 // Enqueues a tuple-creation instruction onto the computation. 1259 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements); 1260 1261 // Enqueues a tuple-element-get instruction onto the computation. 1262 XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); 1263 1264 // Enqueues an equal-to comparison instruction onto the computation. 1265 XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, 1266 absl::Span<const int64> broadcast_dimensions = {}); 1267 1268 // Enqueues a not-equal comparison instruction onto the computation. 1269 XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, 1270 absl::Span<const int64> broadcast_dimensions = {}); 1271 1272 // Enqueues a greater-or-equal comparison instruction onto the computation. 1273 XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, 1274 absl::Span<const int64> broadcast_dimensions = {}); 1275 1276 // Enqueues a greater-than comparison instruction onto the computation. 1277 XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, 1278 absl::Span<const int64> broadcast_dimensions = {}); 1279 1280 // Enqueues a less-than comparison instruction onto the computation. 1281 XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, 1282 absl::Span<const int64> broadcast_dimensions = {}); 1283 1284 // Enqueues a less-or-equal comparison instruction onto the computation. 1285 XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, 1286 absl::Span<const int64> broadcast_dimensions = {}); 1287 1288 // Enqueues a comparison instruction onto the computation. 1289 XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs, 1290 absl::Span<const int64> broadcast_dimensions, 1291 ComparisonDirection direction); 1292 1293 // Enqueues a dot instruction onto the computation. 1294 XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, 1295 const PrecisionConfig* precision_config = nullptr); 1296 1297 // Enqueues a general dot instruction onto the computation. 1298 XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, 1299 const DotDimensionNumbers& dimension_numbers, 1300 const PrecisionConfig* precision_config = nullptr); 1301 1302 // Enqueues a convolution instruction onto the computation, which uses the 1303 // default convolution dimension numbers. 1304 XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, 1305 absl::Span<const int64> window_strides, Padding padding, 1306 int64 feature_group_count = 1, int64 batch_group_count = 1, 1307 const PrecisionConfig* precision_config = nullptr); 1308 1309 // Enqueues a convolution instruction onto the computation, with the caller 1310 // provided padding configuration in the format returned by MakePadding(). 1311 XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, 1312 absl::Span<const int64> window_strides, 1313 absl::Span<const std::pair<int64, int64>> padding, 1314 int64 feature_group_count = 1, 1315 int64 batch_group_count = 1, 1316 const PrecisionConfig* precision_config = nullptr); 1317 1318 // Enqueues a convolution instruction onto the computation, with the caller 1319 // provided dimension numbers configuration. 1320 XlaOp ConvWithGeneralDimensions( 1321 const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides, 1322 Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, 1323 int64 feature_group_count = 1, int64 batch_group_count = 1, 1324 const PrecisionConfig* precision_config = nullptr); 1325 1326 // Enqueues a convolution instruction onto the computation, with the caller 1327 // provided padding configuration as well as the dimension numbers. 1328 XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, 1329 absl::Span<const int64> window_strides, 1330 absl::Span<const std::pair<int64, int64>> padding, 1331 const ConvolutionDimensionNumbers& dimension_numbers, 1332 int64 feature_group_count = 1, int64 batch_group_count = 1, 1333 const PrecisionConfig* precision_config = nullptr); 1334 1335 // Enqueues a convolution instruction onto the computation, with the caller 1336 // provided padding configuration, dilation factors and dimension numbers. 1337 XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, 1338 absl::Span<const int64> window_strides, 1339 absl::Span<const std::pair<int64, int64>> padding, 1340 absl::Span<const int64> lhs_dilation, 1341 absl::Span<const int64> rhs_dilation, 1342 const ConvolutionDimensionNumbers& dimension_numbers, 1343 int64 feature_group_count = 1, 1344 int64 batch_group_count = 1, 1345 const PrecisionConfig* precision_config = nullptr); 1346 1347 // Enqueues an FFT instruction onto the computation, of the given type and 1348 // with the given FFT length. 1349 XlaOp Fft(const XlaOp& operand, FftType fft_type, 1350 absl::Span<const int64> fft_length); 1351 1352 // Solves systems of linear equations with lower or upper triangular coefficient 1353 // matrices by forward- or back-substitution. Broadcasting along leading 1354 // dimensions, this routine solves for x in one of the matrix systems 1355 // `op(a) * x = b`, or `x * op(a) = b`, 1356 // for the variable `x` given `a` and `b`, where `op(a)` is either 1357 // `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`. 1358 // 1359 // * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form 1360 // square matrices. If `lower` is true (false), then the strictly upper 1361 // (lower) triangular part of each innermost matrix in `a` is assumed to be 1362 // zero and is not accessed. 1363 // * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a 1364 // tensor of shape `[..., K, M]`. 1365 // * `left_side` is a boolean, indicating whether to solve a system of the form 1366 // op(a) * x = b (true) or x * op(a) = b (false). 1367 // * `lower` is a boolean, indicating whether the argument `a` is 1368 // lower-triangular (true) or upper-triangular (false). 1369 // * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be 1370 // 1 and not accessed. 1371 // * `transpose_a` indicates which function `op` we use to transform the tensor 1372 // `a`: the identity function, transpose(a), or conjugate(transpose(a)) 1373 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, 1374 bool unit_diagonal, 1375 TriangularSolveOptions::Transpose transpose_a); 1376 1377 // Computes the Cholesky decompositions of a batch of symmetric (Hermitian) 1378 // positive definite matrices. 1379 // `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the 1380 // two minor dimensions equal. 1381 // If `lower` is true, the data from the lower triangle is used; if false, the 1382 // upper triangle is used. The input data in the other triangle of the input 1383 // does not affect the output. Returns the output in the same lower/uppper 1384 // triangle. The data returned in the other output triangle is arbitrary and 1385 // implementation-defined. 1386 // 1387 // The value returned if `a` is not Hermitian positive definite is 1388 // implementation-defined. 1389 XlaOp Cholesky(XlaOp a, bool lower); 1390 1391 // Enqueues an infeed instruction onto the computation, which writes data of 1392 // the given shape to the infeed buffer of the device. 1393 XlaOp Infeed(XlaBuilder* builder, const Shape& shape, 1394 const string& config = ""); 1395 1396 // Variant of Infeed which takes a token-shaped operand and produces a 1397 // two-element tuple containing the data value and a token-shaped value. 1398 // Tokens are used for ordering side-effecting operations. 1399 // TODO(b/110532604): Replace all uses of the non-token form with this variant. 1400 XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, 1401 const string& config = ""); 1402 1403 // Enqueues an outfeed instruction onto the computation. This instruction 1404 // generates outgoing data transfers for the given data. 1405 // 1406 // shape_with_layout communicates the laid out shape that we want to outfeed 1407 // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error 1408 // will occur. 1409 void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, 1410 const string& outfeed_config); 1411 1412 // Variant of Outfeed which takes a token-shaped operand and produces a 1413 // token-shaped value. Tokens are used for ordering side-effecting operations. 1414 // TODO(b/110532604): Replace all uses of the non-token form with this variant. 1415 XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, 1416 const Shape& shape_with_layout, 1417 const string& outfeed_config); 1418 1419 // Enqueues a call instruction onto the computation. 1420 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, 1421 absl::Span<const XlaOp> operands); 1422 1423 // Enqueues a custom call instruction onto the computation. A custom call 1424 // invokes code external to XLA. The |operands| are passed to the external code, 1425 // and the external code is expected to produce a result of the given 1426 // |shape|. The exact mechanism is backend-specific. For example, in the CPU 1427 // backend, a call instruction is emitted which targets a symbol with the name 1428 // |call_target_name|. |call_target_name| and |opaque| can arbitrary strings, 1429 // but |call_target_name| should be short as it may be used in labels. |opaque| 1430 // can encode arbitrarily large amounts of information. 1431 XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, 1432 absl::Span<const XlaOp> operands, const Shape& shape, 1433 const string& opaque = ""); 1434 1435 // Overload which constructs a custom call with fixed layouts. The operands will 1436 // have the layouts specified by |operand_shapes_with_layout| when provided to 1437 // external code, and the external code is expected to produce a result with the 1438 // layout specified by |shape_with_layout|. All shapes in |shape_with_layout| 1439 // and |operand_shapes_with_layout| must have layouts. 1440 XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, 1441 absl::Span<const XlaOp> operands, 1442 const Shape& shape_with_layout, 1443 absl::Span<const Shape> operand_shapes_with_layout, 1444 const string& opaque = ""); 1445 1446 // The following methods enqueue element-wise binary arithmetic operations 1447 // onto the computation. The shapes of the operands have to match unless one 1448 // of the operands is a scalar, or an explicit broadcast dimension is given 1449 // (see g3doc for more details). 1450 1451 // Enqueues a complex compose instruction onto the computation. 1452 XlaOp Complex(const XlaOp& real, const XlaOp& imag, 1453 absl::Span<const int64> broadcast_dimensions = {}); 1454 1455 // Enqueues a complex conjugate instruction onto the computation. 1456 XlaOp Conj(const XlaOp& operand); 1457 1458 // Enqueues an add instruction onto the computation. 1459 XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, 1460 absl::Span<const int64> broadcast_dimensions = {}); 1461 1462 // Enqueues a subtract instruction onto the computation. 1463 XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, 1464 absl::Span<const int64> broadcast_dimensions = {}); 1465 1466 // Enqueues a multiply instruction onto the computation. 1467 XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, 1468 absl::Span<const int64> broadcast_dimensions = {}); 1469 1470 // Enqueues a divide instruction onto the computation. 1471 XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, 1472 absl::Span<const int64> broadcast_dimensions = {}); 1473 1474 // Enqueues a remainder instruction onto the computation. 1475 XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, 1476 absl::Span<const int64> broadcast_dimensions = {}); 1477 1478 // Enqueues a max instruction onto the computation. 1479 XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, 1480 absl::Span<const int64> broadcast_dimensions = {}); 1481 1482 // Enqueues a min instruction onto the computation. 1483 XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, 1484 absl::Span<const int64> broadcast_dimensions = {}); 1485 1486 // Element-wise logical operators 1487 XlaOp And(const XlaOp& lhs, const XlaOp& rhs, 1488 absl::Span<const int64> broadcast_dimensions = {}); 1489 1490 // Overload to call And with 3 or more operands. We need the following somewhat 1491 // convoluted overload set to disambiguate with the overload that takes the 1492 // `broadcast_dimensions` optional param. 1493 inline XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) { 1494 return And(op1, And(op2, op3)); 1495 } 1496 template <typename... XlaOpTs> 1497 XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3, 1498 const XlaOpTs&... operands) { 1499 return And(op1, And(op2, And(op3, operands...))); 1500 } 1501 1502 XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, 1503 absl::Span<const int64> broadcast_dimensions = {}); 1504 1505 // Overload to call Or with 3 or more operands. As with `And`, we need the 1506 // following complicated overload set to handle the default arg in the `Or` 1507 // overload above. 1508 inline XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) { 1509 return Or(op1, Or(op2, op3)); 1510 } 1511 template <typename... XlaOpTs> 1512 XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3, 1513 const XlaOpTs&... operands) { 1514 return Or(op1, Or(op2, Or(op3, operands...))); 1515 } 1516 1517 XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, 1518 absl::Span<const int64> broadcast_dimensions = {}); 1519 1520 XlaOp Not(const XlaOp& operand); 1521 1522 XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, 1523 absl::Span<const int64> broadcast_dimensions = {}); 1524 XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, 1525 absl::Span<const int64> broadcast_dimensions = {}); 1526 XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, 1527 absl::Span<const int64> broadcast_dimensions = {}); 1528 1529 // Reduces an array among the provided dimensions, given "computation" as a 1530 // reduction operator. 1531 XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, 1532 const XlaComputation& computation, 1533 absl::Span<const int64> dimensions_to_reduce); 1534 1535 // Reduces several arrays simultaneously among the provided dimensions, given 1536 // "computation" as a reduction operator. 1537 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands, 1538 absl::Span<const XlaOp> init_values, 1539 const XlaComputation& computation, 1540 absl::Span<const int64> dimensions_to_reduce); 1541 1542 // Convenience wrapper around the above that reduces all the dimensions in the 1543 // operand shape. 1544 XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, 1545 const XlaComputation& computation); 1546 1547 // Enqueues a windowed reduce instruction onto the computation. 1548 XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, 1549 const XlaComputation& computation, 1550 absl::Span<const int64> window_dimensions, 1551 absl::Span<const int64> window_strides, Padding padding); 1552 1553 // As ReduceWindow(), but the padding is given in the format 1554 // returned by MakePadding(). 1555 XlaOp ReduceWindowWithGeneralPadding( 1556 const XlaOp& operand, const XlaOp& init_value, 1557 const XlaComputation& computation, 1558 absl::Span<const int64> window_dimensions, 1559 absl::Span<const int64> window_strides, 1560 absl::Span<const int64> base_dilations, 1561 absl::Span<const int64> window_dilations, 1562 absl::Span<const std::pair<int64, int64>> padding); 1563 1564 // Returns the sum of the operand value within each subgroup of replicas. All 1565 // replicas supply one input to the sum and all replicas receive the resulting 1566 // sum for each subgroup. 1567 XlaOp CrossReplicaSum(const XlaOp& operand, 1568 absl::Span<const ReplicaGroup> replica_groups = {}); 1569 1570 // Enqueues an operation that do an AllReduce of the operand cross cores. Here 1571 // AllReduce means doing a reduction on the input operand cross cores and then 1572 // broadcasting the reduction result to those cores. The reduction function is 1573 // defined by `computation`, which should be a commutative computation on 1574 // scalars, e.g., add, min, or max. The way that AllReduce is applied is 1575 // configured by: 1576 // 1577 // - `replica_groups`: each ReplicaGroup contains a list of replica id. If 1578 // empty, all replicas belong to one group. Allreduce will be applied within 1579 // subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} 1580 // means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. 1581 // 1582 // - `channel_id`: for Allreduce nodes from different modules, if they have the 1583 // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be 1584 // applied cross modules. 1585 // 1586 // TODO(b/117564385): Rename this to AllReduce when it's ready to use. 1587 XlaOp CrossReplicaSum( 1588 const XlaOp& operand, const XlaComputation& computation, 1589 absl::Span<const ReplicaGroup> replica_groups = {}, 1590 const absl::optional<ChannelHandle>& channel_id = absl::nullopt); 1591 1592 // Enqueues an operation that do an Alltoall of the operand cross cores. 1593 XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, 1594 int64 concat_dimension, int64 split_count, 1595 const std::vector<ReplicaGroup>& replica_groups = {}); 1596 1597 // Enqueues an collective operation that sends and receives data cross replicas. 1598 // 1599 // - `source_target_pair`: a list of (source_replica_id, target_replica_id) 1600 // pairs. For each pair, the operand is sent from source replica to target 1601 // replica. Note that, 1) any two pairs should not have the same target replica 1602 // id, and they should not have the same source replica id; 2) if a replica id 1603 // is not a target in any pair, then the output on that replica is a tensor 1604 // consists of 0(s) with the same shape as the input. 1605 XlaOp CollectivePermute( 1606 const XlaOp& operand, 1607 const std::vector<std::pair<int64, int64>>& source_target_pairs); 1608 1609 // Enqueues an operation that returns the replica ID. 1610 XlaOp ReplicaId(XlaBuilder* builder); 1611 1612 // Enqueues an operation that scatters the `source` array to the selected 1613 // indices of each window. 1614 XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, 1615 absl::Span<const int64> window_dimensions, 1616 absl::Span<const int64> window_strides, Padding padding, 1617 const XlaOp& source, const XlaOp& init_value, 1618 const XlaComputation& scatter); 1619 1620 // As SelectAndScatter(), but the padding is given in the format 1621 // returned by MakePadding(). 1622 XlaOp SelectAndScatterWithGeneralPadding( 1623 const XlaOp& operand, const XlaComputation& select, 1624 absl::Span<const int64> window_dimensions, 1625 absl::Span<const int64> window_strides, 1626 absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source, 1627 const XlaOp& init_value, const XlaComputation& scatter); 1628 1629 // Enqueues an abs instruction onto the computation. 1630 XlaOp Abs(const XlaOp& operand); 1631 1632 // Enqueues a atan2 instruction onto the computation. 1633 XlaOp Atan2(const XlaOp& y, const XlaOp& x, 1634 absl::Span<const int64> broadcast_dimensions = {}); 1635 1636 // Enqueues an exp instruction onto the computation. 1637 XlaOp Exp(const XlaOp& operand); 1638 1639 // Enqueues an expm1 instruction onto the computation. 1640 XlaOp Expm1(const XlaOp& operand); 1641 1642 // Enqueues a floor instruction onto the computation. 1643 XlaOp Floor(const XlaOp& operand); 1644 1645 // Enqueues a ceil instruction onto the computation. 1646 XlaOp Ceil(const XlaOp& operand); 1647 1648 // Enqueues a round instruction onto the computation, rounding to nearest even 1649 // with half-way cases rounding away from zero. 1650 XlaOp Round(const XlaOp& operand); 1651 1652 // Enqueues an log instruction (natural logarithm) onto the computation. 1653 XlaOp Log(const XlaOp& operand); 1654 1655 // Enqueues an log1p instruction (log(x+1)) onto the computation. 1656 XlaOp Log1p(const XlaOp& operand); 1657 1658 // Enqueues a sign instruction onto the computation. 1659 XlaOp Sign(const XlaOp& operand); 1660 1661 // Enqueues a count leading zeros instruction onto the computation. 1662 XlaOp Clz(const XlaOp& operand); 1663 1664 // Enqueues a cosine instruction onto the computation. 1665 XlaOp Cos(const XlaOp& operand); 1666 1667 // Enqueues a sine instruction onto the computation. 1668 XlaOp Sin(const XlaOp& operand); 1669 1670 // Enqueues a tanh instruction onto the computation. 1671 XlaOp Tanh(const XlaOp& operand); 1672 1673 // Enqueues a real-part instruction onto the computation. 1674 XlaOp Real(const XlaOp& operand); 1675 1676 // Enqueues an imaginary-part instruction onto the computation. 1677 XlaOp Imag(const XlaOp& operand); 1678 1679 // Enqueues a sqrt computation onto the computation. 1680 XlaOp Sqrt(const XlaOp& operand); 1681 1682 // Enqueues a rsqrt computation onto the computation. 1683 XlaOp Rsqrt(const XlaOp& operand); 1684 1685 // Enqueues a lhs^rhs computation onto the computation. 1686 XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, 1687 absl::Span<const int64> broadcast_dimensions = {}); 1688 1689 // Enqueues an operator that tests if the operand's values are finite, i.e., not 1690 // +/-Inf or NaN. Returns an array of booleans with the same shape where 1691 // entries are true iff the corresponding entry was not infinite or NaN. 1692 // 1693 // Defined only for real-valued (i.e. not complex) floating-point types; raises 1694 // an error for other types. 1695 // 1696 // See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h. 1697 XlaOp IsFinite(const XlaOp& operand); 1698 1699 // Enqueues an iota operation onto the computation. 1700 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension); 1701 1702 // Enqueues a rank-1 iota operation onto the computation. 1703 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); 1704 1705 // Enqueues a convert instruction onto the computation that changes the 1706 // element type of the operand array to primitive_type. 1707 XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); 1708 1709 // Enqueues a no-op instruction onto the computation that changes 1710 // the element type of the operand array to primitive_type. The 1711 // bit-widths of the source and destination element types must be 1712 // identical. 1713 XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); 1714 1715 // Enqueues a negate instruction onto the computation. 1716 XlaOp Neg(const XlaOp& operand); 1717 1718 // Enqueues a transpose instruction onto the computation. 1719 XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation); 1720 1721 // Enqueues a reverse instruction onto the computation. The order of the 1722 // elements in the given dimensions is reversed (i.e., the element at index i 1723 // is moved to index dimension_size - 1 - i). 1724 XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions); 1725 1726 // Enqueues a sort (as increasing order) instruction onto the computation. 1727 // If only keys are provided: 1728 // * If the keys are an rank-1 tensor (an array), the result is a sorted array 1729 // of keys, in ascending order. 1730 // * If the keys have higher rank, the keys are sorted along the provided 1731 // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension 1732 // value of 0 will independently sort every column, and a dimension value of 1 1733 // will independently sort each row. If no dimension number is provided, then 1734 // the last dimension is chosen by default. 1735 // 1736 // If both keys and values are provided: 1737 // * The keys and all values must be tensors with the same dimensions. The 1738 // element types of the tensors may be different. 1739 // * The result is a tuple that consists of a sorted tensor of keys (along the 1740 // provided dimension, as above) as the first element, and tensors with their 1741 // corresponding values as the other elements. 1742 ABSL_DEPRECATED("Use form with comparator computation instead") 1743 XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {}, 1744 int64 dimension = -1); 1745 1746 // Enqueues a sort instruction onto the computation, using 'comparator' for 1747 // comparisons. 'comparator' needs to define a strict weak order. 'is_stable' 1748 // determines whether the stable sorting should be used. 1749 // If only one operand is provided: 1750 // * If the operand is a rank-1 tensor (an array), the result is a sorted array. 1751 // The resulting sorting order has the property that for all index positions 1752 // i, j with i < j, either 1753 // comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or 1754 // comparator(value[i], value[j]) = true. 1755 // * If the operand has higher rank, the operand is sorted along the provided 1756 // dimension. For example, for a rank-2 tensor (a matrix), a dimension value 1757 // of 0 will independently sort every column, and a dimension value of 1 will 1758 // independently sort each row. If no dimension number is provided, then the 1759 // last dimension is chosen by default. For the dimension which is sorted, the 1760 // same sorting order applies as in the rank-1 case. 1761 // 1762 // If more than one operand is provided: 1763 // * All operands must be tensors with the same dimensions. The element types of 1764 // the tensors may be different. 1765 // * The result is a tuple that consists of the operands in sorted order (along 1766 // the provided dimension, as above). The same permutation as implied by the 1767 // comparison computation is applied to all operand tensors. When comparing 1768 // two index positions, 'comparator' is called with 2 * n scalar parameters, 1769 // where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at 1770 // two index positions. 1771 // Default comparator computations can be found in lib/comparators.h 1772 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator, 1773 int64 dimension = -1, bool is_stable = false); 1774 1775 // Enqueues a clamp instruction onto the computation. 1776 XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); 1777 1778 // Enqueues a map instruction onto the computation. 1779 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands, 1780 const XlaComputation& computation, absl::Span<const int64> dimensions, 1781 absl::Span<const XlaOp> static_operands = {}); 1782 1783 // Enqueues a N(mu, sigma) random number generation instruction onto the 1784 // computation. 1785 XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); 1786 1787 // Enqueues a U(a, b) random number generation instruction onto the 1788 // computation. Returns values in the semi-open interval [a, b). 1789 XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); 1790 1791 // Enqueues a while node onto the computation. 1792 XlaOp While(const XlaComputation& condition, const XlaComputation& body, 1793 const XlaOp& init); 1794 1795 // Enqueues a conditional node onto the computation. 1796 XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, 1797 const XlaComputation& true_computation, 1798 const XlaOp& false_operand, 1799 const XlaComputation& false_computation); 1800 1801 // Enqueues either a predicated (if/else) or indexed (switch/case/default) 1802 // conditional node onto the computation. N >= 1 branch_computations and 1803 // branch_operands are matched by index. branch_index selects the branch that 1804 // will be executed. Out of range branch_index uses the N-1'th 1805 // branch_computation as default. 1806 XlaOp Conditional(const XlaOp& branch_index, 1807 absl::Span<const XlaComputation* const> branch_computations, 1808 absl::Span<const XlaOp> branch_operands); 1809 1810 // Enqueues a ReducePrecision node onto the computation. 1811 XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, 1812 const int mantissa_bits); 1813 1814 // Enqueues a Gather node onto the computation. 1815 XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, 1816 const GatherDimensionNumbers& dimension_numbers, 1817 absl::Span<const int64> slice_sizes); 1818 1819 // Enqueues a Scatter node onto the computation. 1820 XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, 1821 const XlaOp& updates, const XlaComputation& update_computation, 1822 const ScatterDimensionNumbers& dimension_numbers); 1823 1824 // Enqueues a Send node onto the computation for device-to-device 1825 // communication. This operation sends the given operand to 1826 // a Recv instruction in a different computation that shares the same channel 1827 // handle. 1828 void Send(const XlaOp& operand, const ChannelHandle& handle); 1829 1830 // Variant of Send which takes a token-shaped operand and produces a 1831 // token-shaped value. Tokens are used for ordering side-effecting operations. 1832 // TODO(b/110532604): Replace all uses of the non-token form with this variant. 1833 XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, 1834 const ChannelHandle& handle); 1835 1836 // Enqueues a Recv node onto the computation for device-to-device 1837 // communication. The data comes from a Send instruction in a different 1838 // computation that shares the same channel handle and its shape must be the 1839 // same as the given shape. 1840 XlaOp Recv(XlaBuilder* builder, const Shape& shape, 1841 const ChannelHandle& handle); 1842 1843 // Variant of Recv which takes a token-shaped operand and produces a two-element 1844 // tuple containing the data value and a token-shaped value. Tokens are used 1845 // for ordering side-effecting operations. 1846 // TODO(b/110532604): Replace all uses of the non-token form with this variant. 1847 XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, 1848 const ChannelHandle& handle); 1849 1850 // Enqueues a Send node which transfers data from the device to the host. The 1851 // 'shape_with_layout' argument defines the layout of the data transferred; its 1852 // shape must be compatible with the shape of the operand. The operand must be 1853 // array-shaped. 1854 // TODO(b/111544877): Support tuple shapes. 1855 XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, 1856 const Shape& shape_with_layout, const ChannelHandle& handle); 1857 1858 // Enqueues a Recv node which transfers data from the host to the device. The 1859 // given shape must contain a layout and must be an array. 1860 // TODO(b/111544877): Support tuple shapes. 1861 XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, 1862 const ChannelHandle& handle); 1863 1864 // Enqueues an operation (AfterAll) with no operands that produces a 1865 // token-shaped value. Tokens are used for ordering side-effecting operations. 1866 // This is a separate method from AfterAll to facility the removal of 1867 // operand-less AfterAll instructions. 1868 // TODO(b/110532604): Remove this function when all tokens are derived from a 1869 // single token generated or passed into the entry computation. 1870 XlaOp CreateToken(XlaBuilder* builder); 1871 1872 // Enqueues an AfterAll instruction which produces a token-shaped value and 1873 // takes a variadic number of token-shaped operands. The number of operands must 1874 // be greater than zero. Used for joining tokens. 1875 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens); 1876 1877 // Normalizes operand across spatial and batch dimensions for each feature. 1878 // 1879 // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` 1880 // is the normalized result and batch_mean and batch_var are the mean and 1881 // variance, respectively, across batch for the operand. 1882 XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, 1883 const XlaOp& offset, float epsilon, 1884 int64 feature_index); 1885 1886 // Normalizes operand across spatial and batch dimensions for each feature. 1887 // 1888 // `BatchNormInference` is equivalent to calling `BatchNormTraining` without 1889 // computing `mean` and `variance` for each batch inside the operation. It 1890 // uses the input `mean` and `variance` instead as estimated values. The 1891 // purpose of this op is to reduce latency in inference, hence the name 1892 // `BatchNormInference`. 1893 // 1894 // The output has the same shape as `operand`, and contains the normalized 1895 // values for each batch. 1896 XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, 1897 const XlaOp& offset, const XlaOp& mean, 1898 const XlaOp& variance, float epsilon, 1899 int64 feature_index); 1900 1901 // Calculates the gradients of a batch norm op. 1902 // 1903 // The inputs `batch_mean` and `batch_var` represent the mean and variance 1904 // across the batch. 1905 // 1906 // Returns a tuple of three elements: 1907 // - grad_operand: Gradient with respect to input `operand` 1908 // - grad_offset: Gradient with respect to input `offset` 1909 // - grad_scale: Gradient with respect to input `scale` 1910 XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, 1911 const XlaOp& batch_mean, const XlaOp& batch_var, 1912 const XlaOp& grad_output, float epsilon, 1913 int64 feature_index); 1914 1915 // Returns the size of the given dimension of the operand. The operand must be 1916 // array shaped. 1917 XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); 1918 1919 // Implementation details below this point. 1920 // 1921 1922 // Free function template implementations. 1923 1924 template <typename NativeT> 1925 XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { 1926 return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value)); 1927 } 1928 1929 template <typename NativeT> 1930 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) { 1931 return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values)); 1932 } 1933 1934 template <typename NativeT> 1935 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) { 1936 Literal literal(ShapeUtil::MakeShape( 1937 primitive_util::NativeToPrimitiveType<NativeT>(), {length})); 1938 literal.PopulateWithValue(value); 1939 return ConstantLiteral(builder, literal); 1940 } 1941 1942 inline XlaOp ConstantR1(XlaBuilder* builder, 1943 const tensorflow::core::Bitmap& values) { 1944 return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); 1945 } 1946 1947 template <typename NativeT> 1948 XlaOp ConstantR2(XlaBuilder* builder, 1949 std::initializer_list<std::initializer_list<NativeT>> values) { 1950 return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values)); 1951 } 1952 1953 template <typename NativeT> 1954 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, 1955 const Array<NativeT>& values, 1956 const Layout& layout) { 1957 return ConstantLiteral( 1958 builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout)); 1959 } 1960 1961 template <typename NativeT> 1962 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) { 1963 return ConstantLiteral(builder, 1964 LiteralUtil::CreateFromArray<NativeT>(values)); 1965 } 1966 1967 template <typename NativeT> 1968 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, 1969 const Array2D<NativeT>& values, 1970 const Layout& layout) { 1971 return ConstantLiteral( 1972 builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout)); 1973 } 1974 1975 template <typename NativeT> 1976 XlaOp ConstantR2FromArray2D(XlaBuilder* builder, 1977 const Array2D<NativeT>& values) { 1978 return ConstantLiteral(builder, 1979 LiteralUtil::CreateR2FromArray2D<NativeT>(values)); 1980 } 1981 1982 template <typename NativeT> 1983 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, 1984 const Array3D<NativeT>& values, 1985 const Layout& layout) { 1986 return ConstantLiteral( 1987 builder, 1988 LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout)); 1989 } 1990 1991 template <typename NativeT> 1992 XlaOp ConstantR3FromArray3D(XlaBuilder* builder, 1993 const Array3D<NativeT>& values) { 1994 return ConstantFromArray(builder, values); 1995 } 1996 1997 template <typename NativeT> 1998 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, 1999 const Array4D<NativeT>& values, 2000 const Layout& layout) { 2001 return ConstantFromArrayWithLayout(builder, values, layout); 2002 } 2003 2004 template <typename NativeT> 2005 XlaOp ConstantR4FromArray4D(XlaBuilder* builder, 2006 const Array4D<NativeT>& values) { 2007 return ConstantFromArray(builder, values); 2008 } 2009 2010 } // namespace xla 2011 2012 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ 2013