Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
     17 
     18 #include <memory>
     19 #include <string>
     20 #include <utility>
     21 #include <vector>
     22 
     23 #include "tensorflow/compiler/xla/layout_util.h"
     24 #include "tensorflow/compiler/xla/literal_util.h"
     25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     27 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
     28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     30 #include "tensorflow/compiler/xla/service/hlo_query.h"
     31 #include "tensorflow/compiler/xla/shape_util.h"
     32 #include "tensorflow/compiler/xla/types.h"
     33 #include "tensorflow/core/lib/core/errors.h"
     34 
     35 namespace xla {
     36 
     37 StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
     38   auto evaluator = MakeUnique<HloEvaluator>();
     39 
     40   XLA_VLOG_LINES(2,
     41                  "HloConstantFolding::Run(), before:\n" + module->ToString());
     42   bool changed = false;
     43 
     44   for (auto* computation : module->MakeNonfusionComputations()) {
     45     for (auto instruction : computation->MakeInstructionPostOrder()) {
     46       // Skip dead code.
     47       if (instruction->user_count() == 0 &&
     48           computation->root_instruction() != instruction) {
     49         continue;
     50       }
     51       // Skip Constant, Parameter, Reduce operation.
     52       // TODO(b/35975797): Enable Reduce operation once arbitrary computation
     53       // are supported by the evaluator.
     54       // TODO(b/64407269): Enable Tuple once the timeout issue is resolved.
     55       if (instruction->opcode() == HloOpcode::kParameter ||
     56           instruction->opcode() == HloOpcode::kConstant ||
     57           instruction->opcode() == HloOpcode::kTuple ||
     58           instruction->opcode() == HloOpcode::kReduce) {
     59         continue;
     60       }
     61       // Skip instructions with non-constant operands.
     62       if (!hlo_query::AllOperandsAreConstants(*instruction)) {
     63         continue;
     64       }
     65 
     66       // Broadcasts dramatically increase the size of constants, which is often
     67       // detrimental to performance and memory capacity, so do not fold
     68       // broadcasts.
     69       if (instruction->opcode() == HloOpcode::kBroadcast) {
     70         continue;
     71       }
     72 
     73       std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
     74       // Currently we skip unimplemented operations.
     75       // TODO(b/35975797): Fold constant computations for more operations.
     76       if (result == nullptr) {
     77         VLOG(2) << "Constant folding failed for instruction: "
     78                 << instruction->ToString();
     79         continue;
     80       }
     81 
     82       TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
     83           instruction, HloInstruction::CreateConstant(std::move(result))));
     84       changed = true;
     85     }
     86   }
     87   XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString());
     88   return changed;
     89 }
     90 
     91 }  // namespace xla
     92