Home | History | Annotate | Download | only in xla
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_
     17 #define TENSORFLOW_COMPILER_XLA_SHAPE_H_
     18 
     19 #include <string>
     20 #include <vector>
     21 
     22 #include "absl/types/optional.h"
     23 #include "tensorflow/compiler/xla/layout.h"
     24 #include "tensorflow/compiler/xla/primitive_util.h"
     25 #include "tensorflow/compiler/xla/types.h"
     26 #include "tensorflow/compiler/xla/xla_data.pb.h"
     27 #include "tensorflow/core/platform/types.h"
     28 
     29 namespace xla {
     30 
     31 // A shape describes the number of dimensions in a array, the bounds of each
     32 // dimension, and the primitive component type. For tuples, shape describes the
     33 // structure (number of elements and nesting).
     34 class Shape {
     35  public:
     36   Shape() = default;
     37 
     38   // Construct a shape from a ShapeProto.
     39   explicit Shape(const ShapeProto& shape_proto);
     40 
     41   // Returns a ShapeProto representation of the Shape.
     42   ShapeProto ToProto() const;
     43 
     44   // Returns a human-readable string that represents the given shape, with or
     45   // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]".
     46   string ToString(bool print_layout = false) const;
     47 
     48   // Returns the rank (number of dimensions) of the given shape. Shape must be
     49   // an array.
     50   int64 rank() const {
     51     CHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString();
     52     return dimensions_.size();
     53   }
     54 
     55   // Returns whether the shape is of the specified type (array, tuple, etc).
     56   bool IsArray() const { return primitive_util::IsArrayType(element_type()); }
     57   bool IsTuple() const { return element_type() == TUPLE; }
     58   bool IsToken() const { return element_type() == TOKEN; }
     59   bool IsOpaque() const { return element_type() == OPAQUE; }
     60 
     61   // Returns true if no array dimension in the shape is dynamically sized. Tuple
     62   // shapes are traversed recursively.
     63   bool is_static() const;
     64 
     65   // Returns true if the given dimension is dynamically-sized.
     66   bool is_dynamic_dimension(int dimension) const {
     67     return dynamic_dimensions_.at(dimension);
     68   }
     69 
     70   // Sets whether or not the given dimension is dynamically-sized.
     71   void set_dynamic_dimension(int dimension, bool is_dynamic) {
     72     dynamic_dimensions_[dimension] = is_dynamic;
     73   }
     74 
     75   const std::vector<bool>& dynamic_dimensions() const {
     76     return dynamic_dimensions_;
     77   }
     78 
     79   // Add dimension_upper_bound().
     80 
     81   // Removes the given dimension form the shape. Layout, if it exists, is
     82   // adjusted to match the modified shape.
     83   void DeleteDimension(int64 dim_to_delete);
     84 
     85   // The following methods mirror the protobuf generated code interface for the
     86   // message ShapeProto. This enabled easy migration of this data structure
     87   // from a proto to a proper C++ class.
     88   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
     89   // interface.
     90 
     91   // Methods for accessing the primitive type.
     92   PrimitiveType element_type() const { return element_type_; }
     93   void set_element_type(PrimitiveType value) { element_type_ = value; }
     94 
     95   // Methods for accessing the dimensions array.
     96   int dimensions_size() const { return dimensions_.size(); }
     97   int64 dimensions(int index) const { return dimensions_.at(index); }
     98   void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; }
     99   void add_dimensions(int64 value) {
    100     dimensions_.push_back(value);
    101     dynamic_dimensions_.push_back(false);
    102   }
    103   void clear_dimensions() {
    104     dimensions_.clear();
    105     dynamic_dimensions_.clear();
    106   }
    107   const std::vector<int64>& dimensions() const { return dimensions_; }
    108   absl::Span<int64> mutable_dimensions() { return absl::MakeSpan(dimensions_); }
    109 
    110   // Methods for accessing the tuple subshapes. This field only non-empty for
    111   // tuple shapes.
    112   int tuple_shapes_size() const { return tuple_shapes_.size(); }
    113   const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); }
    114   Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); }
    115   Shape* add_tuple_shapes() {
    116     tuple_shapes_.push_back(Shape());
    117     return &tuple_shapes_.back();
    118   }
    119   void clear_tuple_shapes() { tuple_shapes_.clear(); }
    120   const std::vector<Shape>& tuple_shapes() const { return tuple_shapes_; }
    121   std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; }
    122 
    123   // Methods for accessing the layout field.
    124   bool has_layout() const { return layout_.format() != INVALID_FORMAT; }
    125   const Layout& layout() const { return layout_; }
    126   Layout* mutable_layout() { return &layout_; }
    127   void clear_layout() { layout_.Clear(); }
    128 
    129   void Swap(Shape* other) {
    130     using std::swap;
    131     swap(*this, *other);
    132   }
    133 
    134   void Clear() {
    135     element_type_ = PRIMITIVE_TYPE_INVALID;
    136     dimensions_.clear();
    137     tuple_shapes_.clear();
    138     clear_layout();
    139   }
    140 
    141   string SerializeAsString() const { return ToProto().SerializeAsString(); }
    142   string ShortDebugString() const { return ToProto().ShortDebugString(); }
    143   string DebugString() const { return ToProto().DebugString(); }
    144 
    145   // Equal is a configurable functor to check the equality of two shapes.
    146   //
    147   // Examples:
    148   //
    149   // - Comparing two shapes ignoring their layout difference:
    150   //   Equal().IgnoreLayout()(shape1, shape2);
    151   //
    152   // - Comparing two shapes ignoring their layout and element type difference:
    153   //   Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2);
    154   class Equal {
    155    public:
    156     Equal() = default;
    157 
    158     bool operator()(const Shape& lhs, const Shape& rhs);
    159 
    160     Equal& IgnoreLayout() {
    161       ignore_layout_ = true;
    162       return *this;
    163     }
    164     Equal& IgnoreTilesInLayout() {
    165       ignore_tiles_in_layout_ = true;
    166       return *this;
    167     }
    168     Equal& IgnoreElementSizeInLayout() {
    169       ignore_element_size_in_layout_ = true;
    170       return *this;
    171     }
    172     Equal& IgnoreElementType() {
    173       ignore_element_type_ = true;
    174       return *this;
    175     }
    176     Equal& IgnoreFpPrecision() {
    177       ignore_fp_precision_ = true;
    178       return *this;
    179     }
    180     Equal& IgnoreDynamicDimension() {
    181       ignore_dynamic_dimension_ = true;
    182       return *this;
    183     }
    184 
    185    private:
    186     bool ignore_layout_ = false;
    187     bool ignore_tiles_in_layout_ = false;
    188     bool ignore_element_size_in_layout_ = false;
    189     bool ignore_element_type_ = false;
    190     bool ignore_fp_precision_ = false;
    191     bool ignore_dynamic_dimension_ = false;
    192   };
    193 
    194   // Test that all fields of the shape are the same, equivalent to Equal().
    195   bool operator==(const Shape& other) const { return Equal()(*this, other); }
    196   bool operator!=(const Shape& other) const { return !(*this == other); }
    197 
    198  private:
    199   // The element type of this shape (tuple, array, etc).
    200   PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID;
    201 
    202   // The array bounds of the dimensions. This is nonempty only for array
    203   // shapes. For a dynamically-sized dimension, the respective value in this
    204   // vector is an inclusive upper limit of the array bound.
    205   std::vector<int64> dimensions_;
    206 
    207   // This vector is the same size as 'dimensions_' and indicates whether the
    208   // respective dimension is dynamically sized.
    209   std::vector<bool> dynamic_dimensions_;
    210 
    211   // The tuple element subshapes. This is nonempty only for tuple shapes.
    212   std::vector<Shape> tuple_shapes_;
    213 
    214   // The layout of the shape. Only relevant for arrays.
    215   Layout layout_;
    216 };
    217 
    218 // Shape of the parameters and output of an XLA computation. This is analogous
    219 // to a traditional function signature.
    220 class ProgramShape {
    221  public:
    222   ProgramShape() = default;
    223 
    224   // Creates a ProgramShape from a ProgramShapeProto protobuf.
    225   explicit ProgramShape(const ProgramShapeProto& program_shape_proto);
    226 
    227   // Returns a proto representation of the object.
    228   ProgramShapeProto ToProto() const;
    229 
    230   string ToString() const;
    231 
    232   // The following methods mirror the protobuf generated code interface for the
    233   // message ProgramShapeProto. This enabled easy migration of this data
    234   // structure from a proto to a proper C++ class.
    235   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
    236   // interface.
    237 
    238   // Methods for accessing and manipulating the Shape of the parameters.
    239   int parameters_size() const { return parameters_.size(); }
    240   const Shape& parameters(int index) const { return parameters_.at(index); }
    241   Shape* mutable_parameters(int index) { return &parameters_.at(index); }
    242   Shape* add_parameters() {
    243     parameters_.emplace_back();
    244     return &parameters_.back();
    245   }
    246   void clear_parameters() { parameters_.clear(); }
    247   const std::vector<Shape>& parameters() const { return parameters_; }
    248   std::vector<Shape>* mutable_parameters() { return &parameters_; }
    249 
    250   // Methods for accessing and manipulating the Shape of the result.
    251   const Shape& result() const { return result_; }
    252   Shape* mutable_result() { return &result_; }
    253 
    254   // Methods for accessing and manipulating the names of the parameters.
    255   int parameter_names_size() const { return parameter_names_.size(); }
    256   const string& parameter_names(int index) const {
    257     return parameter_names_.at(index);
    258   }
    259   void set_parameter_names(int index, const string& value) {
    260     parameter_names_.at(index) = value;
    261   }
    262   string* mutable_parameter_names(int index) {
    263     return &parameter_names_.at(index);
    264   }
    265   void add_parameter_names(const string& value) {
    266     parameter_names_.push_back(value);
    267   }
    268   string* add_parameter_names() {
    269     parameter_names_.push_back("");
    270     return &parameter_names_.back();
    271   }
    272   void clear_parameter_names() { parameter_names_.clear(); }
    273   const std::vector<string>& parameter_names() const {
    274     return parameter_names_;
    275   }
    276   std::vector<string>* mutable_parameter_names() { return &parameter_names_; }
    277 
    278   string ShortDebugString() const { return ToProto().ShortDebugString(); }
    279   string DebugString() const { return ToProto().DebugString(); }
    280 
    281  private:
    282   // The shapes of the parameters of the computation represented by this object.
    283   std::vector<Shape> parameters_;
    284 
    285   // The names of the parameters of the computation represented by this object.
    286   std::vector<string> parameter_names_;
    287 
    288   // The shape of the result of the computation represented by this object.
    289   Shape result_;
    290 };
    291 
    292 std::ostream& operator<<(std::ostream& out, const Shape& shape);
    293 std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape);
    294 
    295 }  // namespace xla
    296 
    297 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_H_
    298