Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 syntax = "proto3";
     17 
     18 package xla;
     19 option cc_enable_arenas = true;
     20 
     21 // Primitive types are the individual values that can be held in rectangular
     22 // multidimensional arrays. A description of the rectangular multidimensional
     23 // array dimensions / primitive type is given by Shape, below.
     24 enum PrimitiveType {
     25   // Invalid primitive type to serve as default.
     26   PRIMITIVE_TYPE_INVALID = 0;
     27 
     28   // Predicates are two-state booleans.
     29   PRED = 1;
     30 
     31   // Signed integral values of fixed width.
     32   S8 = 2;
     33   S16 = 3;
     34   S32 = 4;
     35   S64 = 5;
     36 
     37   // Unsigned integral values of fixed width.
     38   U8 = 6;
     39   U16 = 7;
     40   U32 = 8;
     41   U64 = 9;
     42 
     43   // Floating-point values of fixed width.
     44   //
     45   // Note: if f16s are not natively supported on the device, they will be
     46   // converted to f16 from f32 at arbirary points in the computation.
     47   F16 = 10;
     48   F32 = 11;
     49 
     50   // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
     51   // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
     52   // and 7 bits for the mantissa.
     53   BF16 = 16;
     54 
     55   F64 = 12;
     56 
     57   // Complex values of fixed width.
     58   C64 = 15;  // Paired F32 (real, imag), as in std::complex<float>.
     59 
     60   // A tuple is a polymorphic sequence; e.g. a shape that holds different
     61   // sub-shapes. They are used for things like returning multiple values from a
     62   // computation; e.g. a computation that returns weights and biases may have a
     63   // signature that results in a tuple like (f32[784x2000], f32[2000])
     64   //
     65   // If a shape proto has the tuple element type, it may not have any entries
     66   // in the dimensions field.
     67   TUPLE = 13;
     68 
     69   // An opaque type used for passing context specific data to a custom
     70   // operation.
     71   OPAQUE = 14;
     72 
     73   // Next = 17
     74 }
     75 
     76 // Describes the value held inside padding elements.
     77 enum PaddingValue {
     78   INVALID_PAD = 0;
     79 
     80   // Zero padding must be 0-values that correspond to the shape's element type.
     81   ZERO_PAD = 1;
     82 
     83   // One padding must be 1-values that correspond to the shape's element type.
     84   ONE_PAD = 2;
     85 
     86   // "Lowest" padding must be the lowest values in the shape's element type,
     87   // used as padding for operations like max-accumulation.
     88   LOWEST_PAD = 3;
     89 
     90   // "Highest" padding must be the largest values in the shape's element type,
     91   // used as padding for operations like min-accumulation.
     92   HIGHEST_PAD = 4;
     93 
     94   // Unknown padding could be anything; e.g. floating NaNs!
     95   UNKNOWN_PAD = 5;
     96 }
     97 
     98 // Describes the padding configuration for Pad operation. The padding amount on
     99 // both edges as well as between the elements are specified for each dimension.
    100 message PaddingConfig {
    101   // Describes the padding configuration for a dimension.
    102   message PaddingConfigDimension {
    103     // Padding amount on the low-end (next to the index 0).
    104     int64 edge_padding_low = 1;
    105 
    106     // Padding amount on the high-end (next to the highest index).
    107     int64 edge_padding_high = 2;
    108 
    109     // Padding amount between the elements.
    110     int64 interior_padding = 3;
    111   }
    112 
    113   // The padding configuration for all dimensions.
    114   repeated PaddingConfigDimension dimensions = 1;
    115 }
    116 
    117 // A format specifies the method used by a layout to store an array in memory.
    118 enum Format {
    119   INVALID_FORMAT = 0;
    120   // The default layout, with exactly one storage location per element (ignoring
    121   // padding).
    122   DENSE = 1;
    123   // A sparsely encoded layout, providing only the index/value pairs of non-zero
    124   // elements.
    125   SPARSE = 2;
    126 }
    127 
    128 // A layout describes how the array is placed in (1D) memory space.  This
    129 // includes the minor-to-major ordering of dimensions within a shape, as well as
    130 // any padding present in those dimensions.
    131 //
    132 // Clients must specify the layouts of input Literals to the
    133 // computation. Layouts specified in interior operations which take Shapes (for
    134 // example, Convert) are ignored.
    135 //
    136 // See the XLA documentation for more information on shapes and layouts.
    137 message Layout {
    138   // The method used to store the data in memory. The format determines which of
    139   // the other fields are used by the layout.
    140   Format format = 4;
    141 
    142   // Sequence of dimension numbers, from minor (fastest varying index) to major
    143   // (slowest varying index). This field is required.
    144   repeated int64 minor_to_major = 1;
    145 
    146   // The width to which the layout of each dimension is padded up to. If
    147   // present, the size of the padded_dimensions must equal the rank of the
    148   // shape. The padding appears at the end of a dimension, not at the
    149   // beginning. This kind of padding, unlike padding in e.g. convolution, is not
    150   // part of the shape. This field must be unset unless the format is DENSE.
    151   repeated int64 padded_dimensions = 2;
    152 
    153   // Describes the values in the padding specified by padded_dimensions. This
    154   // field must be unset unless the format is DENSE.
    155   PaddingValue padding_value = 3;
    156 
    157   // The maximum number of elements that can be stored for SPARSE formats.  This
    158   // can be used to determine the maximum size in bytes of arrays stored in
    159   // memory.  This field must be unset unless the format is SPARSE.
    160   int64 max_sparse_elements = 5;
    161 
    162   // Important: if any field is added, be sure to modify ShapeUtil::Equal()
    163   // appropriately to account for the new field.
    164 }
    165 
    166 // A shape describes the number of dimensions in the array, the size of each
    167 // dimension, and the primitive component type.
    168 //
    169 // Tuples are a special case in that they have rank zero and have tuple_shapes
    170 // defined.
    171 //
    172 // See the XLA documentation for more information on shapes and layouts.
    173 message Shape {
    174   reserved 1;
    175   reserved "rank";
    176 
    177   // The element type for this shape.
    178   PrimitiveType element_type = 2;
    179 
    180   // The size (number of elements) for each dimension.
    181   // In XLA, dimensions are numbered from 0 to N-1 for an
    182   // N-dimensional array. The first element of 'dimensions' is the size of
    183   // dimension 0, the second element is the size of dimension 1, and so forth.
    184   // Empty list indicates a scalar.
    185   repeated int64 dimensions = 3;
    186 
    187   // For tuples only, the shapes of constitutent shapes in the tuple sequence.
    188   repeated Shape tuple_shapes = 4;
    189 
    190   // The layout used to back this shape.
    191   Layout layout = 5;
    192 
    193   // Important: if any field is added, be sure to modify ShapeUtil::Equal() and
    194   // ShapeUtil::Compatible() appropriately to account for the new field.
    195 }
    196 
    197 // Shape of the parameters and output of a computation (like a traditional
    198 // function signature).
    199 message ProgramShape {
    200   repeated Shape parameters = 1;
    201   Shape result = 2;
    202   repeated string parameter_names = 3;
    203 }
    204 
    205 // Statistics of a computation.
    206 message ComputationStats {
    207   // The number of floating point operations in the computation.
    208   double flop_count = 1;
    209 
    210   // The number of transcendental operations (e.g., exp) in the computation.
    211   double transcendental_count = 2;
    212 }
    213 
    214 // Symbolization metadata for HLO Instructions.
    215 //
    216 // This metadata is used for debugging XLA code generation, as well as
    217 // performance profiling of XLA-generated executables.
    218 message OpMetadata {
    219   // The framework op name that generated this XLA op.
    220   //
    221   // Frameworks that build on top of XLA should mirror the names of their ops
    222   // back to users by specifying the op_type. In this way, even if the
    223   // framework's "ops" are implemented as multiple XLA HLO Ops, they can be
    224   // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as
    225   // multiple ops, then each op should have the op_type be "SoftMax".)
    226   string op_type = 1;
    227   // The user-specified name of the op.
    228   //
    229   // This name is often unique within a computation. Note: some frameworks
    230   // add auto-generated names if the user does not provide one.
    231   string op_name = 2;
    232   // Indicate a file and line that this op is associated to in a user's program.
    233   //
    234   // e.g. it could be the file and line of user code that generated the op.
    235   string source_file = 3;
    236   int32 source_line = 4;
    237 }
    238 
    239 // Profile data from the execution of a computation.
    240 message ExecutionProfile {
    241   // Whether the executable was read from the compilation cache.
    242   bool compilation_cache_hit = 1;
    243 
    244   // The time in milliseconds spent to compile the computation. This only set if
    245   // the executable was not read from the compilation cache
    246   // (compilation_cache_hit == false).
    247   int64 compile_time_ms = 2;
    248 
    249   // The number of cycles spent for the computation. This does not include the
    250   // time taken for the data transfers between the host and the device. This is
    251   // a target-dependent field and only used for debugging purposes.
    252   int64 compute_cycle_count = 3;
    253 
    254   // The time in nanoseconds spent for the computation, without data transfer.
    255   int64 compute_time_ns = 4;
    256 
    257   // The time in nanoseconds spent for the entire computation, including the
    258   // result data transfer time. Current implementation does not spend any cycles
    259   // for the input data transfer since the memory is initialized with the proper
    260   // values before the execution.
    261   int64 compute_and_transfer_time_ns = 5;
    262 }
    263 
    264 // Handle given to a user that represents a computation that the user builds up
    265 // before execution.
    266 message ComputationHandle {
    267   int64 handle = 1;
    268 }
    269 
    270 // Handle given to a user that represents an execution that the user launched
    271 // asynchronously on the device.
    272 message ExecutionHandle {
    273   int64 handle = 1;
    274 }
    275 
    276 // Handle given to a user that represents a globally accessible allocation.
    277 // Contrast this against a ComputationDataHandle, which is not globally
    278 // accessible, since it only exists within a specific computation.
    279 message GlobalDataHandle {
    280   int64 handle = 1;
    281 }
    282 
    283 // Handle given to a user that represents a data result in a computation.
    284 // This is used to pass to subsequent computations that depends upon the data as
    285 // an operand.
    286 message ComputationDataHandle {
    287   int64 handle = 1;
    288 }
    289 
    290 // Handle given to a user that represents a replicated virtual device. Each
    291 // replicated device represents N physical devices for execution where N is the
    292 // number of replicas.
    293 message DeviceHandle {
    294   int64 handle = 1;
    295 
    296   // The number of model-parallel virtual devices that communicate via XLA
    297   // Send/Recv instructions.
    298   int64 device_count = 2;
    299 }
    300 
    301 // Handle given to a user to represent a channel between two computations
    302 // via a Send and Recv instruction pair. Channels are unbuffered, so Send
    303 // Send instructions will be blocked until the data is transferred.
    304 message ChannelHandle {
    305   int64 handle = 1;
    306 }
    307 
    308 // DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
    309 // represents the device ids assigned to a set of replicated computations.
    310 // See xla::DeviceAssignment class comment for more details.
    311 message DeviceAssignmentProto {
    312   int32 replica_count = 1;
    313   int32 computation_count = 2;
    314 
    315   // Each logical computation runs on replica_count physical devices.
    316   // ComputationDevice represents the device ids assinged to the replicas.
    317   message ComputationDevice {
    318     repeated int32 replica_device_ids = 1;
    319   }
    320   repeated ComputationDevice computation_devices = 3;
    321 }
    322 
    323 // Literals are used when the server and client need to exchange materialized
    324 // data / results. Literals are also used to describe constants used in
    325 // computations.
    326 //
    327 // Transfers to/from the client are encoded in literal form, and the structure
    328 // of the repeated fields is implied by the shape.
    329 message LiteralProto {
    330   Shape shape = 1;
    331   repeated bool preds = 2;
    332   bytes u8s = 3;
    333   repeated int32 s32s = 4;
    334   repeated int64 s64s = 5;
    335   repeated uint32 u32s = 6;
    336   repeated uint64 u64s = 7;
    337   repeated float f32s = 8;
    338   repeated double f64s = 9;
    339   repeated float c64s = 12;  // Stored as interleaved real, imag floats.
    340   repeated LiteralProto tuple_literals = 10;
    341   // The F16s and BF16s are encoded in little endian byte order
    342   bytes f16s = 11;
    343   bytes bf16s = 13;
    344   repeated int64 sparse_indices = 14;
    345   // Next = 15
    346 }
    347 
    348 message WindowDimension {
    349   // The size of the window in this dimension. For a rectangle, this would be
    350   // the width or height.
    351   int64 size = 1;
    352 
    353   // The stride at which the window moves across the base area in this
    354   // dimension. In other words, this is the spacing between different
    355   // positions of the window in this dimension.
    356   int64 stride = 2;
    357 
    358   // If positive, means the amount of padding with zeroes to add to the base
    359   // area at the low end of this dimension; if negative, its negative means the
    360   // number of elements removed from the low end of this dimension. For example,
    361   // in the horizontal dimension of a rectangle, this would be the number of
    362   // zeroes to pad on the left, given that indices increase when going right.
    363   int64 padding_low = 3;
    364 
    365   // As padding_low, but on the high end of this dimension. For
    366   // example, in the horizontal dimension of a rectangle, this would
    367   // be the number of zeroes to pad on the right, given that indices
    368   // increase when going right.
    369   int64 padding_high = 4;
    370 
    371   // Dilation factor of the sliding window in this dimension. A dilation factor
    372   // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
    373   // implicitly placed between each kernel element. See documentation for
    374   // convolution.
    375   int64 window_dilation = 5;
    376 
    377   // Dilation factor of the base area in this dimension. A dilation factor of 1
    378   // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
    379   // placed between each base area element. See documentation for convolution.
    380   int64 base_dilation = 6;
    381 
    382   // Window reversal means that this dimension was logically reversed before the
    383   // operation.
    384   bool window_reversal = 7;
    385 }
    386 
    387 // Describes the windowing in an operation such as convolution.
    388 //
    389 // The window is moved across a base area and for each position of the
    390 // window a computation is performed. The field below describes the
    391 // window and the movement of the window across a base area.
    392 message Window {
    393   repeated WindowDimension dimensions = 1;
    394 }
    395 
    396 // Describes the dimension numbers for a gather operation.
    397 //
    398 // See https://www.tensorflow.org/performance/xla/operation_semantics#gather for
    399 // more details.
    400 message GatherDimensionNumbers {
    401   // "Window indices" is a term for a set of indices that index into the
    402   // interior of a dynamic-slice from the input tensor, the starting indices for
    403   // which were computed from output_gather_dims (see the operation semantic for
    404   // how this is defined) and the gather_indices tensor.
    405   //
    406   // The window indices for a specific output index Out is computed as:
    407   //
    408   //  i = 0
    409   //  for (k : [0, input_tensor_shape.rank))
    410   //    window_indices[k] =
    411   //      if k in elided_window_dims
    412   //      then 0
    413   //      else Out[output_window_dims[i++]]
    414   repeated int64 output_window_dims = 1;
    415   repeated int64 elided_window_dims = 2;
    416 
    417   // This is interpreted as a map from i to gather_dims_to_operand_dims[i]. It
    418   // transforms the gather index looked up from the gather_indices tensor into
    419   // the starting index in the input space.
    420   repeated int64 gather_dims_to_operand_dims = 3;
    421 }
    422 
    423 // Operation requests that are all collected as a tagged union with a oneof
    424 // field in OpRequest.
    425 
    426 message ConstantRequest {
    427   LiteralProto literal = 2;
    428 }
    429 
    430 message GetTupleElementRequest {
    431   ComputationDataHandle operand = 2;
    432   int64 index = 3;
    433 }
    434 
    435 message SliceRequest {
    436   ComputationDataHandle operand = 2;
    437   repeated int64 start_indices = 3;
    438   repeated int64 limit_indices = 4;
    439   repeated int64 strides = 5;
    440 }
    441 
    442 message DynamicSliceRequest {
    443   // Operand from which to slice at dynamic 'start_indices'.
    444   ComputationDataHandle operand = 2;
    445   // Dynamically computed 'start_indices' for slice operation.
    446   ComputationDataHandle start_indices = 3;
    447   // Slice sizes for each dimension (note that indices calculations are computed
    448   // modulo dimension sizes to avoid out-of-bound array accesses).
    449   repeated int64 slice_sizes = 4;
    450 }
    451 
    452 message DynamicUpdateSliceRequest {
    453   // Operand on which slice 'update' is to be applied.
    454   ComputationDataHandle operand = 2;
    455   // The slice update to apply to 'operand'.
    456   ComputationDataHandle update = 3;
    457   // Dynamically computed start indices for the update slice operation.
    458   ComputationDataHandle start_indices = 4;
    459 }
    460 
    461 message ConvolutionDimensionNumbers {
    462   // The number of the dimension that represents batch in the input.
    463   int64 input_batch_dimension = 7;
    464 
    465   // The number of the dimension that represents features in the input.
    466   int64 input_feature_dimension = 8;
    467 
    468   // The dimension numbers for the spatial dimensions that the window
    469   // moves through in the input.
    470   repeated int64 input_spatial_dimensions = 11;
    471 
    472   // The number of the dimension that represents input features in the
    473   // convolutional kernel (rhs).
    474   int64 kernel_input_feature_dimension = 3;
    475 
    476   // The number of the dimension that represents output features in
    477   // the convolutional kernel (rhs).
    478   int64 kernel_output_feature_dimension = 4;
    479 
    480   // The dimension numbers for the spatial dimensions that the window
    481   // moves through in the kernel (rhs). window.strides(0) is the
    482   // stride in the kernel_spatial_dimensions(0) dimension.
    483   repeated int64 kernel_spatial_dimensions = 6;
    484 
    485   // The number of the dimension that represents batch in the output.
    486   int64 output_batch_dimension = 9;
    487 
    488   // The number of the dimension that represents features in the output.
    489   int64 output_feature_dimension = 10;
    490 
    491   // The dimension numbers for the spatial dimensions that the window
    492   // moves through in the output.
    493   repeated int64 output_spatial_dimensions = 12;
    494 
    495   // Next = 13
    496 };
    497 
    498 message ConvolveRequest {
    499   ComputationDataHandle lhs = 2;
    500   ComputationDataHandle rhs = 3;  // This is the filter/kernel.
    501   Window window = 4;              // Describes the filter/kernel.
    502   ConvolutionDimensionNumbers dimension_numbers = 5;
    503 }
    504 
    505 enum FftType {
    506   FFT = 0;    // Forward FFT; complex in, complex out.
    507   IFFT = 1;   // Inverse FFT; complex in, complex out.
    508   RFFT = 2;   // Forward real FFT; real in, fft_length / 2 + 1 complex out
    509   IRFFT = 3;  // Inverse real FFT; fft_length / 2 + 1 complex in,
    510               //                   fft_length real out
    511 }
    512 
    513 message FftRequest {
    514   FftType fft_type = 1;
    515   repeated int64 fft_length = 2;  // Multivalent for higher-order FFT.
    516   ComputationDataHandle operand = 3;
    517 }
    518 
    519 message InfeedRequest {
    520   // The shape of the data returned by reading the device's infeed buffer.
    521   Shape shape = 2;
    522 
    523   // Additional infeed configuration for the backend.
    524   bytes config = 3;
    525 }
    526 
    527 message OutfeedRequest {
    528   // The shape of the data returned by reading the device's outfeed buffer.
    529   Shape shape = 1;
    530 
    531   // Operand to the Outfeed. Supports tuple.
    532   ComputationDataHandle operand = 2;
    533 
    534   // Backend-specific information for how to perform the outfeed.
    535   bytes outfeed_config = 3;
    536 }
    537 
    538 message CallRequest {
    539   ComputationHandle to_apply = 2;
    540   repeated ComputationDataHandle operands = 3;
    541 }
    542 
    543 message CustomCallRequest {
    544   string call_target_name = 2;
    545   repeated ComputationDataHandle operands = 3;
    546   Shape shape = 4;
    547 }
    548 
    549 message HostComputeRequest {
    550   // Operand to the HostCompute. Supports tuple.
    551   repeated ComputationDataHandle operands = 1;
    552 
    553   // Name used to identify HostSend/Recv channels.
    554   string channel_name = 2;
    555 
    556   // Cost estimate in nanoseconds.
    557   int64 cost_estimate_ns = 3;
    558 
    559   // The shape of any data returned by host.
    560   Shape shape = 4;
    561 }
    562 
    563 message DotDimensionNumbers {
    564   // The dimension numbers that represent the 'lhs' contracting dimensions.
    565   repeated int64 lhs_contracting_dimensions = 1;
    566   // The dimension numbers that represent the 'rhs' contracting dimensions.
    567   repeated int64 rhs_contracting_dimensions = 2;
    568   // The dimension numbers that represent the 'lhs' batch dimensions.
    569   repeated int64 lhs_batch_dimensions = 3;
    570   // The dimension numbers that represent the 'rhs' batch dimensions.
    571   repeated int64 rhs_batch_dimensions = 4;
    572 };
    573 
    574 message DotRequest {
    575   ComputationDataHandle lhs = 2;
    576   ComputationDataHandle rhs = 3;
    577   DotDimensionNumbers dimension_numbers = 4;
    578 }
    579 
    580 message MapRequest {
    581   repeated ComputationDataHandle operands = 2;
    582   ComputationHandle to_apply = 3;
    583   repeated ComputationDataHandle static_operands = 4;
    584   // The dimensions over which to map.
    585   // Example mapping a Dot operation along the batch dimension 0:
    586   //   operand0.shape = [2, 2, 2], operand1.shape = [2,2,3]
    587   //   Map({operand0, operand1}, Dot, {0})
    588   repeated int64 dimensions = 5;
    589 }
    590 
    591 message ReduceRequest {
    592   // Operand to the reduction.
    593   ComputationDataHandle operand = 2;
    594 
    595   // Initial value for the reduction. This must be consistent with the result
    596   // shape of to_apply.
    597   ComputationDataHandle init_value = 3;
    598 
    599   // The dimensions to reduce over.
    600   repeated int64 dimensions = 4;
    601 
    602   // The computation to apply in the reduction.
    603   ComputationHandle to_apply = 5;
    604 }
    605 
    606 message ReduceWindowRequest {
    607   ComputationDataHandle operand = 2;
    608   ComputationDataHandle init_value = 3;
    609   Window window = 4;
    610   ComputationHandle to_apply = 5;
    611 }
    612 
    613 message BatchNormTrainingRequest {
    614   ComputationDataHandle operand = 1;
    615   ComputationDataHandle scale = 2;
    616   ComputationDataHandle offset = 3;
    617   float epsilon = 4;
    618   int64 feature_index = 5;
    619 }
    620 
    621 message BatchNormInferenceRequest {
    622   ComputationDataHandle operand = 1;
    623   ComputationDataHandle scale = 2;
    624   ComputationDataHandle offset = 3;
    625   ComputationDataHandle mean = 4;
    626   ComputationDataHandle variance = 5;
    627   float epsilon = 6;
    628   int64 feature_index = 7;
    629 }
    630 
    631 message BatchNormGradRequest {
    632   ComputationDataHandle operand = 1;
    633   ComputationDataHandle scale = 2;
    634   ComputationDataHandle mean = 3;
    635   ComputationDataHandle variance = 4;
    636   ComputationDataHandle grad_output = 5;
    637   float epsilon = 6;
    638   int64 feature_index = 7;
    639 }
    640 
    641 message CrossReplicaSumRequest {
    642   ComputationDataHandle operand = 2;
    643 }
    644 
    645 message SelectAndScatterRequest {
    646   // Operand array on which the windows slide.
    647   ComputationDataHandle operand = 2;
    648 
    649   // Source array for the data to scatter.
    650   ComputationDataHandle source = 3;
    651 
    652   // Initial scalar value for each element in the output.
    653   ComputationDataHandle init_value = 4;
    654 
    655   // Window configuration.
    656   Window window = 5;
    657 
    658   // Binary function used to select an element from each window.
    659   ComputationHandle select = 6;
    660 
    661   // Binary function used to combine each scattered value from source with the
    662   // current output value at the selected location.
    663   ComputationHandle scatter = 7;
    664 }
    665 
    666 message ReverseRequest {
    667   ComputationDataHandle operand = 2;
    668   repeated int64 dimensions = 3;
    669 }
    670 
    671 message BroadcastRequest {
    672   ComputationDataHandle operand = 2;
    673   repeated int64 broadcast_sizes = 3;
    674 }
    675 
    676 message PadRequest {
    677   ComputationDataHandle operand = 2;
    678   ComputationDataHandle padding_value = 3;
    679   PaddingConfig padding_config = 4;
    680 }
    681 
    682 message ReshapeRequest {
    683   ComputationDataHandle operand = 2;
    684 
    685   // The dimension order for collapse (from fastest-changing to slowest).
    686   repeated int64 dimensions = 3;
    687 
    688   // The new dimension sizes (from dimension 0 to n-1).
    689   repeated int64 new_sizes = 4;
    690 }
    691 
    692 message TransposeRequest {
    693   ComputationDataHandle operand = 2;
    694 
    695   // The permutation of the operand's dimensions (in the range 0 to n-1).
    696   repeated int64 dimensions = 3;
    697 }
    698 
    699 message ParameterRequest {
    700   Shape shape = 2;
    701   int64 parameter = 3;
    702   string name = 4;
    703 }
    704 
    705 message GetLocalShapeRequest {
    706   ComputationHandle computation = 1;
    707   ComputationDataHandle operand = 2;
    708 }
    709 
    710 message GetLocalShapeResponse {
    711   Shape shape = 1;
    712 }
    713 
    714 message TraceRequest {
    715   string tag = 2;
    716   ComputationDataHandle operand = 3;
    717 }
    718 
    719 message ConvertRequest {
    720   ComputationDataHandle operand = 2;
    721   PrimitiveType new_element_type = 3;
    722 }
    723 
    724 message ConcatenateRequest {
    725   repeated ComputationDataHandle operands = 2;
    726   // The dimension in which we concatenate; e.g. if you had dimension arrays of
    727   // [4, 1] and [5, 1], you'd concatenate in dimension 0 to produce a [9, 1].
    728   // Attempting to concatenate those in dimension 1 would produce an error, as
    729   // 4 != 5 (and there is no ragged array support).
    730   int64 dimension = 3;
    731 }
    732 
    733 message ConditionalRequest {
    734   ComputationDataHandle predicate = 2;
    735   ComputationDataHandle true_operand = 3;
    736   ComputationHandle true_computation = 4;
    737   ComputationDataHandle false_operand = 5;
    738   ComputationHandle false_computation = 6;
    739 }
    740 
    741 message WhileRequest {
    742   ComputationHandle condition = 2;
    743   ComputationHandle body = 3;
    744   ComputationDataHandle init = 4;
    745 }
    746 
    747 enum UnaryOperation {
    748   UNOP_INVALID = 0;
    749 
    750   // Elementwise, logical negation on booleans and bitwise negation on ints.
    751   UNOP_NOT = 1;
    752 
    753   // Elementwise, computes e^x.
    754   UNOP_EXP = 2;
    755 
    756   // Elementwise, computes -x.
    757   UNOP_NEGATE = 3;
    758 
    759   // Puts the elements in the operand into sorted order.
    760   UNOP_SORT = 4;
    761 
    762   // Elementwise, computes tanh(x).
    763   UNOP_TANH = 5;
    764 
    765   // Elementwise, computes the natural logarithm of x.
    766   UNOP_LOG = 6;
    767 
    768   // Elementwise, computes the floor of x.
    769   UNOP_FLOOR = 7;
    770 
    771   // Elementwise, computes the ceil of x.
    772   UNOP_CEIL = 8;
    773 
    774   // Elementwise, computes the abs of x.
    775   UNOP_ABS = 9;
    776 
    777   // Elementwise, computes the sign of x.
    778   UNOP_SIGN = 10;
    779 
    780   // Elementwise, tests if values are finite (not NaN or inf)
    781   UNOP_IS_FINITE = 11;
    782 
    783   // Elementwise, computes the cosine of x.
    784   UNOP_COS = 12;
    785 
    786   // Elementwise, computes the sine of x.
    787   UNOP_SIN = 13;
    788 
    789   // Elementwise, rounds x to nearest integral value, rounding half-way cases
    790   // away from zero.
    791   UNOP_ROUND_NEAREST_AFZ = 14;
    792 
    793   // Elementwise, extract real component of complex x.
    794   UNOP_REAL = 15;
    795 
    796   // Elementwise, extract real component of complex x.
    797   UNOP_IMAG = 16;
    798 }
    799 
    800 message UnaryOpRequest {
    801   UnaryOperation unop = 2;
    802   ComputationDataHandle operand = 3;
    803 }
    804 
    805 enum BinaryOperation {
    806   BINOP_INVALID = 0;
    807 
    808   // Arithmetic operations.
    809   BINOP_ADD = 1;
    810   BINOP_DIV = 2;
    811   BINOP_MUL = 3;
    812   BINOP_SUB = 4;
    813 
    814   // Comparison operators.
    815   BINOP_EQ = 5;
    816   BINOP_GE = 6;
    817   BINOP_GT = 7;
    818   BINOP_LE = 8;
    819   BINOP_LT = 9;
    820   BINOP_NE = 10;
    821 
    822   // Element-wise maximum.
    823   BINOP_MAX = 14;
    824 
    825   // Element-wise minimum.
    826   BINOP_MIN = 15;
    827 
    828   // Raises the left-hand-side to the right-hand-side power.
    829   BINOP_POW = 16;
    830 
    831   // Remainder operation.
    832   BINOP_REM = 17;
    833 
    834   // Element-wise, logical operators on booleans and bitwise operators on ints.
    835   BINOP_AND = 18;
    836   BINOP_OR = 19;
    837 
    838   BINOP_SHIFT_LEFT = 20;
    839   BINOP_SHIFT_RIGHT_ARITHMETIC = 21;
    840   BINOP_SHIFT_RIGHT_LOGICAL = 22;
    841 
    842   // Complex from real, imag.
    843   BINOP_COMPLEX = 23;
    844 
    845   // Computes the 4-quadrant arctangent of the y, x input arguments.
    846   BINOP_ATAN2 = 24;
    847 }
    848 
    849 message BinaryOpRequest {
    850   BinaryOperation binop = 2;
    851   ComputationDataHandle lhs = 3;
    852   ComputationDataHandle rhs = 4;
    853   repeated int64 broadcast_dimensions = 5;
    854 }
    855 
    856 enum RandomDistribution {
    857   RNG_INVALID = 0;
    858 
    859   // Creates a uniform-distribution-generated random number on the semi-open
    860   // interval [parameter[0], parameter[1]).
    861   RNG_UNIFORM = 1;
    862 
    863   // Creates a normal-distribution-generated random number with mean
    864   // parameter[0] and standard deviation parameter[1].
    865   RNG_NORMAL = 2;
    866 
    867   // Next: 4
    868 }
    869 
    870 message RngRequest {
    871   RandomDistribution distribution = 2;
    872   repeated ComputationDataHandle parameter = 3;
    873   Shape shape = 4;
    874 }
    875 
    876 enum TernaryOperation {
    877   TRIOP_INVALID = 0;
    878 
    879   // Given a predicate and two operands, selects operand0 if the predicate is
    880   // true and operand1 if the predicate is false.
    881   TRIOP_SELECT = 1;
    882 
    883   // Given a min, max and an operand returns the operand if between min and max,
    884   // else returns min if operand is less than min or max if operand is greater
    885   // than max.
    886   TRIOP_CLAMP = 3;
    887 }
    888 
    889 message TernaryOpRequest {
    890   TernaryOperation triop = 2;
    891   ComputationDataHandle lhs = 3;
    892   ComputationDataHandle rhs = 4;
    893   ComputationDataHandle ehs = 5;
    894 }
    895 
    896 enum VariadicOperation {
    897   VAROP_INVALID = 0;
    898 
    899   // Creates a tuple from its operands.
    900   VAROP_TUPLE = 1;
    901 }
    902 
    903 message VariadicOpRequest {
    904   VariadicOperation varop = 2;
    905   repeated ComputationDataHandle operands = 3;
    906 }
    907 
    908 message ReducePrecisionRequest {
    909   ComputationDataHandle operand = 1;
    910   int32 exponent_bits = 2;
    911   int32 mantissa_bits = 3;
    912 }
    913 
    914 message SendRequest {
    915   ComputationDataHandle operand = 1;
    916   ChannelHandle channel_handle = 2;
    917 }
    918 
    919 message RecvRequest {
    920   Shape shape = 1;
    921   ChannelHandle channel_handle = 2;
    922 }
    923 
    924 message GatherRequest {
    925   ComputationDataHandle input = 1;
    926   ComputationDataHandle gather_indices = 2;
    927   GatherDimensionNumbers dimension_numbers = 3;
    928   repeated int64 window_bounds = 4;
    929 }
    930 
    931 message OpSharding {
    932   enum Type {
    933     // This sharding is replicated across all devices (implies maximal,
    934     // all other fields are unused).
    935     REPLICATED = 0;
    936     // This sharding is maximal - one device runs the entire operation.
    937     MAXIMAL = 1;
    938     // This sharding is a tuple - only the tuple_shardings field is valid.
    939     TUPLE = 2;
    940     // None of the above; tile_shape and tile_assignment are both used.
    941     OTHER = 3;
    942   }
    943   Type type = 1;
    944   // The shape of the sharded tile.
    945   Shape tile_shape = 2;
    946   // The shape of the tile assignment tensor - this must be the same rank as
    947   // tile_shape and the product of its dimensions must equal
    948   // tile_assignment_devices.size().
    949   repeated int64 tile_assignment_dimensions = 3;
    950   // Flattened list of device IDs. The order of flattening is the same as used
    951   // by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
    952   repeated int64 tile_assignment_devices = 4;
    953   // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
    954   // in pre-order. The tuple shape could be nested; here we store just a
    955   // flattened list of all leaves in the tuple shape. Note that the tuple shape
    956   // is not stored here; shardings do not store the shapes to which they are
    957   // applied, this is inferred from the instruction this sharding gets attached
    958   // to.
    959   repeated OpSharding tuple_shardings = 5;
    960 }
    961 
    962 message OpRequest {
    963   ComputationHandle computation = 1;
    964   OpMetadata metadata = 33;
    965   OpSharding sharding = 40;
    966 
    967   oneof op {
    968     BinaryOpRequest binary_op_request = 2;
    969     BroadcastRequest broadcast_request = 3;
    970     CallRequest call_request = 4;
    971     ConcatenateRequest concatenate_request = 5;
    972     ConstantRequest constant_request = 6;
    973     ConvertRequest convert_request = 7;
    974     ConvolveRequest convolve_request = 8;
    975     CrossReplicaSumRequest cross_replica_sum_request = 9;
    976     CustomCallRequest custom_call_request = 10;
    977     DotRequest dot_request = 43;
    978     DynamicSliceRequest dynamic_slice_request = 11;
    979     DynamicUpdateSliceRequest dynamic_update_slice_request = 12;
    980     GetTupleElementRequest get_tuple_element_request = 13;
    981     InfeedRequest infeed_request = 14;
    982     MapRequest map_request = 15;
    983     PadRequest pad_request = 16;
    984     ParameterRequest parameter_request = 17;
    985     ReducePrecisionRequest reduce_precision_request = 36;
    986     ReduceRequest reduce_request = 18;
    987     ReduceWindowRequest reduce_window_request = 19;
    988     ReshapeRequest reshape_request = 20;
    989     ReverseRequest reverse_request = 21;
    990     RngRequest rng_request = 22;
    991     SelectAndScatterRequest select_and_scatter_request = 23;
    992     SliceRequest slice_request = 24;
    993     TernaryOpRequest ternary_op_request = 25;
    994     TraceRequest trace_request = 26;
    995     TransposeRequest transpose_request = 34;
    996     UnaryOpRequest unary_op_request = 27;
    997     VariadicOpRequest variadic_op_request = 28;
    998     WhileRequest while_request = 29;
    999     SendRequest send_request = 30;
   1000     RecvRequest recv_request = 31;
   1001     OutfeedRequest outfeed_request = 32;
   1002     BatchNormTrainingRequest batch_norm_training_request = 35;
   1003     BatchNormGradRequest batch_norm_grad_request = 37;
   1004     BatchNormInferenceRequest batch_norm_inference_request = 38;
   1005     FftRequest fft_request = 41;
   1006     ConvertRequest bitcast_convert_request = 42;
   1007     ConditionalRequest conditional_request = 44;
   1008     HostComputeRequest host_compute_request = 45;
   1009     GatherRequest gather_request = 46;
   1010     // Next: 47
   1011   }
   1012 }
   1013 
   1014 message OpResponse {
   1015   ComputationDataHandle output = 1;
   1016 }
   1017