Home | History | Annotate | Download | only in service
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_
     18 
     19 #include <utility>
     20 
     21 #include "absl/container/flat_hash_map.h"
     22 #include "absl/types/optional.h"
     23 #include "tensorflow/compiler/xla/service/hlo.pb.h"
     24 #include "tensorflow/compiler/xla/shape_tree.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 
     27 namespace xla {
     28 
     29 class HloModule;
     30 // We currently use an explicit API that takes an extra parameter to indicate
     31 // the runtime size of a dynamic dimension. DynamicParameterBinding indicates
     32 // the relationship between parameter: We can have a dynamic parameter that
     33 // points to another target parameter to indicate that the target parameter is
     34 // dynamic.
     35 //
     36 //
     37 // TODO(b/119520625): Remove this API once we have more dynamic shape infra
     38 // ready.
     39 class DynamicParameterBinding {
     40  public:
     41   // DynamicParameter represents a special parameter that is used to represent
     42   // the runtime size of a dimension of another parameter. A dynamic parameter
     43   // has to be a scalar value.
     44   struct DynamicParameter {
     45     // The parameter number of dynamic parameter.
     46     int64 parameter_num;
     47     // The index of the parameter.
     48     ShapeIndex parameter_index;
     49   };
     50 
     51   // DynamicDimension represents a dimension whose size is determined at
     52   // runtime. A DynamicDimension's runtime size is determined by the binded
     53   // DynamicParameter using `DynamicParameterBinding::Bind` method.
     54   struct DynamicDimension {
     55     // The parameter number of dynamic dimension.
     56     int64 parameter_num;
     57     // The subshape index of the parameter.
     58     ShapeIndex parameter_index;
     59     // The dimension number in the subshape.
     60     int64 dimension;
     61 
     62     // "friend" keyword are added so these functions can be found by ADL.
     63     template <typename H>
     64     friend H AbslHashValue(H h, const DynamicDimension& m) {
     65       return H::combine(std::move(h), m.parameter_num, m.parameter_index,
     66                         m.dimension);
     67     }
     68 
     69     friend bool operator==(const DynamicDimension& lhs,
     70                            const DynamicDimension& rhs) {
     71       return lhs.parameter_num == rhs.parameter_num &&
     72              lhs.parameter_index == rhs.parameter_index &&
     73              lhs.dimension == rhs.dimension;
     74     }
     75   };
     76 
     77   DynamicParameterBinding() = default;
     78 
     79   virtual ~DynamicParameterBinding() = default;
     80 
     81   // Adds binding which indicates that the dimension indicated by
     82   // `dynamic_dimension` is dynamic, and its runtime size is represented by
     83   // `dynamic_parameter`.
     84   Status Bind(const DynamicParameter& dynamic_parameter,
     85               const DynamicDimension& dynamic_dimension);
     86 
     87   // Returns the parameter and the index representing the runtime size of
     88   // dimension `dim_num` of parameter `param_num` at `param_index`.
     89   //
     90   // Returns nullopt if the binding is not set.
     91   absl::optional<DynamicParameter> GetBinding(
     92       const DynamicDimension& dynamic_dimension) const;
     93 
     94   using BindingFn =
     95       std::function<Status(const DynamicParameter& dynamic_parameter,
     96                            const DynamicDimension& dynamic_dimension)>;
     97 
     98   // Iterate through each binding.
     99   Status ForEachBinding(BindingFn fn) const;
    100 
    101   DynamicParameterBindingProto ToProto() const;
    102 
    103   static StatusOr<DynamicParameterBinding> CreateFromProto(
    104       const DynamicParameterBindingProto& proto);
    105 
    106   string ToString() const;
    107 
    108   // Verifies that the given binding is valid for the given module.
    109   // Specifically, the binding's parameter and parameter size should be valid.
    110   Status Verify(const HloModule& module) const;
    111 
    112  private:
    113   // Keeps track of mappings from DynamicDimension to DynamicParameter. The
    114   // direction of is chosen so that we can easily query if a dimension is
    115   // dynamic and which dynamic parameter represents the real size of that
    116   // dimension.
    117   absl::flat_hash_map<DynamicDimension, DynamicParameter> bindings_;
    118 };
    119 
    120 std::ostream& operator<<(std::ostream& out,
    121                          const DynamicParameterBinding& binding);
    122 
    123 }  // namespace xla
    124 
    125 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_
    126