Home | History | Annotate | Download | only in service
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
     18 
     19 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
     20 #include "tensorflow/compiler/xla/service/hlo_module.h"
     21 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
     22 
     23 namespace xla {
     24 
     25 // A pass which folds F32 <-> BF16 conversions to their operands or users, when
     26 // it is supported by the backend.
     27 //
     28 // This pass follows the passed-in backend-specific BF16 support rules, but can
     29 // introduce mixed precision in individual HLOs which breaks the assumption of
     30 // some other HLO passes. So it should be used at the end of the HLO
     31 // optimization pipeline followed by a DCE pass. If other passes are needed
     32 // after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the
     33 // changed made by this pass.
     34 class BFloat16ConversionFolding : public HloPassInterface {
     35  public:
     36   explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support)
     37       : bfloat16_support_(bfloat16_support) {}
     38 
     39   ~BFloat16ConversionFolding() override = default;
     40   tensorflow::StringPiece name() const override { return "bfloat16-fold"; }
     41 
     42   // Run BF16 conversion folding on the given computation. Returns whether the
     43   // computation was changed.
     44   StatusOr<bool> Run(HloModule* module) override;
     45 
     46  private:
     47   const BFloat16Support* bfloat16_support_;
     48 };
     49 
     50 }  // namespace xla
     51 
     52 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
     53