Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
     17 
     18 #include <memory>
     19 #include <string>
     20 #include <utility>
     21 #include <vector>
     22 
     23 #include "tensorflow/compiler/xla/layout_util.h"
     24 #include "tensorflow/compiler/xla/literal_util.h"
     25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     27 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
     28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     30 #include "tensorflow/compiler/xla/service/hlo_query.h"
     31 #include "tensorflow/compiler/xla/shape_util.h"
     32 #include "tensorflow/compiler/xla/types.h"
     33 #include "tensorflow/core/lib/core/errors.h"
     34 
     35 namespace xla {
     36 namespace {
     37 
     38 HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) {
     39   if (hlo->shape().element_type() != type) {
     40     Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
     41     hlo = hlo->parent()->AddInstruction(
     42         HloInstruction::CreateConvert(shape, hlo));
     43   }
     44   CHECK_EQ(hlo->shape().element_type(), type);
     45   return hlo;
     46 }
     47 
     48 bool HasOperandType(HloInstruction* hlo, PrimitiveType type) {
     49   for (HloInstruction* operand : hlo->operands()) {
     50     if (operand->shape().element_type() == type) {
     51       return true;
     52     }
     53   }
     54   return false;
     55 }
     56 
     57 // Finds out the Tuple Shape of the new instruction after converting the element
     58 // type of the operands of the original instruction from `from_type` to
     59 // `to_type`.
     60 //
     61 // This routine assumes the resulting `shape` of the original instruction is a
     62 // non-nested tuple. This assumption is currently safe as only kTuple, kInfeed,
     63 // kOutfeed, kCall, kCustomCall and kBatchNorm* HLO instructions can produce
     64 // results with tuple shapes, and this routine is only called to convert the
     65 // result shapes of kBatchNorm* HLO instructions, which are non-nested tuples.
     66 Shape GetConvertedTupleShape(const Shape& shape, PrimitiveType from_type,
     67                              PrimitiveType to_type) {
     68   std::vector<Shape> new_tuple_subshapes;
     69   for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
     70     Shape subshape = ShapeUtil::GetTupleElementShape(shape, i);
     71     CHECK(!ShapeUtil::IsTuple(subshape));
     72     if (subshape.element_type() == from_type) {
     73       subshape = ShapeUtil::ChangeElementType(subshape, to_type);
     74     }
     75     new_tuple_subshapes.push_back(subshape);
     76   }
     77   return ShapeUtil::MakeTupleShape(new_tuple_subshapes);
     78 }
     79 
     80 // Converts the elements of the result of `hlo` to produce a new tuple with
     81 // shape `to_shape`.
     82 //
     83 // This routine assumes `hlo` is an instruction that produces a non-nested Tuple
     84 // as a result.
     85 HloInstruction* ConvertTupleElements(HloInstruction* hlo,
     86                                      const Shape& to_shape) {
     87   const Shape& shape = hlo->shape();
     88   HloComputation* computation = hlo->parent();
     89   std::vector<HloInstruction*> tuple_elements;
     90   for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
     91     const Shape& ele_shape = ShapeUtil::GetTupleElementShape(shape, i);
     92     HloInstruction* element = computation->AddInstruction(
     93         HloInstruction::CreateGetTupleElement(ele_shape, hlo, i));
     94     const Shape& to_ele_shape = ShapeUtil::GetTupleElementShape(to_shape, i);
     95     CHECK(!ShapeUtil::IsTuple(ele_shape));
     96     if (ele_shape.element_type() != to_ele_shape.element_type()) {
     97       element = computation->AddInstruction(
     98           HloInstruction::CreateConvert(to_ele_shape, element));
     99     }
    100     tuple_elements.push_back(element);
    101   }
    102   return computation->AddInstruction(
    103       HloInstruction::CreateTuple(tuple_elements));
    104 }
    105 
    106 }  // namespace
    107 
    108 HloElementTypeConverter::HloElementTypeConverter(
    109     PrimitiveType eliminate_type, PrimitiveType replace_with_type)
    110     : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {}
    111 
    112 // This routine converts the arithmetic operations in the given module that use
    113 // eliminate_type_ to operations that use replace_with_type_.
    114 StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
    115   XLA_VLOG_LINES(
    116       3, "HloElementTypeConverter::Run(), before:\n" + module->ToString());
    117 
    118   if (eliminate_type_ == replace_with_type_) {
    119     return false;
    120   }
    121 
    122   bool changed = false;
    123   for (auto* computation : module->computations()) {
    124     for (auto* hlo : computation->MakeInstructionPostOrder()) {
    125       const auto opcode = hlo->opcode();
    126       // These are ops where it does not make sense to convert them.
    127       if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
    128           opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert ||
    129           opcode == HloOpcode::kGetTupleElement ||
    130           opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) {
    131         continue;
    132       }
    133 
    134       // We cannot change a CustomCall since we have no way of adjusting the
    135       // called binary to expect the updated type.
    136       if (opcode == HloOpcode::kCustomCall) {
    137         continue;
    138       }
    139 
    140       // These are ops with embedded computations where it suffices to convert
    141       // the embedded computations instead of converting the ops themselves.
    142       if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
    143           opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap ||
    144           opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow ||
    145           opcode == HloOpcode::kSelectAndScatter ||
    146           opcode == HloOpcode::kConditional) {
    147         continue;
    148       }
    149       TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
    150 
    151       if (!HasOperandType(hlo, eliminate_type_)) {
    152         // If this CHECK fires, then this was an instruction that does not take
    153         // the elimination type as an operand but it does return it. This pass
    154         // does not have a feature to change the output type in that case, so
    155         // instead of silently failing to eliminate the type, it fails loudly.
    156         TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_);
    157         continue;
    158       }
    159 
    160       // Handle instructions that perform arithmetic operations and contain
    161       // operands with eliminate_type_.
    162       //
    163       // First, convert the operands with eliminate_type_ to operands with
    164       // replace_with_type_.
    165       std::vector<HloInstruction*> new_operands;
    166       for (HloInstruction* operand : hlo->operands()) {
    167         if (operand->shape().element_type() == eliminate_type_) {
    168           operand = ToElementType(operand, replace_with_type_);
    169         }
    170         new_operands.push_back(operand);
    171       }
    172 
    173       // Then find out the result type of the new instruction with the same
    174       // opcode but using the converted operands, create the new instruction,
    175       // and convert the result of the new instruction back to match the result
    176       // type of the original instruction.
    177       HloInstruction* new_hlo;
    178       if (hlo->shape().element_type() == eliminate_type_) {
    179         Shape shape =
    180             ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_);
    181         new_hlo = computation->AddInstruction(
    182             hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule()));
    183         new_hlo = ToElementType(new_hlo, eliminate_type_);
    184       } else if (ShapeUtil::IsTuple(hlo->shape())) {
    185         Shape old_shape = hlo->shape();
    186         Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_,
    187                                                  replace_with_type_);
    188         new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands(
    189             new_shape, new_operands, hlo->GetModule()));
    190         // Convert the elements of the result of `new_hlo` to produce a new
    191         // tuple with shape `old_shape`.
    192         new_hlo = ConvertTupleElements(new_hlo, old_shape);
    193       } else {
    194         new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands(
    195             hlo->shape(), new_operands, hlo->GetModule()));
    196       }
    197 
    198       TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, new_hlo));
    199       changed = true;
    200     }
    201   }
    202   XLA_VLOG_LINES(
    203       2, "HloElementTypeConverter::Run(), after:\n" + module->ToString());
    204   return changed;
    205 }
    206 
    207 }  // namespace xla
    208