Home | History | Annotate | Download | only in client
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
     17 #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
     18 
     19 #include <functional>
     20 #include <initializer_list>
     21 #include <memory>
     22 #include <string>
     23 #include <utility>
     24 
     25 #include "tensorflow/compiler/xla/array.h"
     26 #include "tensorflow/compiler/xla/array2d.h"
     27 #include "tensorflow/compiler/xla/array3d.h"
     28 #include "tensorflow/compiler/xla/array4d.h"
     29 #include "tensorflow/compiler/xla/client/client.h"
     30 #include "tensorflow/compiler/xla/client/computation.h"
     31 #include "tensorflow/compiler/xla/client/global_data.h"
     32 #include "tensorflow/compiler/xla/client/padding.h"
     33 #include "tensorflow/compiler/xla/literal_util.h"
     34 #include "tensorflow/compiler/xla/statusor.h"
     35 #include "tensorflow/compiler/xla/types.h"
     36 #include "tensorflow/compiler/xla/xla_data.pb.h"
     37 #include "tensorflow/core/lib/core/bitmap.h"
     38 #include "tensorflow/core/lib/core/stringpiece.h"
     39 #include "tensorflow/core/lib/gtl/array_slice.h"
     40 #include "tensorflow/core/platform/macros.h"
     41 #include "tensorflow/core/platform/stacktrace.h"
     42 #include "tensorflow/core/platform/types.h"
     43 
     44 namespace xla {
     45 
     46 // Wraps an XLA client with a convenient interface for building up
     47 // computations. Any errors encountered in building up the computation are
     48 // deferred from being handled until Build() is called.
     49 //
     50 // Thread-compatible.
     51 class ComputationBuilder {
     52  public:
     53   // client: client in which to build the computation.
     54   // computation_name: name to use for the built computation.
     55   ComputationBuilder(Client* client, const string& computation_name);
     56 
     57   ~ComputationBuilder();
     58 
     59   // Returns the client the builder was initialized with.
     60   Client* client() const { return client_; }
     61 
     62   // Returns the computation name.
     63   const string& name() const { return name_; }
     64 
     65   // Sets OpMetadata that will be added to all instructions until cleared.
     66   //
     67   // OpMetadata is often applied to a series of XLA HLO instructions. As a
     68   // result, OpMetadata is set on the Computation Builder. All subsequent
     69   // instructions generated via this Computation Builder will have the same
     70   // OpMetadata attached until a call to ClearOpMetadata.
     71   void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
     72 
     73   // Clears the HloMetadata state.
     74   void ClearOpMetadata() { metadata_.Clear(); }
     75 
     76   // Sets an OpSharding that will be attached to all instructions until cleared.
     77   void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
     78 
     79   // Clears the sharding. Ops will be sharded according to the default placement
     80   // policy.
     81   void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
     82 
     83   // Returns the OpSharding that will be attached to all instructions.
     84   const tensorflow::gtl::optional<OpSharding>& sharding() const {
     85     return sharding_;
     86   }
     87 
     88   // Sets the builder to a mode where it will die immediately when an error is
     89   // encountered, rather than producing it in a deferred fashion when Build() is
     90   // called (which is the default).
     91   void set_die_immediately_on_error(bool enabled) {
     92     die_immediately_on_error_ = enabled;
     93   }
     94 
     95   // Enqueues a "retrieve parameter value" instruction for a parameter that was
     96   // passed to the computation.
     97   ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape,
     98                                   const string& name);
     99 
    100   // Retrieves the (inferred) shape of the operand in the computation.
    101   StatusOr<std::unique_ptr<Shape>> GetShape(
    102       const ComputationDataHandle& operand);
    103 
    104   // Retrieves the (inferred) result for the current computation's shape.
    105   StatusOr<ProgramShape> GetProgramShape();
    106 
    107   // Checks that the operand has the given expected shape. Returns the operand
    108   // if yes, fails with a CHECK error if no.
    109   ComputationDataHandle CheckShape(const ComputationDataHandle& operand,
    110                                    const Shape& expected_shape);
    111 
    112   // Checks that the lhs and rhs results have the same shape.
    113   void CheckSameShape(const ComputationDataHandle& lhs,
    114                       const ComputationDataHandle& rhs);
    115 
    116   // Enqueues a constant with the value of the given literal onto the
    117   // computation.
    118   ComputationDataHandle ConstantLiteral(const Literal& literal);
    119 
    120   // Enqueues a constant onto the computation. Methods are templated on the
    121   // native host type (NativeT) which corresponds to a specific XLA
    122   // PrimitiveType as given in the following table:
    123   //
    124   //  Native Type   PrimitiveType
    125   // -----------------------------
    126   //   bool           PRED
    127   //   int32          S32
    128   //   int64          S64
    129   //   uint32         U32
    130   //   uint64         U64
    131   //   float          F32
    132   //   double         F64
    133   //
    134   // Note: not all primitive types defined in xla_data.proto have a
    135   // corresponding native type yet.
    136   template <typename NativeT>
    137   ComputationDataHandle ConstantR0(NativeT value);
    138   template <typename NativeT>
    139   ComputationDataHandle ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values);
    140   ComputationDataHandle ConstantR1(const tensorflow::core::Bitmap& values);
    141   template <typename NativeT>
    142   ComputationDataHandle ConstantR2(
    143       std::initializer_list<std::initializer_list<NativeT>> values);
    144   template <typename NativeT>
    145   ComputationDataHandle ConstantFromArrayWithLayout(
    146       const Array<NativeT>& values, const Layout& layout);
    147   template <typename NativeT>
    148   ComputationDataHandle ConstantFromArray(const Array<NativeT>& values);
    149   template <typename NativeT>
    150   ComputationDataHandle ConstantR2FromArray2DWithLayout(
    151       const Array2D<NativeT>& values, const Layout& layout);
    152   template <typename NativeT>
    153   ComputationDataHandle ConstantR2FromArray2D(const Array2D<NativeT>& values);
    154   template <typename NativeT>
    155   ComputationDataHandle ConstantR3FromArray3DWithLayout(
    156       const Array3D<NativeT>& values, const Layout& layout);
    157   template <typename NativeT>
    158   ComputationDataHandle ConstantR3FromArray3D(const Array3D<NativeT>& values);
    159   template <typename NativeT>
    160   ComputationDataHandle ConstantR4FromArray4DWithLayout(
    161       const Array4D<NativeT>& values, const Layout& layout);
    162   template <typename NativeT>
    163   ComputationDataHandle ConstantR4FromArray4D(const Array4D<NativeT>& values);
    164 
    165   // Enqueues a rank one constant (vector) onto the computation. The vector has
    166   // size 'length' and every element has the value 'value'.
    167   template <typename NativeT>
    168   ComputationDataHandle ConstantR1(int64 length, NativeT value);
    169 
    170   // Adds dimensions to an array by duplicating the data in the array.
    171   //
    172   // The new dimensions are inserted on the left, i.e. if
    173   // broadcast_sizes has values {a0, ..., aN} and the operand shape
    174   // has dimensions {b0, ..., bM} then the shape of the output has
    175   // dimensions {a0, ..., aN, b0, ..., bM}.
    176   //
    177   // The new dimensions index into copies of the operand, i.e.
    178   //
    179   //   output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
    180   ComputationDataHandle Broadcast(
    181       const ComputationDataHandle& operand,
    182       tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
    183 
    184   // Enqueues a pad operation onto the computation that pads the given value on
    185   // the edges as well as between the elements of the input. padding_config
    186   // specifies the padding amount for each dimension.
    187   ComputationDataHandle Pad(const ComputationDataHandle& operand,
    188                             const ComputationDataHandle& padding_value,
    189                             const PaddingConfig& padding_config);
    190 
    191   // Enqueues an operation onto the computation that flattens the operand based
    192   // on the dimension order (major/slowest-varying to minor/fastest-varying)
    193   // given, followed by reshaping it into the shape with the given dimension
    194   // sizes (also major to minor). Conceptually, this is a limited form of
    195   // "shape casting".
    196   ComputationDataHandle Reshape(const ComputationDataHandle& operand,
    197                                 tensorflow::gtl::ArraySlice<int64> dimensions,
    198                                 tensorflow::gtl::ArraySlice<int64> new_sizes);
    199 
    200   // Enqueues an operation onto the computation that collapses the operand, from
    201   // minor to major order, then reshapes it into the shape with the given
    202   // dimension sizes, also from major to minor. Conceptually, this is a limited
    203   // form of "shape casting".
    204   ComputationDataHandle Reshape(const ComputationDataHandle& operand,
    205                                 tensorflow::gtl::ArraySlice<int64> new_sizes);
    206 
    207   // Wrapper for Reshape.
    208   // Enqueues an operation to collapse the provided dimensions; e.g. an
    209   // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
    210   // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
    211   // be a consecutive, in-order subsequence of the operand dimensions.
    212   //
    213   // Note that collapsing a single dimension does nothing:
    214   //
    215   //    {256} collapsing {0} => {256}
    216   //    {1} collapsing {0} => {1}
    217   //
    218   // Collapsing multiple dimensions produces a single result dimension:
    219   //
    220   //    {256, 2} collapsing {0,1} => {512}
    221   //    {256, 2, 3} collapsing {0,1} => {512, 3}
    222   //
    223   // This could potentially cause data to be moved -- it provides a more
    224   // structured form of reshaping than an arbitrary Reshape operation.
    225   ComputationDataHandle Collapse(const ComputationDataHandle& operand,
    226                                  tensorflow::gtl::ArraySlice<int64> dimensions);
    227 
    228   // Enqueues a slice operation onto the computation that slices the operand
    229   // from the start indices to the limit indices; e.g.
    230   //
    231   //        x
    232   //   [ 0 1 2 3 ]
    233   // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
    234   //   [ 8 9 a b ]
    235   //
    236   // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
    237   // range notation.
    238   // The strides parameter determines the stride over the slice
    239   ComputationDataHandle Slice(const ComputationDataHandle& operand,
    240                               tensorflow::gtl::ArraySlice<int64> start_indices,
    241                               tensorflow::gtl::ArraySlice<int64> limit_indices,
    242                               tensorflow::gtl::ArraySlice<int64> strides);
    243 
    244   // Enqueues a slice operation in a given dimension, taking all other
    245   // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
    246   // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
    247   // for:
    248   //
    249   //  array[:, 2:4:1, :]
    250   ComputationDataHandle SliceInDim(const ComputationDataHandle& operand,
    251                                    int64 start_index, int64 limit_index,
    252                                    int64 stride, int64 dimno);
    253 
    254   // Enqueues a slice operation onto the computation that slices the 'operand'
    255   // from dynamic start indices which are passed in 'start_indices'.
    256   // The size of the slice in each dimension is passed in 'slice_sizes',
    257   // which specify the end point of exclusive slice intervals in each
    258   // dimension [start, start + size).
    259   // The shape of 'start_indices' must be rank == 1, with dimension size
    260   // equal to the rank of the 'operand'.
    261   // Slice index calculations are computed modulo input dimension sizes to
    262   // prevent dynamic start indices from generating out-of-bound array accesses.
    263   ComputationDataHandle DynamicSlice(
    264       const ComputationDataHandle& operand,
    265       const ComputationDataHandle& start_indices,
    266       tensorflow::gtl::ArraySlice<int64> slice_sizes);
    267 
    268   // Enqueues a dynamic update slice operation onto the computation, which
    269   // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
    270   // The shape of 'update' determines the shape of the slice of 'operand'
    271   // which is updated.
    272   // The indices specified in 'start_indices' specify the offset of the slice
    273   // of 'operand' which is updated.
    274   //
    275   //               update = {10, 11} // calculated at runtime.
    276   //   [1 2 3]     start  = {1, 1}   // calculated at runtime.  [1 2  3 ]
    277   //   [4 5 6]  => DynamicUpdateslice(data, update, start)   => [4 10 11]
    278   //   [7 8 9]                                                  [7 8  9 ]
    279   //
    280   // The shape of 'start_indices' must be rank == 1, with dimension size
    281   // equal to the rank of the 'operand'.
    282   // Slice index calculations are computed modulo update dimension sizes to
    283   // prevent dynamic start indices from generating out-of-bound array accesses.
    284   ComputationDataHandle DynamicUpdateSlice(
    285       const ComputationDataHandle& operand, const ComputationDataHandle& update,
    286       const ComputationDataHandle& start_indices);
    287 
    288   // Enqueues a concatenate instruction onto the computation. 'operands' must
    289   // have >= 1 entry.
    290   ComputationDataHandle ConcatInDim(
    291       tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
    292       int64 dimension);
    293 
    294   // Enqueue a tracing operation onto the computation; the computation will emit
    295   // a logging message with the operand.
    296   void Trace(const string& tag, const ComputationDataHandle& operand);
    297 
    298   // Enqueues a conditional-move-like select operation onto the computation;
    299   // predicated on pred, selects between on_true and on_false.
    300   ComputationDataHandle Select(const ComputationDataHandle& pred,
    301                                const ComputationDataHandle& on_true,
    302                                const ComputationDataHandle& on_false);
    303 
    304   // Enqueues a tuple-creation instruction onto the computation.
    305   ComputationDataHandle Tuple(
    306       tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
    307 
    308   // Enqueues a tuple-element-get instruction onto the computation.
    309   ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data,
    310                                         int64 index);
    311 
    312   // Enqueues an equal-to comparison instruction onto the computation.
    313   ComputationDataHandle Eq(
    314       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    315       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    316 
    317   // Enqueues a not-equal comparison instruction onto the computation.
    318   ComputationDataHandle Ne(
    319       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    320       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    321 
    322   // Enqueues a greater-or-equal comparison instruction onto the computation.
    323   ComputationDataHandle Ge(
    324       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    325       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    326 
    327   // Enqueues a greater-than comparison instruction onto the computation.
    328   ComputationDataHandle Gt(
    329       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    330       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    331 
    332   // Enqueues a less-than comparison instruction onto the computation.
    333   ComputationDataHandle Lt(
    334       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    335       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    336 
    337   // Enqueues a less-or-equal comparison instruction onto the computation.
    338   ComputationDataHandle Le(
    339       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    340       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    341 
    342   // Enqueues a dot instruction onto the computation.
    343   ComputationDataHandle Dot(const ComputationDataHandle& lhs,
    344                             const ComputationDataHandle& rhs);
    345 
    346   // Enqueues a general dot instruction onto the computation.
    347   ComputationDataHandle DotGeneral(
    348       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    349       const DotDimensionNumbers& dimension_numbers);
    350 
    351   // Default dimension numbers used for a 2D convolution.
    352   static constexpr int64 kConvBatchDimension = 0;
    353   static constexpr int64 kConvFeatureDimension = 1;
    354   static constexpr int64 kConvFirstSpatialDimension = 2;
    355   static constexpr int64 kConvSecondSpatialDimension = 3;
    356   static constexpr int64 kConvKernelOutputDimension = 0;
    357   static constexpr int64 kConvKernelInputDimension = 1;
    358   static constexpr int64 kConvKernelFirstSpatialDimension = 2;
    359   static constexpr int64 kConvKernelSecondSpatialDimension = 3;
    360 
    361   // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
    362   // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
    363   // the kernel operand
    364   // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
    365   static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
    366       int num_spatial_dims = 2);
    367 
    368   // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an
    369   // error if either the input or the weight dimension numbers have conflicts.
    370   static StatusOr<ConvolutionDimensionNumbers> CreateConvDimensionNumbers(
    371       int64 input_batch, int64 input_feature, int64 input_first_spatial,
    372       int64 input_second_spatial, int64 output_batch, int64 output_feature,
    373       int64 output_first_spatial, int64 output_second_spatial,
    374       int64 kernel_output_feature, int64 kernel_input_feature,
    375       int64 kernel_first_spatial, int64 kernel_second_spatial);
    376 
    377   // Enqueues a convolution instruction onto the computation, which uses the
    378   // default convolution dimension numbers.
    379   ComputationDataHandle Conv(const ComputationDataHandle& lhs,
    380                              const ComputationDataHandle& rhs,
    381                              tensorflow::gtl::ArraySlice<int64> window_strides,
    382                              Padding padding);
    383 
    384   // Enqueues a convolution instruction onto the computation, with the caller
    385   // provided padding configuration in the format returned by MakePadding().
    386   ComputationDataHandle ConvWithGeneralPadding(
    387       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    388       tensorflow::gtl::ArraySlice<int64> window_strides,
    389       tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
    390 
    391   // Enqueues a convolution instruction onto the computation, with the caller
    392   // provided dimension numbers configuration.
    393   ComputationDataHandle ConvWithGeneralDimensions(
    394       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    395       tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
    396       const ConvolutionDimensionNumbers& dimension_numbers);
    397 
    398   // Enqueues a convolution instruction onto the computation, with the caller
    399   // provided padding configuration as well as the dimension numbers.
    400   ComputationDataHandle ConvGeneral(
    401       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    402       tensorflow::gtl::ArraySlice<int64> window_strides,
    403       tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
    404       const ConvolutionDimensionNumbers& dimension_numbers);
    405 
    406   // Enqueues a convolution instruction onto the computation, with the caller
    407   // provided padding configuration, dilation factors and dimension numbers.
    408   ComputationDataHandle ConvGeneralDilated(
    409       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    410       tensorflow::gtl::ArraySlice<int64> window_strides,
    411       tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
    412       tensorflow::gtl::ArraySlice<int64> lhs_dilation,
    413       tensorflow::gtl::ArraySlice<int64> rhs_dilation,
    414       const ConvolutionDimensionNumbers& dimension_numbers);
    415 
    416   // Enqueues an FFT instruction onto the computation, of the given type and
    417   // with the given FFT length.
    418   ComputationDataHandle Fft(const ComputationDataHandle& operand,
    419                             FftType fft_type,
    420                             tensorflow::gtl::ArraySlice<int64> fft_length);
    421 
    422   // Enqueues an infeed instruction onto the computation, which writes data of
    423   // the given shape to the infeed buffer of the device.
    424   ComputationDataHandle Infeed(const Shape& shape, const string& config = "");
    425 
    426   // Enqueues an outfeed instruction onto the computation. This instruction
    427   // generates outgoing data transfers for the given data.
    428   //
    429   // shape_with_layout communicates the laid out shape that we want to outfeed
    430   // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
    431   // will occur.
    432   void Outfeed(const ComputationDataHandle& operand,
    433                const Shape& shape_with_layout, const string& outfeed_config);
    434 
    435   // Enqueues a call instruction onto the computation.
    436   ComputationDataHandle Call(
    437       const Computation& computation,
    438       tensorflow::gtl::ArraySlice<ComputationDataHandle> operands);
    439 
    440   // Enqueues a custom call instruction onto the computation.
    441   // During code generation, a call instruction is emitted which targets a
    442   // symbol with the name |call_target_name|.  The |operands| are passed to the
    443   // call instruction.  |shape| is the resultant shape.
    444   ComputationDataHandle CustomCall(
    445       const string& call_target_name,
    446       tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
    447       const Shape& shape);
    448 
    449   // Enqueues a pseudo-op to represent host-side computation data-dependencies.
    450   // During code generation, host send and receive operations will be generated
    451   // to transfer |operands| to the host and a single result of |shape| back to
    452   // the device.  Host send/recv operations are emitted using |channel_name|.
    453   // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
    454   // instruction scheduling.
    455   ComputationDataHandle HostCompute(
    456       tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
    457       const string& channel_name, int64 cost_estimate_ns, const Shape& shape);
    458 
    459   // The following methods enqueue element-wise binary arithmetic operations
    460   // onto the computation. The shapes of the operands have to match unless one
    461   // of the operands is a scalar, or an explicit broadcast dimension is given
    462   // (see g3doc for more details).
    463 
    464   // Enqueues a complex compose instruction onto the computation.
    465   ComputationDataHandle Complex(
    466       const ComputationDataHandle& real, const ComputationDataHandle& imag,
    467       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    468 
    469   // Enqueues a complex conjugate instruction onto the computation.
    470   ComputationDataHandle Conj(const ComputationDataHandle& operand);
    471 
    472   // Enqueues an add instruction onto the computation.
    473   ComputationDataHandle Add(
    474       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    475       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    476 
    477   // Enqueues a subtract instruction onto the computation.
    478   ComputationDataHandle Sub(
    479       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    480       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    481 
    482   // Enqueues a multiply instruction onto the computation.
    483   ComputationDataHandle Mul(
    484       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    485       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    486 
    487   // Enqueues a divide instruction onto the computation.
    488   ComputationDataHandle Div(
    489       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    490       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    491 
    492   // Enqueues a remainder instruction onto the computation.
    493   ComputationDataHandle Rem(
    494       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    495       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    496 
    497   // Enqueues a max instruction onto the computation.
    498   ComputationDataHandle Max(
    499       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    500       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    501 
    502   // Enqueues a min instruction onto the computation.
    503   ComputationDataHandle Min(
    504       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    505       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    506 
    507   // Element-wise logical operators
    508   ComputationDataHandle And(
    509       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    510       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    511 
    512   ComputationDataHandle Or(
    513       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    514       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    515 
    516   ComputationDataHandle Not(const ComputationDataHandle& operand);
    517 
    518   ComputationDataHandle ShiftLeft(
    519       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    520       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    521   ComputationDataHandle ShiftRightArithmetic(
    522       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    523       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    524   ComputationDataHandle ShiftRightLogical(
    525       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    526       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    527 
    528   // Reduces an array among the provided dimensions, given "computation" as a
    529   // reduction operator.
    530   ComputationDataHandle Reduce(
    531       const ComputationDataHandle& operand,
    532       const ComputationDataHandle& init_value, const Computation& computation,
    533       tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
    534 
    535   // Convenience wrapper around the above that reduces all the dimensions in the
    536   // operand shape.
    537   ComputationDataHandle ReduceAll(const ComputationDataHandle& operand,
    538                                   const ComputationDataHandle& init_value,
    539                                   const Computation& computation);
    540 
    541   // Enqueues a windowed reduce instruction onto the computation.
    542   ComputationDataHandle ReduceWindow(
    543       const ComputationDataHandle& operand,
    544       const ComputationDataHandle& init_value, const Computation& computation,
    545       tensorflow::gtl::ArraySlice<int64> window_dimensions,
    546       tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
    547 
    548   // As ReduceWindow(), but the padding is given in the format
    549   // returned by MakePadding().
    550   ComputationDataHandle ReduceWindowWithGeneralPadding(
    551       const ComputationDataHandle& operand,
    552       const ComputationDataHandle& init_value, const Computation& computation,
    553       tensorflow::gtl::ArraySlice<int64> window_dimensions,
    554       tensorflow::gtl::ArraySlice<int64> window_strides,
    555       tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
    556 
    557   // Returns the sum of the operand value across all replicas. All replicas
    558   // supply one input to the sum and all replicas receive the resulting sum.
    559   ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand);
    560 
    561   // Enqueues an operation that scatters the `source` array to the selected
    562   // indices of each window.
    563   ComputationDataHandle SelectAndScatter(
    564       const ComputationDataHandle& operand, const Computation& select,
    565       tensorflow::gtl::ArraySlice<int64> window_dimensions,
    566       tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
    567       const ComputationDataHandle& source,
    568       const ComputationDataHandle& init_value, const Computation& scatter);
    569 
    570   // As SelectAndScatter(), but the padding is given in the format
    571   // returned by MakePadding().
    572   ComputationDataHandle SelectAndScatterWithGeneralPadding(
    573       const ComputationDataHandle& operand, const Computation& select,
    574       tensorflow::gtl::ArraySlice<int64> window_dimensions,
    575       tensorflow::gtl::ArraySlice<int64> window_strides,
    576       tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
    577       const ComputationDataHandle& source,
    578       const ComputationDataHandle& init_value, const Computation& scatter);
    579 
    580   // Enqueues an abs instruction onto the computation.
    581   ComputationDataHandle Abs(const ComputationDataHandle& operand);
    582 
    583   // Enqueues a atan2 instruction onto the computation.
    584   ComputationDataHandle Atan2(
    585       const ComputationDataHandle& y, const ComputationDataHandle& x,
    586       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    587 
    588   // Enqueues an exp instruction onto the computation.
    589   ComputationDataHandle Exp(const ComputationDataHandle& operand);
    590 
    591   // Enqueues a floor instruction onto the computation.
    592   ComputationDataHandle Floor(const ComputationDataHandle& operand);
    593 
    594   // Enqueues a ceil instruction onto the computation.
    595   ComputationDataHandle Ceil(const ComputationDataHandle& operand);
    596 
    597   // Enqueues a round instruction onto the computation, rounding to nearest even
    598   // with half-way cases rounding away from zero.
    599   ComputationDataHandle Round(const ComputationDataHandle& operand);
    600 
    601   // Enqueues an log instruction (natural logarithm) onto the computation.
    602   ComputationDataHandle Log(const ComputationDataHandle& operand);
    603 
    604   // Enqueues a sign instruction onto the computation.
    605   ComputationDataHandle Sign(const ComputationDataHandle& operand);
    606 
    607   // Enqueues a cosine instruction onto the computation.
    608   ComputationDataHandle Cos(const ComputationDataHandle& operand);
    609 
    610   // Enqueues a sine instruction onto the computation.
    611   ComputationDataHandle Sin(const ComputationDataHandle& operand);
    612 
    613   // Enqueues a tanh instruction onto the computation.
    614   ComputationDataHandle Tanh(const ComputationDataHandle& operand);
    615 
    616   // Enqueues a real-part instruction onto the computation.
    617   ComputationDataHandle Real(const ComputationDataHandle& operand);
    618 
    619   // Enqueues an imaginary-part instruction onto the computation.
    620   ComputationDataHandle Imag(const ComputationDataHandle& operand);
    621 
    622   // Enqueues a float32 sqrt instruction onto the computation.
    623   // (float32 is specified as there is an implicit float32 0.5f constant
    624   // exponent).
    625   ComputationDataHandle SqrtF32(const ComputationDataHandle& operand);
    626 
    627   // Enqueues a float32 square instruction onto the computation.
    628   // (float32 is specified as there is an implicit float32 2.0f constant
    629   // exponent).
    630   ComputationDataHandle SquareF32(const ComputationDataHandle& operand);
    631 
    632   // Enqueues a lhs^rhs computation onto the computation.
    633   ComputationDataHandle Pow(
    634       const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
    635       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
    636 
    637   // Enqueues an operator that tests if the operand's values are finite, i.e.,
    638   // not Inf or NaN. Defined only for floating-point types. Returns an array of
    639   // booleans with the same shape where entries are true iff the corresponding
    640   // entry was NaN.
    641   ComputationDataHandle IsFinite(const ComputationDataHandle& operand);
    642 
    643   // Enqueues a convert instruction onto the computation that changes the
    644   // element type of the operand array to primitive_type.
    645   ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
    646                                            PrimitiveType new_element_type);
    647 
    648   // Enqueues a no-op instruction onto the computation that changes
    649   // the element type of the operand array to primitive_type. The
    650   // bit-widths of the source and destination element types must be
    651   // identical.
    652   ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand,
    653                                            PrimitiveType new_element_type);
    654 
    655   // Enqueues a float32 reciprocal instruction onto the computation.
    656   // (float32 is specified as there is an implicit float32 -1.0f constant
    657   // exponent).
    658   //
    659   // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
    660   // shape of the operand.
    661   ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand);
    662 
    663   // Enqueues a negate instruction onto the computation.
    664   ComputationDataHandle Neg(const ComputationDataHandle& operand);
    665 
    666   // Enqueues a transpose instruction onto the computation.
    667   ComputationDataHandle Transpose(
    668       const ComputationDataHandle& operand,
    669       tensorflow::gtl::ArraySlice<int64> permutation);
    670 
    671   // Enqueues a reverse instruction onto the computation. The order of the
    672   // elements in the given dimensions is reversed (i.e., the element at index i
    673   // is moved to index dimension_size - 1 - i).
    674   ComputationDataHandle Rev(const ComputationDataHandle& operand,
    675                             tensorflow::gtl::ArraySlice<int64> dimensions);
    676 
    677   // Enqueues a sort (as increasing order) instruction onto the computation.
    678   ComputationDataHandle Sort(const ComputationDataHandle& operand);
    679 
    680   // Enqueues a clamp instruction onto the computation.
    681   ComputationDataHandle Clamp(const ComputationDataHandle& min,
    682                               const ComputationDataHandle& operand,
    683                               const ComputationDataHandle& max);
    684 
    685   // Enqueues a map instruction onto the computation.
    686   ComputationDataHandle Map(
    687       tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
    688       const Computation& computation,
    689       tensorflow::gtl::ArraySlice<int64> dimensions,
    690       tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands = {});
    691 
    692   // Enqueues a N(mu, sigma) random number generation instruction onto the
    693   // computation.
    694   ComputationDataHandle RngNormal(const ComputationDataHandle& mu,
    695                                   const ComputationDataHandle& sigma,
    696                                   const Shape& shape);
    697 
    698   // Enqueues a U(a, b) random number generation instruction onto the
    699   // computation. Returns values in the semi-open interval [a, b).
    700   ComputationDataHandle RngUniform(const ComputationDataHandle& a,
    701                                    const ComputationDataHandle& b,
    702                                    const Shape& shape);
    703 
    704   // Enqueues a while node onto the computation.
    705   ComputationDataHandle While(const Computation& condition,
    706                               const Computation& body,
    707                               const ComputationDataHandle& init);
    708 
    709   // Enqueues a conditional node onto the computation.
    710   ComputationDataHandle Conditional(const ComputationDataHandle& predicate,
    711                                     const ComputationDataHandle& true_operand,
    712                                     const Computation& true_computation,
    713                                     const ComputationDataHandle& false_operand,
    714                                     const Computation& false_computation);
    715 
    716   // Enqueues a ReducePrecision node onto the computation.
    717   ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand,
    718                                         const int exponent_bits,
    719                                         const int mantissa_bits);
    720 
    721   // Enqueues a Gather node onto the computation.
    722   ComputationDataHandle Gather(
    723       const ComputationDataHandle& input,
    724       const ComputationDataHandle& gather_indices,
    725       const GatherDimensionNumbers& dimension_numbers,
    726       tensorflow::gtl::ArraySlice<int64> window_bounds);
    727 
    728   // Enqueues a Send node onto the computation, to send the given operand to
    729   // a Recv instruction that shares the same channel handle.
    730   void Send(const ComputationDataHandle& operand, const ChannelHandle& handle);
    731 
    732   // Enqueues a Recv node onto the computation. The data comes from a Send
    733   // instruction that shares the same channel handle and its shape must
    734   // be the same as the given shape.
    735   ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle);
    736 
    737   // Returns true if 'operand' is a compile-time constant. A compile-time
    738   // constant does not depend on parameters with index greater than or equal to
    739   // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`.
    740   // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a
    741   // compile-time constant without evaluating the computation.
    742   StatusOr<bool> IsConstant(const ComputationDataHandle& operand,
    743                             int64 num_parameters = 0);
    744 
    745   // Normalizes operand across spatial and batch dimensions for each feature.
    746   //
    747   // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
    748   // is the normalized result and batch_mean and batch_var are the mean and
    749   // variance, respectively, across batch for the operand.
    750   ComputationDataHandle BatchNormTraining(const ComputationDataHandle& operand,
    751                                           const ComputationDataHandle& scale,
    752                                           const ComputationDataHandle& offset,
    753                                           float epsilon, int64 feature_index);
    754 
    755   // Normalizes operand across spatial and batch dimensions for each feature.
    756   //
    757   // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
    758   // computing `mean` and `variance` for each batch inside the operation. It
    759   // uses the input `mean` and `variance` instead as estimated values. The
    760   // purpose of this op is to reduce latency in inference, hence the name
    761   // `BatchNormInference`.
    762   //
    763   // The output has the same shape as `operand`, and contains the normalized
    764   // values for each batch.
    765   ComputationDataHandle BatchNormInference(
    766       const ComputationDataHandle& operand, const ComputationDataHandle& scale,
    767       const ComputationDataHandle& offset, const ComputationDataHandle& mean,
    768       const ComputationDataHandle& variance, float epsilon,
    769       int64 feature_index);
    770 
    771   // Calculates the gradients of a batch norm op.
    772   //
    773   // The inputs `batch_mean` and `batch_var` represent the mean and variance
    774   // across the batch.
    775   //
    776   // Returns a tuple of three elements:
    777   //   - grad_operand: Gradient with respect to input `operand`
    778   //   - grad_offset: Gradient with respect to input `offset`
    779   //   - grad_scale: Gradient with respect to input `scale`
    780   ComputationDataHandle BatchNormGrad(const ComputationDataHandle& operand,
    781                                       const ComputationDataHandle& scale,
    782                                       const ComputationDataHandle& batch_mean,
    783                                       const ComputationDataHandle& batch_var,
    784                                       const ComputationDataHandle& grad_output,
    785                                       float epsilon, int64 feature_index);
    786 
    787   // Computes the value of a constant indicated by a
    788   // ComputationDataHandle using a non-optimized interpreter on the host.
    789   //
    790   // The operand must be from the computation currently being built -
    791   // i.e., returned from this builder with no intervening call to
    792   // Build(). This happens to currently work regardless of that, but
    793   // that may stop working at any time.
    794   //
    795   // The operand must represent a constant value, which in this case
    796   // means that it must not statically depend on any parameter of the
    797   // computation that is being built other then the ones specified on the
    798   // parameter list. The parameters in the list will be indexed by their
    799   // parameter id property so the number of parameters specified should be at
    800   // least as many as the largest used parameter index.
    801   //
    802   // `IsConstant` can be used to test whether a computation is a compile-time
    803   // constant without evaluation it. `ComputeConstant` only succeeds for
    804   // computations where `IsConstant` returns true.
    805   //
    806   // This functionality can be useful when translating a computation
    807   // into XLA where something that looked dynamic is required by
    808   // XLA to be specified as a constant. E.g. the source
    809   // computation (outside of XLA) may include a dynamic
    810   // computation of the shape of something and ComputeConstant lets
    811   // you determine what the value of that computation is in the case
    812   // where the value can be determined at compile time.
    813   //
    814   // If output_layout is non-null, then the output of the computation
    815   // will be stored using that layout.
    816   StatusOr<std::unique_ptr<Literal>> ComputeConstant(
    817       const ComputationDataHandle& operand,
    818       const Layout* output_layout = nullptr,
    819       tensorflow::gtl::ArraySlice<Literal> parameters = {});
    820 
    821   // Returns a new ComputationBuilder whose resultant Computation is used only
    822   // by this ComputationBuilder. The sub-ComputationBuilder has the same
    823   // die_immediately_on_error behavior as the parent.
    824   std::unique_ptr<ComputationBuilder> CreateSubBuilder(
    825       const string& computation_name);
    826 
    827   // Modifies the computation being built so that executions of it
    828   // will return the value associated with operand, rather than the
    829   // last expression enqueued on the ComputationBuilder. Any subsequent
    830   // operations added to the ComputationBuilder will not have any effect unless
    831   // SetReturnValue is called again.
    832   Status SetReturnValue(const ComputationDataHandle& operand);
    833 
    834   // Builds the computation with the requested operations, or returns a non-ok
    835   // status.
    836   StatusOr<Computation> Build();
    837 
    838   // Builds the computation with the requested operations, or notes an error in
    839   // the parent ComputationBuilder and returns an empty computation if building
    840   // failed. This function is intended to be used where the returned
    841   // Computation is only used by the parent ComputationBuilder and hence further
    842   // operation on the returned Computation will simply be error'ed out if an
    843   // error occurred while building this computation. If the built computation is
    844   // to be used by a ComputationBuilder other than the parent ComputationBuilder
    845   // then Build() should be used instead.
    846   Computation BuildAndNoteError();
    847 
    848   // Returns the first error that was encountered while building the
    849   // computation. When an error is encountered, by default we return a vacuous
    850   // ComputationDataHandle and inform the user of the error that occurred while
    851   // building the computation when they make a final call to Build().
    852   //
    853   // See also set_die_immediately_on_error().
    854   Status first_error() const { return first_error_; }
    855 
    856  private:
    857   // Limited checking of convolution parameters. Returns false on
    858   // error.
    859   bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape,
    860                          const ConvolutionDimensionNumbers& dimension_numbers);
    861 
    862   // The parent ComputationBuilder of a sub-ComputationBuilder. The
    863   // parent_builder_ will be the nullptr if not a sub-ComputationBuilder.
    864   ComputationBuilder* parent_builder_{nullptr};
    865 
    866   // Helper function for creating a Window proto from user-supplied
    867   // data. Returns true if the user-supplied data was valid.
    868   bool MakeWindow(tensorflow::gtl::ArraySlice<int64> window_dimensions,
    869                   tensorflow::gtl::ArraySlice<int64> window_strides,
    870                   tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
    871                   tensorflow::gtl::ArraySlice<int64> lhs_dilation,
    872                   tensorflow::gtl::ArraySlice<int64> rhs_dilation,
    873                   Window* window);
    874 
    875   // Internal helper method that does the building for an arbitrary unary op.
    876   ComputationDataHandle UnaryOp(UnaryOperation binop,
    877                                 const ComputationDataHandle& operand);
    878 
    879   // Internal helper method that does the building for an arbitrary binary op.
    880   // broadcast_dimensions specifies which dimensions to use for broadcasting
    881   // when the operation is between tensors of different ranks.
    882   ComputationDataHandle BinaryOp(
    883       BinaryOperation binop, const ComputationDataHandle& lhs,
    884       const ComputationDataHandle& rhs,
    885       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
    886 
    887   // Internal helper method that does the building for an arbitrary ternary op.
    888   ComputationDataHandle TernaryOp(TernaryOperation triop,
    889                                   const ComputationDataHandle& lhs,
    890                                   const ComputationDataHandle& rhs,
    891                                   const ComputationDataHandle& ehs);
    892 
    893   // Internal helper method that does the building for a random number generator
    894   // of a given distribution with an explicitly specified shape.
    895   ComputationDataHandle RngOp(
    896       RandomDistribution distribution,
    897       tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,
    898       const Shape& shape);
    899 
    900   // Populates computation_ with a valid object or returns a failing status.
    901   // This is used before any given operation is enqueued.
    902   Status PrepareComputation();
    903 
    904   // Notes that the error occurred by:
    905   // * storing it internally and capturing a backtrace if it's the first error
    906   //   (this deferred value will be produced on the call to Build())
    907   // * dying if die_immediately_on_error_ is true
    908   void NoteError(const Status& error);
    909 
    910   // Helper function that runs the given op_request, filling in op_response.
    911   // Before the op is run, PrepareComputation is called, and common fields in
    912   // the op_request are filled in.
    913   Status RunOp(OpRequest* op_request, OpResponse* op_response);
    914 
    915   // Helper function that calls RunOp and calls NoteError on failures.
    916   void RunOpAndNoteError(OpRequest* op_request);
    917 
    918   // Helper function that calls RunOp and either returns the output computation
    919   // data handle (on success) or a vacuous computation data handle (on failure).
    920   ComputationDataHandle RunOpAndParseResponse(OpRequest* op_request);
    921 
    922   // Helper function that implements GetShape without noting errors. This makes
    923   // it easier to ensure the real GetShape will note errors on every error path.
    924   StatusOr<std::unique_ptr<Shape>> GetShapeWithoutNoteError(
    925       const ComputationDataHandle& operand);
    926 
    927   string name_;  // Name to use for the built computation.
    928 
    929   // The first error encountered while building the computation.
    930   // This is OK until the first error is encountered.
    931   Status first_error_;
    932 
    933   // The saved stack trace from the point at which the first error occurred.
    934   tensorflow::SavedStackTrace first_error_backtrace_;
    935 
    936   // The computation that operations are enqueued onto.
    937   Computation computation_;
    938 
    939   // The client that the computation is created in. Not owned.
    940   Client* client_;
    941 
    942   // Mode bit that indicates whether to die when a first error is encountered.
    943   bool die_immediately_on_error_ = false;
    944 
    945   // The metadata to attach to each op. This is structured as a "modal"-like
    946   // operation, in order to simplify client code (and not sprinkle this metadata
    947   // throughout the TensorFlow op kernel implementations).
    948   OpMetadata metadata_;
    949 
    950   // Sharding for this operator. This is structured as a "model"-like operation,
    951   // in order to simplify client code, similar to metadata_.
    952   tensorflow::gtl::optional<OpSharding> sharding_;
    953 
    954   TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder);
    955 };
    956 
    957 template <typename NativeT>
    958 ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) {
    959   return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
    960 }
    961 
    962 template <typename NativeT>
    963 ComputationDataHandle ComputationBuilder::ConstantR1(
    964     tensorflow::gtl::ArraySlice<NativeT> values) {
    965   return ConstantLiteral(*Literal::CreateR1<NativeT>(values));
    966 }
    967 
    968 template <typename NativeT>
    969 ComputationDataHandle ComputationBuilder::ConstantR1(int64 length,
    970                                                      NativeT value) {
    971   Literal literal(ShapeUtil::MakeShape(
    972       primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
    973   literal.PopulateWithValue(value);
    974   return ConstantLiteral(literal);
    975 }
    976 
    977 inline ComputationDataHandle ComputationBuilder::ConstantR1(
    978     const tensorflow::core::Bitmap& values) {
    979   return ConstantLiteral(*Literal::CreateR1(values));
    980 }
    981 
    982 template <typename NativeT>
    983 ComputationDataHandle ComputationBuilder::ConstantR2(
    984     std::initializer_list<std::initializer_list<NativeT>> values) {
    985   return ConstantLiteral(*Literal::CreateR2<NativeT>(values));
    986 }
    987 
    988 template <typename NativeT>
    989 ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout(
    990     const Array<NativeT>& values, const Layout& layout) {
    991   return ConstantLiteral(
    992       *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
    993 }
    994 
    995 template <typename NativeT>
    996 ComputationDataHandle ComputationBuilder::ConstantFromArray(
    997     const Array<NativeT>& values) {
    998   return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values));
    999 }
   1000 
   1001 template <typename NativeT>
   1002 ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
   1003     const Array2D<NativeT>& values, const Layout& layout) {
   1004   return ConstantLiteral(
   1005       *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
   1006 }
   1007 
   1008 template <typename NativeT>
   1009 ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
   1010     const Array2D<NativeT>& values) {
   1011   return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values));
   1012 }
   1013 
   1014 template <typename NativeT>
   1015 ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout(
   1016     const Array3D<NativeT>& values, const Layout& layout) {
   1017   return ConstantLiteral(
   1018       *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
   1019 }
   1020 
   1021 template <typename NativeT>
   1022 ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D(
   1023     const Array3D<NativeT>& values) {
   1024   return ConstantFromArray(values);
   1025 }
   1026 
   1027 template <typename NativeT>
   1028 ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout(
   1029     const Array4D<NativeT>& values, const Layout& layout) {
   1030   return ConstantFromArrayWithLayout(values, layout);
   1031 }
   1032 
   1033 template <typename NativeT>
   1034 ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
   1035     const Array4D<NativeT>& values) {
   1036   return ConstantFromArray(values);
   1037 }
   1038 
   1039 // RAII-style object: sets the current sharding assignment in builder on
   1040 // construction, and sets back to the previous assignment on destruction.
   1041 class ScopedShardingAssignment {
   1042  public:
   1043   ScopedShardingAssignment(xla::ComputationBuilder* builder,
   1044                            tensorflow::gtl::optional<OpSharding> sharding)
   1045       : builder_(builder), prev_sharding_(builder->sharding()) {
   1046     SetSharding(sharding);
   1047   }
   1048 
   1049   ~ScopedShardingAssignment() { SetSharding(prev_sharding_); }
   1050 
   1051  private:
   1052   void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) {
   1053     if (sharding.has_value()) {
   1054       builder_->SetSharding(sharding.value());
   1055     } else {
   1056       builder_->ClearSharding();
   1057     }
   1058   }
   1059 
   1060   xla::ComputationBuilder* const builder_;
   1061   tensorflow::gtl::optional<OpSharding> prev_sharding_;
   1062 
   1063   TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment);
   1064 };
   1065 
   1066 }  // namespace xla
   1067 
   1068 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
   1069