Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/inliner.h"
     17 
     18 #include <memory>
     19 #include <string>
     20 
     21 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     25 #include "tensorflow/compiler/xla/service/hlo_query.h"
     26 #include "tensorflow/compiler/xla/status_macros.h"
     27 #include "tensorflow/compiler/xla/types.h"
     28 #include "tensorflow/core/lib/core/errors.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 #include "tensorflow/core/lib/gtl/array_slice.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 
     33 namespace xla {
     34 
     35 // InlinerVisitor traverses the HLO computation and inlines maps.
     36 class InlinerVisitor : public DfsHloVisitorWithDefault {
     37  public:
     38   explicit InlinerVisitor(HloComputation* computation)
     39       : computation_(computation) {}
     40 
     41   // Default visitor action is to do nothing and return OK.
     42   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
     43     return Status::OK();
     44   }
     45 
     46   Status HandleMap(HloInstruction* map) override;
     47 
     48   // Runs the visitor on a computation.
     49   StatusOr<bool> Run(HloComputation* computation);
     50 
     51  private:
     52   // Current HloComputation instance the InlinerVisitor is traversing.
     53   HloComputation* computation_;
     54 
     55   // Whether algebraic simplification has occurred.
     56   bool changed_ = false;
     57 };
     58 
     59 StatusOr<bool> InlinerVisitor::Run(HloComputation* computation) {
     60   changed_ = false;
     61   computation_ = computation;
     62   TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this));
     63   return changed_;
     64 }
     65 
     66 Status InlinerVisitor::HandleMap(HloInstruction* map) {
     67   HloComputation* function = map->to_apply();
     68   HloInstruction& root = *function->root_instruction();
     69   // TODO(b/29249531): Add DCE pass to remove unused HloComputations.
     70   // Only inlining functions that are simply a single operation until a better
     71   // profitability model for inlining is defined.
     72   if (hlo_query::AllOperandsAreParameters(root)) {
     73     if (root.opcode() == HloOpcode::kFusion ||
     74         root.opcode() == HloOpcode::kParameter ||
     75         root.opcode() == HloOpcode::kTrace) {
     76       // Cloning not supported for these instructions.
     77       return Status::OK();
     78     }
     79     VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function "
     80              << root.ToShortString();
     81     // If the input is a constant then the shape of the constant could be
     82     // different than the map shape. Hence, a broadcast is needed, else the
     83     // cloned operand with new shape and operands work.
     84     if (root.opcode() != HloOpcode::kConstant) {
     85       std::vector<HloInstruction*> params;
     86       for (int64 o = 0; o < root.operands().size(); o++) {
     87         params.push_back(map->operands()[root.operand(o)->parameter_number()]);
     88       }
     89       HloInstruction* placed_instruction = computation_->AddInstruction(
     90           root.CloneWithNewOperands(map->shape(), params));
     91       TF_RETURN_IF_ERROR(
     92           computation_->ReplaceInstruction(map, placed_instruction));
     93     } else {
     94       // The constant is in an embedded computation and needs to be recreated
     95       // as part of the computation that the broadcast is inserted into.
     96       HloInstruction* constant = computation_->AddInstruction(root.Clone());
     97       HloInstruction* placed_instruction = computation_->AddInstruction(
     98           HloInstruction::CreateBroadcast(map->shape(), constant, {}));
     99       TF_RETURN_IF_ERROR(
    100           computation_->ReplaceInstruction(map, placed_instruction));
    101     }
    102     changed_ = true;
    103     return Status::OK();
    104   }
    105 
    106   return Status::OK();
    107 }
    108 
    109 StatusOr<bool> Inliner::Run(HloModule* module) {
    110   InlinerVisitor visitor(/*computation=*/nullptr);
    111   bool changed = false;
    112   for (HloComputation* computation : module->computations()) {
    113     TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation));
    114     changed |= computation_changed;
    115   }
    116   return changed;
    117 }
    118 
    119 }  // namespace xla
    120