Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Shape inference is used by the XLA service as the user builds up
     17 // computation requests.
     18 
     19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
     20 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
     21 
     22 #include <vector>
     23 
     24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     26 #include "tensorflow/compiler/xla/statusor.h"
     27 #include "tensorflow/compiler/xla/types.h"
     28 #include "tensorflow/compiler/xla/xla_data.pb.h"
     29 #include "tensorflow/core/lib/gtl/array_slice.h"
     30 #include "tensorflow/core/platform/macros.h"
     31 #include "tensorflow/core/platform/types.h"
     32 
     33 namespace xla {
     34 
     35 // For a given operation and input shapes, infers what the resulting shape is
     36 // for the operation. With this functionality, the user does not need to specify
     37 // the expected result type for computations that are built up via the API --
     38 // the shape that results from an operation is inferred. Some methods have
     39 // overloads for inferring shape at the HLO level.
     40 //
     41 // TODO(b/73352135): Shape inference does not issue very good error messages, in
     42 // part because HloInstruction::ToString() is not available since shape
     43 // inference runs before the HloInstruction object is created. We need a
     44 // solution for this.
     45 class ShapeInference {
     46  public:
     47   // Infers the shape produced by applying the given unary operation to the
     48   // given input shape.
     49   static StatusOr<Shape> InferUnaryOpShape(UnaryOperation operation,
     50                                            const Shape& arg);
     51   static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
     52                                            const HloInstruction* operand);
     53 
     54   // Infers the shape produced by applying the given binary operation to the
     55   // given input shapes.
     56   static StatusOr<Shape> InferBinaryOpShape(
     57       BinaryOperation operation, const Shape& lhs, const Shape& rhs,
     58       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
     59   static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
     60                                             const HloInstruction* lhs,
     61                                             const HloInstruction* rhs);
     62 
     63   // Infers the shape produced by applying the given ternary operation to the
     64   // given input shapes.
     65   static StatusOr<Shape> InferTernaryOpShape(TernaryOperation operation,
     66                                              const Shape& lhs, const Shape& rhs,
     67                                              const Shape& ehs);
     68   static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode,
     69                                              const HloInstruction* lhs,
     70                                              const HloInstruction* rhs,
     71                                              const HloInstruction* ehs);
     72 
     73   // Infers the shape produced by applying the given variadic operation to the
     74   // given input operand shapes.
     75   static StatusOr<Shape> InferVariadicOpShape(
     76       VariadicOperation operation,
     77       tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
     78   static StatusOr<Shape> InferVariadicOpShape(
     79       HloOpcode opcode,
     80       tensorflow::gtl::ArraySlice<const HloInstruction*> operands);
     81 
     82   // Infers the shape produced by applying the given mapping computation shape
     83   // to the given operand shapes.
     84   static StatusOr<Shape> InferMapShape(
     85       tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
     86       const ProgramShape& to_apply,
     87       tensorflow::gtl::ArraySlice<int64> dimensions);
     88 
     89   // Infers the shape produced by InferBatchNormTraining with the given
     90   // operands.
     91   static StatusOr<Shape> InferBatchNormTrainingShape(const Shape& operand_shape,
     92                                                      const Shape& scale_shape,
     93                                                      const Shape& offset_shape,
     94                                                      int64 feature_index);
     95 
     96   // Infers the shape produced by InferBatchNormInference with the given
     97   // operands.
     98   static StatusOr<Shape> InferBatchNormInferenceShape(
     99       const Shape& operand_shape, const Shape& scale_shape,
    100       const Shape& offset_shape, const Shape& mean_shape,
    101       const Shape& variance_shape, int64 feature_index);
    102 
    103   // Infers the shape produced by InferBatchNormGrad with the given operands.
    104   static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape,
    105                                                  const Shape& scale_shape,
    106                                                  const Shape& mean_shape,
    107                                                  const Shape& var_shape,
    108                                                  const Shape& output_grad_shape,
    109                                                  int64 feature_index);
    110 
    111   // Infers the shape produced by applying the given convolutional
    112   // filter (rhs) to lhs in the way specified by the fields on window.
    113   static StatusOr<Shape> InferConvolveShape(
    114       const Shape& lhs, const Shape& rhs, const Window& window,
    115       const ConvolutionDimensionNumbers& dimension_numbers);
    116 
    117   // Infers the shape produced by the given FFT type on the given operand.
    118   static StatusOr<Shape> InferFftShape(
    119       const Shape& in, FftType fft_type,
    120       tensorflow::gtl::ArraySlice<int64> fft_length);
    121 
    122   // Infers the shape produced a cross replica sum with the given operand
    123   // shapes.
    124   static StatusOr<Shape> InferCrossReplicaSumShape(
    125       tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
    126 
    127   // Infers the shape produced by applying the given reduction computation
    128   // shape to the given input operand shape.
    129   //
    130   // If pass_index is true, the reduce function is invoked with the element
    131   // index as the leading parameter, and the program shape should match
    132   // accordingly (or an error will result).
    133   static StatusOr<Shape> InferReduceShape(
    134       const Shape& arg, const Shape& init_value,
    135       tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
    136       const ProgramShape& to_apply);
    137 
    138   // Infers the shape produced by applying the given computation to the operand
    139   // shape with the given window and stride dimensions.
    140   static StatusOr<Shape> InferReduceWindowShape(
    141       const Shape& operand_shape, const Shape& init_value, const Window& window,
    142       const ProgramShape& to_apply_shape);
    143 
    144   // Infers the shape produced by scattering the given source shape to the
    145   // selected indices of each window on the operand shape.
    146   static StatusOr<Shape> InferSelectAndScatterShape(
    147       const Shape& operand_shape, const ProgramShape& select_shape,
    148       const Window& window, const Shape& source_shape,
    149       const Shape& init_value_shape, const ProgramShape& scatter_shape);
    150 
    151   // Infers the shape produced by a reverse operation that reverses the order
    152   // of the elements in the given dimensions.
    153   static StatusOr<Shape> InferReverseShape(
    154       const Shape& operand_shape,
    155       tensorflow::gtl::ArraySlice<int64> dimensions);
    156 
    157   // Infers the shape produced by a slice operation spanning from the starts to
    158   // the limits in the original shape's dimensions.
    159   //
    160   // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16]
    161   static StatusOr<Shape> InferSliceShape(
    162       const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
    163       tensorflow::gtl::ArraySlice<int64> limits,
    164       tensorflow::gtl::ArraySlice<int64> strides);
    165 
    166   // Infers the shape produced by a dynamic slice operation of size specified
    167   // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
    168   static StatusOr<Shape> InferDynamicSliceShape(
    169       const Shape& operand_shape, const Shape& start_indices_shape,
    170       tensorflow::gtl::ArraySlice<int64> slice_sizes);
    171 
    172   // Infers the shape produced by a dynamic update slice operation based
    173   // on the shape of operand and update.
    174   static StatusOr<Shape> InferDynamicUpdateSliceShape(
    175       const Shape& operand_shape, const Shape& update_shape,
    176       const Shape& start_indices_shape);
    177 
    178   // Infers the shape produced by doing a compile-time-constant indexing into
    179   // the given input shape. This is essential for operations on tuples, because
    180   // it is impossible to infer the type that comes out of the tuple indexing if
    181   // it is not a compile time constant.
    182   static StatusOr<Shape> InferGetTupleElementShape(const Shape& arg,
    183                                                    int64 index);
    184 
    185   // Infers the shape produced from a while node. condition and body are the
    186   // shapes of computations for the condition and the body of a while node, and
    187   // init is the shape of data initially passed in to the body as an argument.
    188   // The shapes must match; condition: T -> PRED, body: T -> T, init: T
    189   static StatusOr<Shape> InferWhileShape(const ProgramShape& condition,
    190                                          const ProgramShape& body,
    191                                          const Shape& init);
    192 
    193   // Infers the shape produced by a conditional operation.
    194   static StatusOr<Shape> InferConditionalShape(
    195       const Shape& predicate, const Shape& true_operand,
    196       const Shape& false_operand, const ProgramShape& true_computation,
    197       const ProgramShape& false_computation);
    198 
    199   // Infers the shape produced by a broadcast operation.
    200   static StatusOr<Shape> InferBroadcastShape(
    201       const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
    202 
    203   // Infers the shape produced by a reshape operation from the element type of
    204   // its operand and the new dimension sizes specified.
    205   static StatusOr<Shape> InferReshapeShape(
    206       const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
    207       tensorflow::gtl::ArraySlice<int64> new_sizes);
    208 
    209   // Infers the shape produced by a transpose operation from the element type of
    210   // its operand and its dimensions field.
    211   static StatusOr<Shape> InferTransposeShape(
    212       const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
    213 
    214   // Helper that infers the shape produced by performing a concatenate operation
    215   // with the given operand shapes.
    216   static StatusOr<Shape> InferConcatOpShape(
    217       tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension);
    218 
    219   // Helper that validates the given operand shape can be converted to the
    220   // target output_shape via a convert instruction -- the requirement is that
    221   // the shape is identical except for the element type.
    222   static StatusOr<Shape> InferConvertShape(const Shape& operand_shape,
    223                                            PrimitiveType new_element_type);
    224 
    225   // Helper that validates the given operand shape can be bitcast converted to
    226   // the target output_shape via a bitcast convert instruction -- the
    227   // requirement is that the shape is identical except for the element type and
    228   // the element types have identical bit-widths.
    229   static StatusOr<Shape> InferBitcastConvertShape(
    230       const Shape& operand_shape, PrimitiveType new_element_type);
    231 
    232   // Helper that validates the input data type for a reduce-precision operation,
    233   // and returns the result shape.
    234   static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape,
    235                                                    const int exponent_bits,
    236                                                    const int mantissa_bits);
    237 
    238   // Helper that infers the shape produced by a pad operation based on the
    239   // padding configuration.
    240   static StatusOr<Shape> InferPadShape(const Shape& operand_shape,
    241                                        const Shape& padding_value_shape,
    242                                        const PaddingConfig& padding_config);
    243 
    244   // Helper that validates the given arg_shapes are compatible with the shape of
    245   // the to_apply parameters, and returns the to_apply result shape.
    246   static StatusOr<Shape> InferCallShape(
    247       tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
    248       const ProgramShape& to_apply);
    249 
    250   // Helper that infers the shape produced by performing a dot operation with
    251   // the given LHS and RHS shapes.
    252   static StatusOr<Shape> InferDotOpShape(
    253       const Shape& lhs, const Shape& rhs,
    254       const DotDimensionNumbers& dimension_numbers);
    255 
    256   // Helper that infers the shape of the tensor produced by a gather operation
    257   // with the given input shape, gather indices shape and gather dimension
    258   // numbers.
    259   static StatusOr<Shape> InferGatherShape(
    260       const Shape& input_shape, const Shape& gather_indices_shape,
    261       const GatherDimensionNumbers& gather_dim_numbers,
    262       tensorflow::gtl::ArraySlice<int64> window_bounds);
    263 
    264  private:
    265   // Helper that infers the shape produced by performing an element-wise binary
    266   // operation with the given LHS and RHS shapes.
    267   // Note: By "element-wise" we mean operations that look at a single element in
    268   // the LHS and a single element in the RHS to produce a single output element,
    269   // even in the presence of broadcasting of one of the operands over the other.
    270   static StatusOr<Shape> InferElementwiseBinaryOpShape(
    271       BinaryOperation operation, const Shape& lhs, const Shape& rhs,
    272       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
    273 
    274   // Helper for inferring the shape of Clamp ops.
    275   static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand,
    276                                          const Shape& max);
    277 
    278   // Helper for inferring the shape of Select ops.
    279   static StatusOr<Shape> InferSelectShape(const Shape& pred,
    280                                           const Shape& on_true,
    281                                           const Shape& on_false);
    282 
    283   // Helper for inferring shapes of binary operations which use degenerate
    284   // dimension broadcasting (a dimension of size 1 in one operand is broadcast
    285   // up to match the size of the dimension in the other operand).
    286   static StatusOr<Shape> InferDegenerateDimensionBroadcastShape(
    287       BinaryOperation operation, const Shape& lhs, const Shape& rhs);
    288 
    289   // Helper for inferring shapes of binary operations using "InDim"
    290   // broadcasting. This is the broadcasting used in the *InDim binary operations
    291   // (for example ComputationBuilder::AddInDim). smaller_shape must be a
    292   // lower-rank shape than larger_shape. Returns the shape that the
    293   // smaller_shape is broadcast to.
    294   static StatusOr<Shape> InferInDimBroadcastShape(
    295       BinaryOperation operation, const Shape& smaller_shape,
    296       const Shape& larger_shape,
    297       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
    298 
    299   TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference);
    300 };
    301 
    302 }  // namespace xla
    303 
    304 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
    305