Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
     17 
     18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     21 #include "tensorflow/compiler/xla/shape_util.h"
     22 #include "tensorflow/compiler/xla/status_macros.h"
     23 #include "tensorflow/compiler/xla/types.h"
     24 #include "tensorflow/core/platform/logging.h"
     25 
     26 namespace xla {
     27 
     28 namespace {
     29 
     30 // TODO(b/69062148) Remove this code when all backends support BatchDot
     31 // natively.
     32 Status DecomposeBatchDot(HloInstruction* dot) {
     33   auto computation = dot->parent();
     34   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
     35   HloInstruction* lhs = dot->mutable_operand(0);
     36   HloInstruction* rhs = dot->mutable_operand(1);
     37   const Shape& lhs_shape = lhs->shape();
     38   const Shape& rhs_shape = rhs->shape();
     39   const Shape& dot_shape = dot->shape();
     40 
     41   // ShapeInference should guarantee that lhs/rhs batch dimensions match.
     42   CHECK_EQ(dnums.lhs_batch_dimensions_size(),
     43            dnums.rhs_batch_dimensions_size());
     44   const int64 num_batch_dims = dnums.lhs_batch_dimensions_size();
     45   // Calculate total batch size (note that ShapeInference requires that
     46   // the batch dimensions are most-major).
     47   int64 batch_size = 1;
     48   for (int i = 0; i < num_batch_dims; ++i) {
     49     CHECK_EQ(lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)),
     50              rhs_shape.dimensions(dnums.rhs_batch_dimensions(i)));
     51     batch_size *= lhs_shape.dimensions(dnums.lhs_batch_dimensions(i));
     52   }
     53 
     54   // Set lhs/rhs_transpose.
     55   CHECK_EQ(1, dnums.lhs_contracting_dimensions_size());
     56   const int64 lhs_contracting_dim_number = dnums.lhs_contracting_dimensions(0);
     57   const bool lhs_transpose = (lhs_contracting_dim_number - num_batch_dims) == 0;
     58 
     59   CHECK_EQ(1, dnums.rhs_contracting_dimensions_size());
     60   const int64 rhs_contracting_dim_number = dnums.rhs_contracting_dimensions(0);
     61   const bool rhs_transpose = (rhs_contracting_dim_number - num_batch_dims) == 1;
     62 
     63   // Compute R3 and R3 shapes for lhs.
     64   PrimitiveType lhs_type = lhs_shape.element_type();
     65   const int64 lhs_rows = lhs_shape.dimensions(num_batch_dims + 0);
     66   const int64 lhs_cols = lhs_shape.dimensions(num_batch_dims + 1);
     67   Shape lhs_shape_r3 =
     68       ShapeUtil::MakeShape(lhs_type, {batch_size, lhs_rows, lhs_cols});
     69   Shape lhs_slice_shape_r3 =
     70       ShapeUtil::MakeShape(lhs_type, {1, lhs_rows, lhs_cols});
     71   Shape lhs_slice_shape_r2 =
     72       ShapeUtil::MakeShape(lhs_type, {lhs_rows, lhs_cols});
     73 
     74   // Compute R3 and R3 shapes for rhs.
     75   PrimitiveType rhs_type = rhs_shape.element_type();
     76   const int64 rhs_rows = rhs_shape.dimensions(num_batch_dims + 0);
     77   const int64 rhs_cols = rhs_shape.dimensions(num_batch_dims + 1);
     78   Shape rhs_shape_r3 =
     79       ShapeUtil::MakeShape(rhs_type, {batch_size, rhs_rows, rhs_cols});
     80   Shape rhs_slice_shape_r3 =
     81       ShapeUtil::MakeShape(rhs_type, {1, rhs_rows, rhs_cols});
     82   Shape rhs_slice_shape_r2 =
     83       ShapeUtil::MakeShape(rhs_type, {rhs_rows, rhs_cols});
     84 
     85   // Compute R3 and R3 shapes for dot output.
     86   PrimitiveType dot_type = dot_shape.element_type();
     87   const int64 dot_rows = dot_shape.dimensions(num_batch_dims + 0);
     88   const int64 dot_cols = dot_shape.dimensions(num_batch_dims + 1);
     89   Shape dot_shape_r2 = ShapeUtil::MakeShape(dot_type, {dot_rows, dot_cols});
     90   Shape dot_shape_r3 = ShapeUtil::MakeShape(dot_type, {1, dot_rows, dot_cols});
     91   Shape concat_shape_r3 =
     92       ShapeUtil::MakeShape(dot_type, {batch_size, dot_rows, dot_cols});
     93 
     94   // Reshape lhs/rhs into R3.
     95   auto lhs_r3 = computation->AddInstruction(
     96       HloInstruction::CreateReshape(lhs_shape_r3, lhs));
     97   auto rhs_r3 = computation->AddInstruction(
     98       HloInstruction::CreateReshape(rhs_shape_r3, rhs));
     99 
    100   // Loop through batch size, slicing out required lhs/rhs to compute each Dot.
    101   std::vector<HloInstruction*> output_slices(batch_size);
    102   for (int64 i = 0; i < batch_size; ++i) {
    103     // Slice R3 shape from 'lhs' and reshape to R2.
    104     auto lhs_slice_r3 = computation->AddInstruction(
    105         HloInstruction::CreateSlice(lhs_slice_shape_r3, lhs_r3, {i, 0, 0},
    106                                     {i + 1, lhs_rows, lhs_cols}, {1, 1, 1}));
    107     auto lhs_slice_r2 = computation->AddInstruction(
    108         HloInstruction::CreateReshape(lhs_slice_shape_r2, lhs_slice_r3));
    109 
    110     // Slice R3 shape from 'rhs' and reshape to R2.
    111     auto rhs_slice_r3 = computation->AddInstruction(
    112         HloInstruction::CreateSlice(rhs_slice_shape_r3, rhs_r3, {i, 0, 0},
    113                                     {i + 1, rhs_rows, rhs_cols}, {1, 1, 1}));
    114     auto rhs_slice_r2 = computation->AddInstruction(
    115         HloInstruction::CreateReshape(rhs_slice_shape_r2, rhs_slice_r3));
    116 
    117     // Transpose lhs/rhs (if needed).
    118     if (lhs_transpose) {
    119       Shape lhs_slice_shape_r2_transpose =
    120           ShapeUtil::MakeShape(lhs_type, {lhs_cols, lhs_rows});
    121       lhs_slice_r2 =
    122           computation->AddInstruction(HloInstruction::CreateTranspose(
    123               lhs_slice_shape_r2_transpose, lhs_slice_r2, {1, 0}));
    124     }
    125     if (rhs_transpose) {
    126       Shape rhs_slice_shape_r2_transpose =
    127           ShapeUtil::MakeShape(rhs_type, {rhs_cols, rhs_rows});
    128       rhs_slice_r2 =
    129           computation->AddInstruction(HloInstruction::CreateTranspose(
    130               rhs_slice_shape_r2_transpose, rhs_slice_r2, {1, 0}));
    131     }
    132 
    133     // Compute Dot of lhs/rhs R2 slices.
    134     DotDimensionNumbers dot_dnums;
    135     dot_dnums.add_lhs_contracting_dimensions(1);
    136     dot_dnums.add_rhs_contracting_dimensions(0);
    137     auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot(
    138         dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums));
    139 
    140     // Reshape Dot to R3 so we can concat along batch dimension.
    141     auto dot_r3 = computation->AddInstruction(
    142         HloInstruction::CreateReshape(dot_shape_r3, dot_r2));
    143 
    144     output_slices[i] = dot_r3;
    145   }
    146 
    147   // Concatenate slices from 'output_slices' along batch dimension.
    148   auto concat = computation->AddInstruction(
    149       HloInstruction::CreateConcatenate(concat_shape_r3, output_slices, 0));
    150   // Reshape output 'new_dot' to original dimensions.
    151   auto new_dot = computation->AddInstruction(
    152       HloInstruction::CreateReshape(dot_shape, concat));
    153 
    154   // Replace all uses of 'dot' in 'computation' with 'new_dot'.
    155   return computation->ReplaceInstruction(dot, new_dot);
    156 }
    157 
    158 }  // namespace
    159 
    160 StatusOr<bool> DotDecomposer::Run(HloModule* module) {
    161   XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString());
    162   // Gather all batch Dot operations.
    163   std::vector<HloInstruction*> batch_dots;
    164   for (auto* computation : module->MakeNonfusionComputations()) {
    165     for (auto* instruction : computation->instructions()) {
    166       if (instruction->opcode() != HloOpcode::kDot) {
    167         continue;
    168       }
    169       const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers();
    170       if (dnums.lhs_batch_dimensions_size() > 0 && decompose_batch_dot_) {
    171         batch_dots.push_back(instruction);
    172       }
    173     }
    174   }
    175   // Decompose each batch Dot in 'batch_dots'.
    176   bool changed = false;
    177   for (auto* dot : batch_dots) {
    178     TF_RETURN_IF_ERROR(DecomposeBatchDot(dot));
    179     changed = true;
    180   }
    181   XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString());
    182   return changed;
    183 }
    184 
    185 }  // namespace xla
    186