1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ 17 #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ 18 19 #include <functional> 20 #include <initializer_list> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 25 #include "tensorflow/compiler/xla/array.h" 26 #include "tensorflow/compiler/xla/array2d.h" 27 #include "tensorflow/compiler/xla/array3d.h" 28 #include "tensorflow/compiler/xla/array4d.h" 29 #include "tensorflow/compiler/xla/client/client.h" 30 #include "tensorflow/compiler/xla/client/computation.h" 31 #include "tensorflow/compiler/xla/client/global_data.h" 32 #include "tensorflow/compiler/xla/client/padding.h" 33 #include "tensorflow/compiler/xla/literal_util.h" 34 #include "tensorflow/compiler/xla/statusor.h" 35 #include "tensorflow/compiler/xla/types.h" 36 #include "tensorflow/compiler/xla/xla_data.pb.h" 37 #include "tensorflow/core/lib/core/bitmap.h" 38 #include "tensorflow/core/lib/core/stringpiece.h" 39 #include "tensorflow/core/lib/gtl/array_slice.h" 40 #include "tensorflow/core/platform/macros.h" 41 #include "tensorflow/core/platform/stacktrace.h" 42 #include "tensorflow/core/platform/types.h" 43 44 namespace xla { 45 46 // Wraps an XLA client with a convenient interface for building up 47 // computations. Any errors encountered in building up the computation are 48 // deferred from being handled until Build() is called. 49 // 50 // Thread-compatible. 51 class ComputationBuilder { 52 public: 53 // client: client in which to build the computation. 54 // computation_name: name to use for the built computation. 55 ComputationBuilder(Client* client, const string& computation_name); 56 57 ~ComputationBuilder(); 58 59 // Returns the client the builder was initialized with. 60 Client* client() const { return client_; } 61 62 // Returns the computation name. 63 const string& name() const { return name_; } 64 65 // Sets OpMetadata that will be added to all instructions until cleared. 66 // 67 // OpMetadata is often applied to a series of XLA HLO instructions. As a 68 // result, OpMetadata is set on the Computation Builder. All subsequent 69 // instructions generated via this Computation Builder will have the same 70 // OpMetadata attached until a call to ClearOpMetadata. 71 void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } 72 73 // Clears the HloMetadata state. 74 void ClearOpMetadata() { metadata_.Clear(); } 75 76 // Sets an OpSharding that will be attached to all instructions until cleared. 77 void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } 78 79 // Clears the sharding. Ops will be sharded according to the default placement 80 // policy. 81 void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } 82 83 // Returns the OpSharding that will be attached to all instructions. 84 const tensorflow::gtl::optional<OpSharding>& sharding() const { 85 return sharding_; 86 } 87 88 // Sets the builder to a mode where it will die immediately when an error is 89 // encountered, rather than producing it in a deferred fashion when Build() is 90 // called (which is the default). 91 void set_die_immediately_on_error(bool enabled) { 92 die_immediately_on_error_ = enabled; 93 } 94 95 // Enqueues a "retrieve parameter value" instruction for a parameter that was 96 // passed to the computation. 97 ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, 98 const string& name); 99 100 // Retrieves the (inferred) shape of the operand in the computation. 101 StatusOr<std::unique_ptr<Shape>> GetShape( 102 const ComputationDataHandle& operand); 103 104 // Retrieves the (inferred) result for the current computation's shape. 105 StatusOr<ProgramShape> GetProgramShape(); 106 107 // Checks that the operand has the given expected shape. Returns the operand 108 // if yes, fails with a CHECK error if no. 109 ComputationDataHandle CheckShape(const ComputationDataHandle& operand, 110 const Shape& expected_shape); 111 112 // Checks that the lhs and rhs results have the same shape. 113 void CheckSameShape(const ComputationDataHandle& lhs, 114 const ComputationDataHandle& rhs); 115 116 // Enqueues a constant with the value of the given literal onto the 117 // computation. 118 ComputationDataHandle ConstantLiteral(const Literal& literal); 119 120 // Enqueues a constant onto the computation. Methods are templated on the 121 // native host type (NativeT) which corresponds to a specific XLA 122 // PrimitiveType as given in the following table: 123 // 124 // Native Type PrimitiveType 125 // ----------------------------- 126 // bool PRED 127 // int32 S32 128 // int64 S64 129 // uint32 U32 130 // uint64 U64 131 // float F32 132 // double F64 133 // 134 // Note: not all primitive types defined in xla_data.proto have a 135 // corresponding native type yet. 136 template <typename NativeT> 137 ComputationDataHandle ConstantR0(NativeT value); 138 template <typename NativeT> 139 ComputationDataHandle ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values); 140 ComputationDataHandle ConstantR1(const tensorflow::core::Bitmap& values); 141 template <typename NativeT> 142 ComputationDataHandle ConstantR2( 143 std::initializer_list<std::initializer_list<NativeT>> values); 144 template <typename NativeT> 145 ComputationDataHandle ConstantFromArrayWithLayout( 146 const Array<NativeT>& values, const Layout& layout); 147 template <typename NativeT> 148 ComputationDataHandle ConstantFromArray(const Array<NativeT>& values); 149 template <typename NativeT> 150 ComputationDataHandle ConstantR2FromArray2DWithLayout( 151 const Array2D<NativeT>& values, const Layout& layout); 152 template <typename NativeT> 153 ComputationDataHandle ConstantR2FromArray2D(const Array2D<NativeT>& values); 154 template <typename NativeT> 155 ComputationDataHandle ConstantR3FromArray3DWithLayout( 156 const Array3D<NativeT>& values, const Layout& layout); 157 template <typename NativeT> 158 ComputationDataHandle ConstantR3FromArray3D(const Array3D<NativeT>& values); 159 template <typename NativeT> 160 ComputationDataHandle ConstantR4FromArray4DWithLayout( 161 const Array4D<NativeT>& values, const Layout& layout); 162 template <typename NativeT> 163 ComputationDataHandle ConstantR4FromArray4D(const Array4D<NativeT>& values); 164 165 // Enqueues a rank one constant (vector) onto the computation. The vector has 166 // size 'length' and every element has the value 'value'. 167 template <typename NativeT> 168 ComputationDataHandle ConstantR1(int64 length, NativeT value); 169 170 // Adds dimensions to an array by duplicating the data in the array. 171 // 172 // The new dimensions are inserted on the left, i.e. if 173 // broadcast_sizes has values {a0, ..., aN} and the operand shape 174 // has dimensions {b0, ..., bM} then the shape of the output has 175 // dimensions {a0, ..., aN, b0, ..., bM}. 176 // 177 // The new dimensions index into copies of the operand, i.e. 178 // 179 // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] 180 ComputationDataHandle Broadcast( 181 const ComputationDataHandle& operand, 182 tensorflow::gtl::ArraySlice<int64> broadcast_sizes); 183 184 // Enqueues a pad operation onto the computation that pads the given value on 185 // the edges as well as between the elements of the input. padding_config 186 // specifies the padding amount for each dimension. 187 ComputationDataHandle Pad(const ComputationDataHandle& operand, 188 const ComputationDataHandle& padding_value, 189 const PaddingConfig& padding_config); 190 191 // Enqueues an operation onto the computation that flattens the operand based 192 // on the dimension order (major/slowest-varying to minor/fastest-varying) 193 // given, followed by reshaping it into the shape with the given dimension 194 // sizes (also major to minor). Conceptually, this is a limited form of 195 // "shape casting". 196 ComputationDataHandle Reshape(const ComputationDataHandle& operand, 197 tensorflow::gtl::ArraySlice<int64> dimensions, 198 tensorflow::gtl::ArraySlice<int64> new_sizes); 199 200 // Enqueues an operation onto the computation that collapses the operand, from 201 // minor to major order, then reshapes it into the shape with the given 202 // dimension sizes, also from major to minor. Conceptually, this is a limited 203 // form of "shape casting". 204 ComputationDataHandle Reshape(const ComputationDataHandle& operand, 205 tensorflow::gtl::ArraySlice<int64> new_sizes); 206 207 // Wrapper for Reshape. 208 // Enqueues an operation to collapse the provided dimensions; e.g. an 209 // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to 210 // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must 211 // be a consecutive, in-order subsequence of the operand dimensions. 212 // 213 // Note that collapsing a single dimension does nothing: 214 // 215 // {256} collapsing {0} => {256} 216 // {1} collapsing {0} => {1} 217 // 218 // Collapsing multiple dimensions produces a single result dimension: 219 // 220 // {256, 2} collapsing {0,1} => {512} 221 // {256, 2, 3} collapsing {0,1} => {512, 3} 222 // 223 // This could potentially cause data to be moved -- it provides a more 224 // structured form of reshaping than an arbitrary Reshape operation. 225 ComputationDataHandle Collapse(const ComputationDataHandle& operand, 226 tensorflow::gtl::ArraySlice<int64> dimensions); 227 228 // Enqueues a slice operation onto the computation that slices the operand 229 // from the start indices to the limit indices; e.g. 230 // 231 // x 232 // [ 0 1 2 3 ] 233 // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] 234 // [ 8 9 a b ] 235 // 236 // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D 237 // range notation. 238 // The strides parameter determines the stride over the slice 239 ComputationDataHandle Slice(const ComputationDataHandle& operand, 240 tensorflow::gtl::ArraySlice<int64> start_indices, 241 tensorflow::gtl::ArraySlice<int64> limit_indices, 242 tensorflow::gtl::ArraySlice<int64> strides); 243 244 // Enqueues a slice operation in a given dimension, taking all other 245 // dimensions as they are; e.g. if dimno is 1 from start_index 2 to 246 // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand 247 // for: 248 // 249 // array[:, 2:4:1, :] 250 ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, 251 int64 start_index, int64 limit_index, 252 int64 stride, int64 dimno); 253 254 // Enqueues a slice operation onto the computation that slices the 'operand' 255 // from dynamic start indices which are passed in 'start_indices'. 256 // The size of the slice in each dimension is passed in 'slice_sizes', 257 // which specify the end point of exclusive slice intervals in each 258 // dimension [start, start + size). 259 // The shape of 'start_indices' must be rank == 1, with dimension size 260 // equal to the rank of the 'operand'. 261 // Slice index calculations are computed modulo input dimension sizes to 262 // prevent dynamic start indices from generating out-of-bound array accesses. 263 ComputationDataHandle DynamicSlice( 264 const ComputationDataHandle& operand, 265 const ComputationDataHandle& start_indices, 266 tensorflow::gtl::ArraySlice<int64> slice_sizes); 267 268 // Enqueues a dynamic update slice operation onto the computation, which 269 // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. 270 // The shape of 'update' determines the shape of the slice of 'operand' 271 // which is updated. 272 // The indices specified in 'start_indices' specify the offset of the slice 273 // of 'operand' which is updated. 274 // 275 // update = {10, 11} // calculated at runtime. 276 // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] 277 // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] 278 // [7 8 9] [7 8 9 ] 279 // 280 // The shape of 'start_indices' must be rank == 1, with dimension size 281 // equal to the rank of the 'operand'. 282 // Slice index calculations are computed modulo update dimension sizes to 283 // prevent dynamic start indices from generating out-of-bound array accesses. 284 ComputationDataHandle DynamicUpdateSlice( 285 const ComputationDataHandle& operand, const ComputationDataHandle& update, 286 const ComputationDataHandle& start_indices); 287 288 // Enqueues a concatenate instruction onto the computation. 'operands' must 289 // have >= 1 entry. 290 ComputationDataHandle ConcatInDim( 291 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, 292 int64 dimension); 293 294 // Enqueue a tracing operation onto the computation; the computation will emit 295 // a logging message with the operand. 296 void Trace(const string& tag, const ComputationDataHandle& operand); 297 298 // Enqueues a conditional-move-like select operation onto the computation; 299 // predicated on pred, selects between on_true and on_false. 300 ComputationDataHandle Select(const ComputationDataHandle& pred, 301 const ComputationDataHandle& on_true, 302 const ComputationDataHandle& on_false); 303 304 // Enqueues a tuple-creation instruction onto the computation. 305 ComputationDataHandle Tuple( 306 tensorflow::gtl::ArraySlice<ComputationDataHandle> elements); 307 308 // Enqueues a tuple-element-get instruction onto the computation. 309 ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, 310 int64 index); 311 312 // Enqueues an equal-to comparison instruction onto the computation. 313 ComputationDataHandle Eq( 314 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 315 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 316 317 // Enqueues a not-equal comparison instruction onto the computation. 318 ComputationDataHandle Ne( 319 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 320 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 321 322 // Enqueues a greater-or-equal comparison instruction onto the computation. 323 ComputationDataHandle Ge( 324 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 325 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 326 327 // Enqueues a greater-than comparison instruction onto the computation. 328 ComputationDataHandle Gt( 329 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 330 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 331 332 // Enqueues a less-than comparison instruction onto the computation. 333 ComputationDataHandle Lt( 334 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 335 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 336 337 // Enqueues a less-or-equal comparison instruction onto the computation. 338 ComputationDataHandle Le( 339 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 340 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 341 342 // Enqueues a dot instruction onto the computation. 343 ComputationDataHandle Dot(const ComputationDataHandle& lhs, 344 const ComputationDataHandle& rhs); 345 346 // Enqueues a general dot instruction onto the computation. 347 ComputationDataHandle DotGeneral( 348 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 349 const DotDimensionNumbers& dimension_numbers); 350 351 // Default dimension numbers used for a 2D convolution. 352 static constexpr int64 kConvBatchDimension = 0; 353 static constexpr int64 kConvFeatureDimension = 1; 354 static constexpr int64 kConvFirstSpatialDimension = 2; 355 static constexpr int64 kConvSecondSpatialDimension = 3; 356 static constexpr int64 kConvKernelOutputDimension = 0; 357 static constexpr int64 kConvKernelInputDimension = 1; 358 static constexpr int64 kConvKernelFirstSpatialDimension = 2; 359 static constexpr int64 kConvKernelSecondSpatialDimension = 3; 360 361 // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for 362 // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for 363 // the kernel operand 364 // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. 365 static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( 366 int num_spatial_dims = 2); 367 368 // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an 369 // error if either the input or the weight dimension numbers have conflicts. 370 static StatusOr<ConvolutionDimensionNumbers> CreateConvDimensionNumbers( 371 int64 input_batch, int64 input_feature, int64 input_first_spatial, 372 int64 input_second_spatial, int64 output_batch, int64 output_feature, 373 int64 output_first_spatial, int64 output_second_spatial, 374 int64 kernel_output_feature, int64 kernel_input_feature, 375 int64 kernel_first_spatial, int64 kernel_second_spatial); 376 377 // Enqueues a convolution instruction onto the computation, which uses the 378 // default convolution dimension numbers. 379 ComputationDataHandle Conv(const ComputationDataHandle& lhs, 380 const ComputationDataHandle& rhs, 381 tensorflow::gtl::ArraySlice<int64> window_strides, 382 Padding padding); 383 384 // Enqueues a convolution instruction onto the computation, with the caller 385 // provided padding configuration in the format returned by MakePadding(). 386 ComputationDataHandle ConvWithGeneralPadding( 387 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 388 tensorflow::gtl::ArraySlice<int64> window_strides, 389 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding); 390 391 // Enqueues a convolution instruction onto the computation, with the caller 392 // provided dimension numbers configuration. 393 ComputationDataHandle ConvWithGeneralDimensions( 394 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 395 tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, 396 const ConvolutionDimensionNumbers& dimension_numbers); 397 398 // Enqueues a convolution instruction onto the computation, with the caller 399 // provided padding configuration as well as the dimension numbers. 400 ComputationDataHandle ConvGeneral( 401 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 402 tensorflow::gtl::ArraySlice<int64> window_strides, 403 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, 404 const ConvolutionDimensionNumbers& dimension_numbers); 405 406 // Enqueues a convolution instruction onto the computation, with the caller 407 // provided padding configuration, dilation factors and dimension numbers. 408 ComputationDataHandle ConvGeneralDilated( 409 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 410 tensorflow::gtl::ArraySlice<int64> window_strides, 411 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, 412 tensorflow::gtl::ArraySlice<int64> lhs_dilation, 413 tensorflow::gtl::ArraySlice<int64> rhs_dilation, 414 const ConvolutionDimensionNumbers& dimension_numbers); 415 416 // Enqueues an FFT instruction onto the computation, of the given type and 417 // with the given FFT length. 418 ComputationDataHandle Fft(const ComputationDataHandle& operand, 419 FftType fft_type, 420 tensorflow::gtl::ArraySlice<int64> fft_length); 421 422 // Enqueues an infeed instruction onto the computation, which writes data of 423 // the given shape to the infeed buffer of the device. 424 ComputationDataHandle Infeed(const Shape& shape, const string& config = ""); 425 426 // Enqueues an outfeed instruction onto the computation. This instruction 427 // generates outgoing data transfers for the given data. 428 // 429 // shape_with_layout communicates the laid out shape that we want to outfeed 430 // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error 431 // will occur. 432 void Outfeed(const ComputationDataHandle& operand, 433 const Shape& shape_with_layout, const string& outfeed_config); 434 435 // Enqueues a call instruction onto the computation. 436 ComputationDataHandle Call( 437 const Computation& computation, 438 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands); 439 440 // Enqueues a custom call instruction onto the computation. 441 // During code generation, a call instruction is emitted which targets a 442 // symbol with the name |call_target_name|. The |operands| are passed to the 443 // call instruction. |shape| is the resultant shape. 444 ComputationDataHandle CustomCall( 445 const string& call_target_name, 446 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, 447 const Shape& shape); 448 449 // Enqueues a pseudo-op to represent host-side computation data-dependencies. 450 // During code generation, host send and receive operations will be generated 451 // to transfer |operands| to the host and a single result of |shape| back to 452 // the device. Host send/recv operations are emitted using |channel_name|. 453 // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO 454 // instruction scheduling. 455 ComputationDataHandle HostCompute( 456 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, 457 const string& channel_name, int64 cost_estimate_ns, const Shape& shape); 458 459 // The following methods enqueue element-wise binary arithmetic operations 460 // onto the computation. The shapes of the operands have to match unless one 461 // of the operands is a scalar, or an explicit broadcast dimension is given 462 // (see g3doc for more details). 463 464 // Enqueues a complex compose instruction onto the computation. 465 ComputationDataHandle Complex( 466 const ComputationDataHandle& real, const ComputationDataHandle& imag, 467 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 468 469 // Enqueues a complex conjugate instruction onto the computation. 470 ComputationDataHandle Conj(const ComputationDataHandle& operand); 471 472 // Enqueues an add instruction onto the computation. 473 ComputationDataHandle Add( 474 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 475 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 476 477 // Enqueues a subtract instruction onto the computation. 478 ComputationDataHandle Sub( 479 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 480 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 481 482 // Enqueues a multiply instruction onto the computation. 483 ComputationDataHandle Mul( 484 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 485 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 486 487 // Enqueues a divide instruction onto the computation. 488 ComputationDataHandle Div( 489 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 490 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 491 492 // Enqueues a remainder instruction onto the computation. 493 ComputationDataHandle Rem( 494 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 495 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 496 497 // Enqueues a max instruction onto the computation. 498 ComputationDataHandle Max( 499 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 500 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 501 502 // Enqueues a min instruction onto the computation. 503 ComputationDataHandle Min( 504 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 505 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 506 507 // Element-wise logical operators 508 ComputationDataHandle And( 509 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 510 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 511 512 ComputationDataHandle Or( 513 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 514 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 515 516 ComputationDataHandle Not(const ComputationDataHandle& operand); 517 518 ComputationDataHandle ShiftLeft( 519 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 520 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 521 ComputationDataHandle ShiftRightArithmetic( 522 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 523 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 524 ComputationDataHandle ShiftRightLogical( 525 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 526 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 527 528 // Reduces an array among the provided dimensions, given "computation" as a 529 // reduction operator. 530 ComputationDataHandle Reduce( 531 const ComputationDataHandle& operand, 532 const ComputationDataHandle& init_value, const Computation& computation, 533 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce); 534 535 // Convenience wrapper around the above that reduces all the dimensions in the 536 // operand shape. 537 ComputationDataHandle ReduceAll(const ComputationDataHandle& operand, 538 const ComputationDataHandle& init_value, 539 const Computation& computation); 540 541 // Enqueues a windowed reduce instruction onto the computation. 542 ComputationDataHandle ReduceWindow( 543 const ComputationDataHandle& operand, 544 const ComputationDataHandle& init_value, const Computation& computation, 545 tensorflow::gtl::ArraySlice<int64> window_dimensions, 546 tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding); 547 548 // As ReduceWindow(), but the padding is given in the format 549 // returned by MakePadding(). 550 ComputationDataHandle ReduceWindowWithGeneralPadding( 551 const ComputationDataHandle& operand, 552 const ComputationDataHandle& init_value, const Computation& computation, 553 tensorflow::gtl::ArraySlice<int64> window_dimensions, 554 tensorflow::gtl::ArraySlice<int64> window_strides, 555 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding); 556 557 // Returns the sum of the operand value across all replicas. All replicas 558 // supply one input to the sum and all replicas receive the resulting sum. 559 ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); 560 561 // Enqueues an operation that scatters the `source` array to the selected 562 // indices of each window. 563 ComputationDataHandle SelectAndScatter( 564 const ComputationDataHandle& operand, const Computation& select, 565 tensorflow::gtl::ArraySlice<int64> window_dimensions, 566 tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, 567 const ComputationDataHandle& source, 568 const ComputationDataHandle& init_value, const Computation& scatter); 569 570 // As SelectAndScatter(), but the padding is given in the format 571 // returned by MakePadding(). 572 ComputationDataHandle SelectAndScatterWithGeneralPadding( 573 const ComputationDataHandle& operand, const Computation& select, 574 tensorflow::gtl::ArraySlice<int64> window_dimensions, 575 tensorflow::gtl::ArraySlice<int64> window_strides, 576 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, 577 const ComputationDataHandle& source, 578 const ComputationDataHandle& init_value, const Computation& scatter); 579 580 // Enqueues an abs instruction onto the computation. 581 ComputationDataHandle Abs(const ComputationDataHandle& operand); 582 583 // Enqueues a atan2 instruction onto the computation. 584 ComputationDataHandle Atan2( 585 const ComputationDataHandle& y, const ComputationDataHandle& x, 586 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 587 588 // Enqueues an exp instruction onto the computation. 589 ComputationDataHandle Exp(const ComputationDataHandle& operand); 590 591 // Enqueues a floor instruction onto the computation. 592 ComputationDataHandle Floor(const ComputationDataHandle& operand); 593 594 // Enqueues a ceil instruction onto the computation. 595 ComputationDataHandle Ceil(const ComputationDataHandle& operand); 596 597 // Enqueues a round instruction onto the computation, rounding to nearest even 598 // with half-way cases rounding away from zero. 599 ComputationDataHandle Round(const ComputationDataHandle& operand); 600 601 // Enqueues an log instruction (natural logarithm) onto the computation. 602 ComputationDataHandle Log(const ComputationDataHandle& operand); 603 604 // Enqueues a sign instruction onto the computation. 605 ComputationDataHandle Sign(const ComputationDataHandle& operand); 606 607 // Enqueues a cosine instruction onto the computation. 608 ComputationDataHandle Cos(const ComputationDataHandle& operand); 609 610 // Enqueues a sine instruction onto the computation. 611 ComputationDataHandle Sin(const ComputationDataHandle& operand); 612 613 // Enqueues a tanh instruction onto the computation. 614 ComputationDataHandle Tanh(const ComputationDataHandle& operand); 615 616 // Enqueues a real-part instruction onto the computation. 617 ComputationDataHandle Real(const ComputationDataHandle& operand); 618 619 // Enqueues an imaginary-part instruction onto the computation. 620 ComputationDataHandle Imag(const ComputationDataHandle& operand); 621 622 // Enqueues a float32 sqrt instruction onto the computation. 623 // (float32 is specified as there is an implicit float32 0.5f constant 624 // exponent). 625 ComputationDataHandle SqrtF32(const ComputationDataHandle& operand); 626 627 // Enqueues a float32 square instruction onto the computation. 628 // (float32 is specified as there is an implicit float32 2.0f constant 629 // exponent). 630 ComputationDataHandle SquareF32(const ComputationDataHandle& operand); 631 632 // Enqueues a lhs^rhs computation onto the computation. 633 ComputationDataHandle Pow( 634 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 635 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); 636 637 // Enqueues an operator that tests if the operand's values are finite, i.e., 638 // not Inf or NaN. Defined only for floating-point types. Returns an array of 639 // booleans with the same shape where entries are true iff the corresponding 640 // entry was NaN. 641 ComputationDataHandle IsFinite(const ComputationDataHandle& operand); 642 643 // Enqueues a convert instruction onto the computation that changes the 644 // element type of the operand array to primitive_type. 645 ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, 646 PrimitiveType new_element_type); 647 648 // Enqueues a no-op instruction onto the computation that changes 649 // the element type of the operand array to primitive_type. The 650 // bit-widths of the source and destination element types must be 651 // identical. 652 ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand, 653 PrimitiveType new_element_type); 654 655 // Enqueues a float32 reciprocal instruction onto the computation. 656 // (float32 is specified as there is an implicit float32 -1.0f constant 657 // exponent). 658 // 659 // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the 660 // shape of the operand. 661 ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand); 662 663 // Enqueues a negate instruction onto the computation. 664 ComputationDataHandle Neg(const ComputationDataHandle& operand); 665 666 // Enqueues a transpose instruction onto the computation. 667 ComputationDataHandle Transpose( 668 const ComputationDataHandle& operand, 669 tensorflow::gtl::ArraySlice<int64> permutation); 670 671 // Enqueues a reverse instruction onto the computation. The order of the 672 // elements in the given dimensions is reversed (i.e., the element at index i 673 // is moved to index dimension_size - 1 - i). 674 ComputationDataHandle Rev(const ComputationDataHandle& operand, 675 tensorflow::gtl::ArraySlice<int64> dimensions); 676 677 // Enqueues a sort (as increasing order) instruction onto the computation. 678 ComputationDataHandle Sort(const ComputationDataHandle& operand); 679 680 // Enqueues a clamp instruction onto the computation. 681 ComputationDataHandle Clamp(const ComputationDataHandle& min, 682 const ComputationDataHandle& operand, 683 const ComputationDataHandle& max); 684 685 // Enqueues a map instruction onto the computation. 686 ComputationDataHandle Map( 687 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, 688 const Computation& computation, 689 tensorflow::gtl::ArraySlice<int64> dimensions, 690 tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands = {}); 691 692 // Enqueues a N(mu, sigma) random number generation instruction onto the 693 // computation. 694 ComputationDataHandle RngNormal(const ComputationDataHandle& mu, 695 const ComputationDataHandle& sigma, 696 const Shape& shape); 697 698 // Enqueues a U(a, b) random number generation instruction onto the 699 // computation. Returns values in the semi-open interval [a, b). 700 ComputationDataHandle RngUniform(const ComputationDataHandle& a, 701 const ComputationDataHandle& b, 702 const Shape& shape); 703 704 // Enqueues a while node onto the computation. 705 ComputationDataHandle While(const Computation& condition, 706 const Computation& body, 707 const ComputationDataHandle& init); 708 709 // Enqueues a conditional node onto the computation. 710 ComputationDataHandle Conditional(const ComputationDataHandle& predicate, 711 const ComputationDataHandle& true_operand, 712 const Computation& true_computation, 713 const ComputationDataHandle& false_operand, 714 const Computation& false_computation); 715 716 // Enqueues a ReducePrecision node onto the computation. 717 ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand, 718 const int exponent_bits, 719 const int mantissa_bits); 720 721 // Enqueues a Gather node onto the computation. 722 ComputationDataHandle Gather( 723 const ComputationDataHandle& input, 724 const ComputationDataHandle& gather_indices, 725 const GatherDimensionNumbers& dimension_numbers, 726 tensorflow::gtl::ArraySlice<int64> window_bounds); 727 728 // Enqueues a Send node onto the computation, to send the given operand to 729 // a Recv instruction that shares the same channel handle. 730 void Send(const ComputationDataHandle& operand, const ChannelHandle& handle); 731 732 // Enqueues a Recv node onto the computation. The data comes from a Send 733 // instruction that shares the same channel handle and its shape must 734 // be the same as the given shape. 735 ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle); 736 737 // Returns true if 'operand' is a compile-time constant. A compile-time 738 // constant does not depend on parameters with index greater than or equal to 739 // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. 740 // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a 741 // compile-time constant without evaluating the computation. 742 StatusOr<bool> IsConstant(const ComputationDataHandle& operand, 743 int64 num_parameters = 0); 744 745 // Normalizes operand across spatial and batch dimensions for each feature. 746 // 747 // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` 748 // is the normalized result and batch_mean and batch_var are the mean and 749 // variance, respectively, across batch for the operand. 750 ComputationDataHandle BatchNormTraining(const ComputationDataHandle& operand, 751 const ComputationDataHandle& scale, 752 const ComputationDataHandle& offset, 753 float epsilon, int64 feature_index); 754 755 // Normalizes operand across spatial and batch dimensions for each feature. 756 // 757 // `BatchNormInference` is equivalent to calling `BatchNormTraining` without 758 // computing `mean` and `variance` for each batch inside the operation. It 759 // uses the input `mean` and `variance` instead as estimated values. The 760 // purpose of this op is to reduce latency in inference, hence the name 761 // `BatchNormInference`. 762 // 763 // The output has the same shape as `operand`, and contains the normalized 764 // values for each batch. 765 ComputationDataHandle BatchNormInference( 766 const ComputationDataHandle& operand, const ComputationDataHandle& scale, 767 const ComputationDataHandle& offset, const ComputationDataHandle& mean, 768 const ComputationDataHandle& variance, float epsilon, 769 int64 feature_index); 770 771 // Calculates the gradients of a batch norm op. 772 // 773 // The inputs `batch_mean` and `batch_var` represent the mean and variance 774 // across the batch. 775 // 776 // Returns a tuple of three elements: 777 // - grad_operand: Gradient with respect to input `operand` 778 // - grad_offset: Gradient with respect to input `offset` 779 // - grad_scale: Gradient with respect to input `scale` 780 ComputationDataHandle BatchNormGrad(const ComputationDataHandle& operand, 781 const ComputationDataHandle& scale, 782 const ComputationDataHandle& batch_mean, 783 const ComputationDataHandle& batch_var, 784 const ComputationDataHandle& grad_output, 785 float epsilon, int64 feature_index); 786 787 // Computes the value of a constant indicated by a 788 // ComputationDataHandle using a non-optimized interpreter on the host. 789 // 790 // The operand must be from the computation currently being built - 791 // i.e., returned from this builder with no intervening call to 792 // Build(). This happens to currently work regardless of that, but 793 // that may stop working at any time. 794 // 795 // The operand must represent a constant value, which in this case 796 // means that it must not statically depend on any parameter of the 797 // computation that is being built other then the ones specified on the 798 // parameter list. The parameters in the list will be indexed by their 799 // parameter id property so the number of parameters specified should be at 800 // least as many as the largest used parameter index. 801 // 802 // `IsConstant` can be used to test whether a computation is a compile-time 803 // constant without evaluation it. `ComputeConstant` only succeeds for 804 // computations where `IsConstant` returns true. 805 // 806 // This functionality can be useful when translating a computation 807 // into XLA where something that looked dynamic is required by 808 // XLA to be specified as a constant. E.g. the source 809 // computation (outside of XLA) may include a dynamic 810 // computation of the shape of something and ComputeConstant lets 811 // you determine what the value of that computation is in the case 812 // where the value can be determined at compile time. 813 // 814 // If output_layout is non-null, then the output of the computation 815 // will be stored using that layout. 816 StatusOr<std::unique_ptr<Literal>> ComputeConstant( 817 const ComputationDataHandle& operand, 818 const Layout* output_layout = nullptr, 819 tensorflow::gtl::ArraySlice<Literal> parameters = {}); 820 821 // Returns a new ComputationBuilder whose resultant Computation is used only 822 // by this ComputationBuilder. The sub-ComputationBuilder has the same 823 // die_immediately_on_error behavior as the parent. 824 std::unique_ptr<ComputationBuilder> CreateSubBuilder( 825 const string& computation_name); 826 827 // Modifies the computation being built so that executions of it 828 // will return the value associated with operand, rather than the 829 // last expression enqueued on the ComputationBuilder. Any subsequent 830 // operations added to the ComputationBuilder will not have any effect unless 831 // SetReturnValue is called again. 832 Status SetReturnValue(const ComputationDataHandle& operand); 833 834 // Builds the computation with the requested operations, or returns a non-ok 835 // status. 836 StatusOr<Computation> Build(); 837 838 // Builds the computation with the requested operations, or notes an error in 839 // the parent ComputationBuilder and returns an empty computation if building 840 // failed. This function is intended to be used where the returned 841 // Computation is only used by the parent ComputationBuilder and hence further 842 // operation on the returned Computation will simply be error'ed out if an 843 // error occurred while building this computation. If the built computation is 844 // to be used by a ComputationBuilder other than the parent ComputationBuilder 845 // then Build() should be used instead. 846 Computation BuildAndNoteError(); 847 848 // Returns the first error that was encountered while building the 849 // computation. When an error is encountered, by default we return a vacuous 850 // ComputationDataHandle and inform the user of the error that occurred while 851 // building the computation when they make a final call to Build(). 852 // 853 // See also set_die_immediately_on_error(). 854 Status first_error() const { return first_error_; } 855 856 private: 857 // Limited checking of convolution parameters. Returns false on 858 // error. 859 bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape, 860 const ConvolutionDimensionNumbers& dimension_numbers); 861 862 // The parent ComputationBuilder of a sub-ComputationBuilder. The 863 // parent_builder_ will be the nullptr if not a sub-ComputationBuilder. 864 ComputationBuilder* parent_builder_{nullptr}; 865 866 // Helper function for creating a Window proto from user-supplied 867 // data. Returns true if the user-supplied data was valid. 868 bool MakeWindow(tensorflow::gtl::ArraySlice<int64> window_dimensions, 869 tensorflow::gtl::ArraySlice<int64> window_strides, 870 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, 871 tensorflow::gtl::ArraySlice<int64> lhs_dilation, 872 tensorflow::gtl::ArraySlice<int64> rhs_dilation, 873 Window* window); 874 875 // Internal helper method that does the building for an arbitrary unary op. 876 ComputationDataHandle UnaryOp(UnaryOperation binop, 877 const ComputationDataHandle& operand); 878 879 // Internal helper method that does the building for an arbitrary binary op. 880 // broadcast_dimensions specifies which dimensions to use for broadcasting 881 // when the operation is between tensors of different ranks. 882 ComputationDataHandle BinaryOp( 883 BinaryOperation binop, const ComputationDataHandle& lhs, 884 const ComputationDataHandle& rhs, 885 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); 886 887 // Internal helper method that does the building for an arbitrary ternary op. 888 ComputationDataHandle TernaryOp(TernaryOperation triop, 889 const ComputationDataHandle& lhs, 890 const ComputationDataHandle& rhs, 891 const ComputationDataHandle& ehs); 892 893 // Internal helper method that does the building for a random number generator 894 // of a given distribution with an explicitly specified shape. 895 ComputationDataHandle RngOp( 896 RandomDistribution distribution, 897 tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters, 898 const Shape& shape); 899 900 // Populates computation_ with a valid object or returns a failing status. 901 // This is used before any given operation is enqueued. 902 Status PrepareComputation(); 903 904 // Notes that the error occurred by: 905 // * storing it internally and capturing a backtrace if it's the first error 906 // (this deferred value will be produced on the call to Build()) 907 // * dying if die_immediately_on_error_ is true 908 void NoteError(const Status& error); 909 910 // Helper function that runs the given op_request, filling in op_response. 911 // Before the op is run, PrepareComputation is called, and common fields in 912 // the op_request are filled in. 913 Status RunOp(OpRequest* op_request, OpResponse* op_response); 914 915 // Helper function that calls RunOp and calls NoteError on failures. 916 void RunOpAndNoteError(OpRequest* op_request); 917 918 // Helper function that calls RunOp and either returns the output computation 919 // data handle (on success) or a vacuous computation data handle (on failure). 920 ComputationDataHandle RunOpAndParseResponse(OpRequest* op_request); 921 922 // Helper function that implements GetShape without noting errors. This makes 923 // it easier to ensure the real GetShape will note errors on every error path. 924 StatusOr<std::unique_ptr<Shape>> GetShapeWithoutNoteError( 925 const ComputationDataHandle& operand); 926 927 string name_; // Name to use for the built computation. 928 929 // The first error encountered while building the computation. 930 // This is OK until the first error is encountered. 931 Status first_error_; 932 933 // The saved stack trace from the point at which the first error occurred. 934 tensorflow::SavedStackTrace first_error_backtrace_; 935 936 // The computation that operations are enqueued onto. 937 Computation computation_; 938 939 // The client that the computation is created in. Not owned. 940 Client* client_; 941 942 // Mode bit that indicates whether to die when a first error is encountered. 943 bool die_immediately_on_error_ = false; 944 945 // The metadata to attach to each op. This is structured as a "modal"-like 946 // operation, in order to simplify client code (and not sprinkle this metadata 947 // throughout the TensorFlow op kernel implementations). 948 OpMetadata metadata_; 949 950 // Sharding for this operator. This is structured as a "model"-like operation, 951 // in order to simplify client code, similar to metadata_. 952 tensorflow::gtl::optional<OpSharding> sharding_; 953 954 TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder); 955 }; 956 957 template <typename NativeT> 958 ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) { 959 return ConstantLiteral(*Literal::CreateR0<NativeT>(value)); 960 } 961 962 template <typename NativeT> 963 ComputationDataHandle ComputationBuilder::ConstantR1( 964 tensorflow::gtl::ArraySlice<NativeT> values) { 965 return ConstantLiteral(*Literal::CreateR1<NativeT>(values)); 966 } 967 968 template <typename NativeT> 969 ComputationDataHandle ComputationBuilder::ConstantR1(int64 length, 970 NativeT value) { 971 Literal literal(ShapeUtil::MakeShape( 972 primitive_util::NativeToPrimitiveType<NativeT>(), {length})); 973 literal.PopulateWithValue(value); 974 return ConstantLiteral(literal); 975 } 976 977 inline ComputationDataHandle ComputationBuilder::ConstantR1( 978 const tensorflow::core::Bitmap& values) { 979 return ConstantLiteral(*Literal::CreateR1(values)); 980 } 981 982 template <typename NativeT> 983 ComputationDataHandle ComputationBuilder::ConstantR2( 984 std::initializer_list<std::initializer_list<NativeT>> values) { 985 return ConstantLiteral(*Literal::CreateR2<NativeT>(values)); 986 } 987 988 template <typename NativeT> 989 ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout( 990 const Array<NativeT>& values, const Layout& layout) { 991 return ConstantLiteral( 992 *Literal::CreateFromArrayWithLayout<NativeT>(values, layout)); 993 } 994 995 template <typename NativeT> 996 ComputationDataHandle ComputationBuilder::ConstantFromArray( 997 const Array<NativeT>& values) { 998 return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values)); 999 } 1000 1001 template <typename NativeT> 1002 ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( 1003 const Array2D<NativeT>& values, const Layout& layout) { 1004 return ConstantLiteral( 1005 *Literal::CreateFromArrayWithLayout<NativeT>(values, layout)); 1006 } 1007 1008 template <typename NativeT> 1009 ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( 1010 const Array2D<NativeT>& values) { 1011 return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values)); 1012 } 1013 1014 template <typename NativeT> 1015 ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( 1016 const Array3D<NativeT>& values, const Layout& layout) { 1017 return ConstantLiteral( 1018 *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout)); 1019 } 1020 1021 template <typename NativeT> 1022 ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( 1023 const Array3D<NativeT>& values) { 1024 return ConstantFromArray(values); 1025 } 1026 1027 template <typename NativeT> 1028 ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( 1029 const Array4D<NativeT>& values, const Layout& layout) { 1030 return ConstantFromArrayWithLayout(values, layout); 1031 } 1032 1033 template <typename NativeT> 1034 ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( 1035 const Array4D<NativeT>& values) { 1036 return ConstantFromArray(values); 1037 } 1038 1039 // RAII-style object: sets the current sharding assignment in builder on 1040 // construction, and sets back to the previous assignment on destruction. 1041 class ScopedShardingAssignment { 1042 public: 1043 ScopedShardingAssignment(xla::ComputationBuilder* builder, 1044 tensorflow::gtl::optional<OpSharding> sharding) 1045 : builder_(builder), prev_sharding_(builder->sharding()) { 1046 SetSharding(sharding); 1047 } 1048 1049 ~ScopedShardingAssignment() { SetSharding(prev_sharding_); } 1050 1051 private: 1052 void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) { 1053 if (sharding.has_value()) { 1054 builder_->SetSharding(sharding.value()); 1055 } else { 1056 builder_->ClearSharding(); 1057 } 1058 } 1059 1060 xla::ComputationBuilder* const builder_; 1061 tensorflow::gtl::optional<OpSharding> prev_sharding_; 1062 1063 TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment); 1064 }; 1065 1066 } // namespace xla 1067 1068 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ 1069