Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
     18 
     19 #include <utility>
     20 
     21 #include "tensorflow/compiler/xla/service/hlo_module.h"
     22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
     23 
     24 namespace xla {
     25 
     26 class AlgebraicSimplifierOptions {
     27  public:
     28   AlgebraicSimplifierOptions() {}
     29   // Platform dependent callback to determine if a reshape `from_shape` to
     30   // `to_shape` is a bitcast.
     31   using ReshapeIsBitcastCallback =
     32       std::function<bool(const Shape& from_shape, const Shape& to_shape)>;
     33   explicit AlgebraicSimplifierOptions(
     34       ReshapeIsBitcastCallback reshape_is_bitcast_callback)
     35       : reshape_is_bitcast_callback_(std::move(reshape_is_bitcast_callback)) {}
     36 
     37   // Use the platform specific callback if set. It is not sensible to return
     38   // true here if the options are not layout sensitive.
     39   bool ReshapeIsBitcast(const Shape& from_shape, const Shape& to_shape) const {
     40     if (!is_layout_sensitive_) {
     41       return false;
     42     }
     43     if (!reshape_is_bitcast_callback_) {
     44       return ShapeUtil::ReshapeIsBitcast(from_shape, to_shape);
     45     }
     46     return reshape_is_bitcast_callback_(from_shape, to_shape);
     47   }
     48 
     49   // If is_layout_sensitive is true, then the simplifier preserves layout during
     50   // transformation. Otherwise, layout is ignored.
     51   void set_is_layout_sensitive(bool is_layout_sensitive) {
     52     is_layout_sensitive_ = is_layout_sensitive;
     53   }
     54 
     55   bool is_layout_sensitive() const { return is_layout_sensitive_; }
     56 
     57   // Enable dot simplification on platforms where it is profitable.
     58   void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) {
     59     enable_dot_strength_reduction_ = enable_dot_strength_reduction;
     60   }
     61 
     62   bool enable_dot_strength_reduction() const {
     63     return enable_dot_strength_reduction_;
     64   }
     65 
     66   // Enable convolution simplification on platforms where it is profitable.
     67   void set_enable_conv_simplification(bool enable_conv_simplification) {
     68     enable_conv_simplification_ = enable_conv_simplification;
     69   }
     70   bool enable_conv_simplification() const {
     71     return enable_conv_simplification_;
     72   }
     73 
     74   // If enable_window_reduce_replacement is true, the kReduceWindow instruction
     75   // can be optimized by replacement with simpler operations.
     76   void set_enable_window_reduce_to_reduce_replacement(
     77       bool enable_window_reduce_to_reduce_replacement) {
     78     enable_window_reduce_to_reduce_replacement_ =
     79         enable_window_reduce_to_reduce_replacement;
     80   }
     81 
     82   bool enable_window_reduce_to_reduce_replacement() const {
     83     return enable_window_reduce_to_reduce_replacement_;
     84   }
     85 
     86  private:
     87   ReshapeIsBitcastCallback reshape_is_bitcast_callback_;
     88   bool is_layout_sensitive_{false};
     89   bool enable_dot_strength_reduction_{true};
     90   bool enable_conv_simplification_{true};
     91   bool enable_window_reduce_to_reduce_replacement_{true};
     92 };
     93 
     94 // A pass which performs algebraic simplifications.
     95 class AlgebraicSimplifier : public HloModulePass {
     96  public:
     97   // If is_layout_sensitive is true, then the simplifier preserves layout during
     98   // transformation. Otherwise, layout is ignored.
     99   explicit AlgebraicSimplifier(const AlgebraicSimplifierOptions& options)
    100       : options_(options) {}
    101   ~AlgebraicSimplifier() override = default;
    102   absl::string_view name() const override { return "algsimp"; }
    103 
    104   // Run algebraic simplification on the given computation. Returns whether the
    105   // computation was changed.
    106   StatusOr<bool> Run(HloModule* module) override;
    107 
    108  private:
    109   AlgebraicSimplifierOptions options_;
    110 };
    111 
    112 }  // namespace xla
    113 
    114 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
    115