Home | History | Annotate | Download | only in service
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h"
     17 
     18 #include <algorithm>
     19 #include <memory>
     20 #include <numeric>
     21 #include <string>
     22 #include <utility>
     23 #include <vector>
     24 
     25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     27 #include "tensorflow/compiler/xla/service/hlo_dce.h"
     28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     30 #include "tensorflow/compiler/xla/shape_util.h"
     31 #include "tensorflow/compiler/xla/status_macros.h"
     32 #include "tensorflow/compiler/xla/types.h"
     33 #include "tensorflow/compiler/xla/util.h"
     34 #include "tensorflow/core/lib/core/errors.h"
     35 #include "tensorflow/core/lib/core/status.h"
     36 #include "tensorflow/core/platform/logging.h"
     37 #include "tensorflow/core/platform/types.h"
     38 
     39 namespace xla {
     40 
     41 namespace {
     42 
     43 // Visitor for removing implicit broadcasts.
     44 class ImplicitBroadcastVisitor : public DfsHloVisitorWithDefault {
     45  public:
     46   Status DefaultAction(HloInstruction* hlo_instruction) override {
     47     return Status::OK();
     48   }
     49 
     50   Status HandleElementwiseBinary(HloInstruction* hlo) override {
     51     return ReplaceImplicitBroadcastOperands(hlo);
     52   }
     53 
     54   Status HandleClamp(HloInstruction* hlo) override {
     55     // Clamp is the only element-wise ternary operation.
     56     return ReplaceImplicitBroadcastOperands(hlo);
     57   }
     58 
     59   // Returns whether any modification has been made to any visited instruction.
     60   bool changed() const { return changed_; }
     61 
     62  private:
     63   // Iterates through the operands of 'hlo' and replace any operands which are
     64   // implicitly broadcast with the equivalent sequence of broadcast and reshape
     65   // instructions. An operand is considered to be implicitly broadcast if the
     66   // operand shape does have the same dimensions as the shape of 'hlo'.
     67   Status ReplaceImplicitBroadcastOperands(HloInstruction* hlo) {
     68     auto fadd = [hlo](std::unique_ptr<HloInstruction> x) {
     69       return hlo->parent()->AddInstruction(std::move(x));
     70     };
     71     std::vector<HloInstruction*> operands;
     72     bool operands_changed = false;
     73     for (int i = 0; i < hlo->operand_count(); ++i) {
     74       HloInstruction* operand = hlo->mutable_operand(i);
     75       if (!ShapeUtil::SameDimensions(hlo->shape(), operand->shape())) {
     76         HloInstruction* new_operand = hlo->parent()->AddInstruction(
     77             HloInstruction::CreateBroadcastSequence(hlo->shape(), operand,
     78                                                     fadd));
     79         operands.push_back(new_operand);
     80         operands_changed = true;
     81       } else {
     82         operands.push_back(operand);
     83       }
     84     }
     85     if (operands_changed) {
     86       // Create a new HLO instruction because the HloInstruction::Replace*
     87       // methods check that the shape does not change with the replacement.
     88       HloInstruction* new_hlo = hlo->parent()->AddInstruction(
     89           hlo->CloneWithNewOperands(hlo->shape(), operands));
     90       TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo));
     91       changed_ = true;
     92     }
     93     return Status::OK();
     94   }
     95 
     96   bool changed_ = false;
     97 };
     98 
     99 }  // namespace
    100 
    101 StatusOr<bool> ImplicitBroadcastRemover::Run(HloModule* module) {
    102   VLOG(1) << "Removing implicit broadcast from module " << module->name();
    103   XLA_VLOG_LINES(2,
    104                  "Before removing implicit broadcasts:\n" + module->ToString());
    105 
    106   ImplicitBroadcastVisitor visitor;
    107   for (HloComputation* computation : module->computations()) {
    108     TF_RETURN_IF_ERROR(computation->Accept(&visitor));
    109   }
    110 
    111   if (visitor.changed()) {
    112     // HLO instructions with implicitly broadcast operands are cloned and left
    113     // for dead. Remove them.
    114     HloDCE dce;
    115     TF_RETURN_IF_ERROR(dce.Run(module).status());
    116   }
    117 
    118   XLA_VLOG_LINES(2,
    119                  "After removing implicit broadcasts:\n" + module->ToString());
    120 
    121   return visitor.changed();
    122 }
    123 
    124 }  // namespace xla
    125