Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/transpose_folding.h"
     17 
     18 #include <memory>
     19 #include <unordered_set>
     20 #include <vector>
     21 
     22 #include "tensorflow/compiler/xla/client/computation_builder.h"
     23 #include "tensorflow/compiler/xla/literal_util.h"
     24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
     25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     27 #include "tensorflow/compiler/xla/service/hlo_module.h"
     28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     29 #include "tensorflow/compiler/xla/service/shape_inference.h"
     30 #include "tensorflow/compiler/xla/shape_util.h"
     31 #include "tensorflow/compiler/xla/test.h"
     32 #include "tensorflow/compiler/xla/test_helpers.h"
     33 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     34 #include "tensorflow/compiler/xla/xla_data.pb.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 
     37 namespace xla {
     38 namespace {
     39 
     40 class TransposeFoldingTest : public HloTestBase {
     41  protected:
     42   void FoldTranspose(HloModule* module) {
     43     TransposeFolding transpose_folding(
     44         [](const HloInstruction& dot,
     45            const TransposeFolding::OperandIndices& candidate_operands) {
     46           return candidate_operands;
     47         },
     48         [](const HloInstruction& convolution,
     49            const TransposeFolding::OperandIndices& candidate_operands) {
     50           return candidate_operands;
     51         });
     52     EXPECT_IS_OK(transpose_folding.Run(module).status());
     53   }
     54 };
     55 
     56 TEST_F(TransposeFoldingTest, FoldDotTranspose) {
     57   auto builder = HloComputation::Builder("entry_computation");
     58   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
     59       /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}),
     60       /*name=*/"x"));
     61   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
     62       /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}),
     63       /*name=*/"y"));
     64   HloInstruction* transpose_y =
     65       builder.AddInstruction(HloInstruction::CreateTranspose(
     66           ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0}));
     67   DotDimensionNumbers dot_dnums;
     68   dot_dnums.add_lhs_contracting_dimensions(1);
     69   dot_dnums.add_rhs_contracting_dimensions(0);
     70   HloInstruction* dot = builder.AddInstruction(
     71       HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
     72                                 /*rhs=*/transpose_y, dot_dnums));
     73 
     74   HloModule module("test_module");
     75   HloComputation* entry_computation =
     76       module.AddEntryComputation(builder.Build(dot));
     77   FoldTranspose(&module);
     78 
     79   // Instructions after folding: x, y, and the fusion.
     80   std::unordered_set<HloInstruction*> instruction_set(
     81       entry_computation->instructions().begin(),
     82       entry_computation->instructions().end());
     83   CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
     84   CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
     85   CHECK_EQ(1, instruction_set.size())
     86       << "entry_computation should contain exactly 3 instructions.";
     87   HloInstruction* fusion = *instruction_set.begin();
     88   EXPECT_EQ(HloOpcode::kFusion, fusion->opcode());
     89 
     90   // The fusion instruction should contain two parameters, one transpose and
     91   // one dot.
     92   EXPECT_EQ(4, fusion->fused_instruction_count());
     93 }
     94 
     95 TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) {
     96   auto builder = HloComputation::Builder("entry_computation");
     97   // 2x1
     98   HloInstruction* const0 = builder.AddInstruction(
     99       HloInstruction::CreateConstant(Literal::CreateR2<float>({{1}, {2}})));
    100   // 3x2
    101   HloInstruction* const1 =
    102       builder.AddInstruction(HloInstruction::CreateConstant(
    103           Literal::CreateR2<float>({{1, 2}, {3, 4}, {5, 6}})));
    104   HloInstruction* transpose0 =
    105       builder.AddInstruction(HloInstruction::CreateTranspose(
    106           ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0}));
    107   HloInstruction* transpose1 =
    108       builder.AddInstruction(HloInstruction::CreateTranspose(
    109           ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0}));
    110   DotDimensionNumbers dot_dnums;
    111   dot_dnums.add_lhs_contracting_dimensions(1);
    112   dot_dnums.add_rhs_contracting_dimensions(0);
    113   HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
    114       ShapeUtil::MakeShape(F32, {1, 3}),
    115       /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums));
    116 
    117   HloModule module("test_module");
    118   HloComputation* entry_computation =
    119       module.AddEntryComputation(builder.Build(dot));
    120   FoldTranspose(&module);
    121 
    122   for (auto* instruction : entry_computation->instructions()) {
    123     if (instruction->opcode() == HloOpcode::kFusion) {
    124       CHECK_EQ(2, instruction->operand_count());
    125       EXPECT_EQ(const0, instruction->operand(0));
    126       EXPECT_EQ(const1, instruction->operand(1));
    127     }
    128   }
    129 
    130   // The created fusion instruction should contain two parameters, two
    131   // transposes (one for each parameter) and one dot.
    132   EXPECT_EQ(5,
    133             entry_computation->root_instruction()->fused_instruction_count());
    134 }
    135 
    136 TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
    137   auto builder = HloComputation::Builder("entry");
    138   // (1.0 + 2.0) * (2.0 - 3.0)
    139   HloInstruction* const1 = builder.AddInstruction(
    140       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    141   HloInstruction* const2 = builder.AddInstruction(
    142       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    143   HloInstruction* const3 = builder.AddInstruction(
    144       HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
    145   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
    146       const1->shape(), HloOpcode::kAdd, const1, const2));
    147   HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
    148       const2->shape(), HloOpcode::kSubtract, const2, const3));
    149   HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary(
    150       add->shape(), HloOpcode::kMultiply, add, sub));
    151 
    152   HloModule module("fuse_with_constant_operands");
    153   HloComputation* entry_computation =
    154       module.AddEntryComputation(builder.Build(mul));
    155   HloInstruction* call = module.OutlineExpressionFromComputation(
    156       {add, sub, mul}, "", entry_computation);
    157   EXPECT_EQ(call, entry_computation->root_instruction());
    158   HloComputation* callee_computation = call->to_apply();
    159   // The arguments to the call should be const1, const2, and const3.
    160   EXPECT_THAT(call->operands(),
    161               ::testing::UnorderedElementsAre(const1, const2, const3));
    162 
    163   // The callee should contain 3 parameters and 3 binary operators.
    164   EXPECT_EQ(6, callee_computation->instruction_count());
    165 }
    166 
    167 TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) {
    168   auto builder = HloComputation::Builder("entry_computation");
    169   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
    170       /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}),
    171       /*name=*/"x"));
    172   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
    173       /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}),
    174       /*name=*/"y"));
    175   HloInstruction* transpose_y =
    176       builder.AddInstruction(HloInstruction::CreateTranspose(
    177           ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0}));
    178   DotDimensionNumbers dot_dnums;
    179   dot_dnums.add_lhs_contracting_dimensions(1);
    180   dot_dnums.add_rhs_contracting_dimensions(0);
    181   HloInstruction* dot = builder.AddInstruction(
    182       HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
    183                                 /*rhs=*/transpose_y, dot_dnums));
    184 
    185   HloModule module("test_module");
    186   HloComputation* entry_computation =
    187       module.AddEntryComputation(builder.Build(dot));
    188 
    189   HloInstruction* call = module.OutlineExpressionFromComputation(
    190       {transpose_y, dot}, "outlined", entry_computation);
    191 
    192   FoldTranspose(&module);
    193 
    194   // Instructions after folding: x, y, and the fusion.
    195   std::unordered_set<HloInstruction*> instruction_set(
    196       entry_computation->instructions().begin(),
    197       entry_computation->instructions().end());
    198   CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
    199   CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
    200   CHECK_EQ(1, instruction_set.erase(call))
    201       << "call is not in entry_computation.";
    202   CHECK(instruction_set.empty())
    203       << "entry_computation should contain exactly 3 instructions.";
    204   HloInstruction* fusion =
    205       call->called_computations().front()->root_instruction();
    206   EXPECT_EQ(HloOpcode::kFusion, fusion->opcode());
    207 
    208   // The fusion instruction should contain two parameters, one transpose and
    209   // one dot.
    210   EXPECT_EQ(4, fusion->fused_instruction_count());
    211 }
    212 
    213 // Test that a two dimension swap of the kernel gets folded into convolution.
    214 TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
    215   auto builder = HloComputation::Builder("entry_computation");
    216   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
    217       /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
    218       /*name=*/"x"));
    219   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
    220       /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
    221       /*name=*/"y"));
    222   HloInstruction* transpose_y =
    223       builder.AddInstruction(HloInstruction::CreateTranspose(
    224           ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 0, 2, 3}));
    225   auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
    226   Window window;
    227   for (int i = 0; i < 2; ++i) {
    228     WindowDimension* dim = window.add_dimensions();
    229     dim->set_padding_low(0);
    230     dim->set_padding_high(0);
    231     dim->set_base_dilation(1);
    232     dim->set_window_dilation(1);
    233     dim->set_stride(1);
    234     dim->set_size(
    235         transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
    236   }
    237   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
    238       x->shape(), transpose_y->shape(), window, dnums);
    239   EXPECT_IS_OK(conv_shape);
    240   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
    241       conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
    242 
    243   HloModule module("test_module");
    244   HloComputation* entry_computation =
    245       module.AddEntryComputation(builder.Build(conv));
    246   FoldTranspose(&module);
    247 
    248   // Instructions after folding: x, y, and the convolution.
    249   std::unordered_set<HloInstruction*> instruction_set(
    250       entry_computation->instructions().begin(),
    251       entry_computation->instructions().end());
    252   CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
    253   CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
    254   CHECK_EQ(1, instruction_set.size())
    255       << "entry_computation should contain exactly 3 instructions.";
    256   HloInstruction* new_conv = *instruction_set.begin();
    257   EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
    258   EXPECT_EQ(dnums.kernel_input_feature_dimension(),
    259             new_conv->convolution_dimension_numbers()
    260                 .kernel_output_feature_dimension());
    261   EXPECT_EQ(dnums.kernel_output_feature_dimension(),
    262             new_conv->convolution_dimension_numbers()
    263                 .kernel_input_feature_dimension());
    264 }
    265 
    266 // Test that a complex transpose of the kernel gets folded into convolution.
    267 TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
    268   auto builder = HloComputation::Builder("entry_computation");
    269   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
    270       /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
    271       /*name=*/"x"));
    272   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
    273       /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {1, 2, 1, 3}),
    274       /*name=*/"y"));
    275   HloInstruction* transpose_y =
    276       builder.AddInstruction(HloInstruction::CreateTranspose(
    277           ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 3, 0, 2}));
    278   auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
    279   Window window;
    280   for (int i = 0; i < 2; ++i) {
    281     WindowDimension* dim = window.add_dimensions();
    282     dim->set_padding_low(0);
    283     dim->set_padding_high(0);
    284     dim->set_base_dilation(1);
    285     dim->set_window_dilation(1);
    286     dim->set_stride(1);
    287     dim->set_size(
    288         transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
    289   }
    290   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
    291       x->shape(), transpose_y->shape(), window, dnums);
    292   EXPECT_IS_OK(conv_shape);
    293   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
    294       conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
    295 
    296   HloModule module("test_module");
    297   HloComputation* entry_computation =
    298       module.AddEntryComputation(builder.Build(conv));
    299   FoldTranspose(&module);
    300 
    301   // Instructions after folding: x, y, and the convolution.
    302   std::unordered_set<HloInstruction*> instruction_set(
    303       entry_computation->instructions().begin(),
    304       entry_computation->instructions().end());
    305   CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
    306   CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
    307   CHECK_EQ(1, instruction_set.size())
    308       << "entry_computation should contain exactly 3 instructions.";
    309   HloInstruction* new_conv = *instruction_set.begin();
    310   EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
    311   EXPECT_EQ(dnums.kernel_input_feature_dimension(),
    312             new_conv->convolution_dimension_numbers()
    313                 .kernel_output_feature_dimension());
    314   EXPECT_EQ(dnums.kernel_spatial_dimensions(1),
    315             new_conv->convolution_dimension_numbers()
    316                 .kernel_input_feature_dimension());
    317   EXPECT_EQ(
    318       dnums.kernel_output_feature_dimension(),
    319       new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(0));
    320   EXPECT_EQ(
    321       dnums.kernel_spatial_dimensions(0),
    322       new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1));
    323 }
    324 
    325 // Test that a transpose of the activations gets folded into convolution.
    326 TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
    327   auto builder = HloComputation::Builder("entry_computation");
    328   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
    329       /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
    330       /*name=*/"x"));
    331   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
    332       /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
    333       /*name=*/"y"));
    334   HloInstruction* transpose_x =
    335       builder.AddInstruction(HloInstruction::CreateTranspose(
    336           ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 2, 3}));
    337   auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
    338   Window window;
    339   for (int i = 0; i < 2; ++i) {
    340     WindowDimension* dim = window.add_dimensions();
    341     dim->set_padding_low(0);
    342     dim->set_padding_high(0);
    343     dim->set_base_dilation(1);
    344     dim->set_window_dilation(1);
    345     dim->set_stride(1);
    346     dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
    347   }
    348   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
    349       transpose_x->shape(), y->shape(), window, dnums);
    350   EXPECT_IS_OK(conv_shape);
    351   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
    352       conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
    353 
    354   HloModule module("test_module");
    355   HloComputation* entry_computation =
    356       module.AddEntryComputation(builder.Build(conv));
    357   FoldTranspose(&module);
    358 
    359   // Instructions after folding: x, y, and the convolution.
    360   std::unordered_set<HloInstruction*> instruction_set(
    361       entry_computation->instructions().begin(),
    362       entry_computation->instructions().end());
    363   EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
    364   EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
    365   EXPECT_EQ(1, instruction_set.size())
    366       << "entry_computation should contain exactly 3 instructions.";
    367   HloInstruction* new_conv = *instruction_set.begin();
    368   EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
    369   EXPECT_EQ(dnums.input_feature_dimension(),
    370             new_conv->convolution_dimension_numbers().input_batch_dimension());
    371   EXPECT_EQ(
    372       dnums.input_batch_dimension(),
    373       new_conv->convolution_dimension_numbers().input_feature_dimension());
    374   EXPECT_EQ(
    375       dnums.input_spatial_dimensions(0),
    376       new_conv->convolution_dimension_numbers().input_spatial_dimensions(0));
    377   EXPECT_EQ(
    378       dnums.input_spatial_dimensions(1),
    379       new_conv->convolution_dimension_numbers().input_spatial_dimensions(1));
    380   EXPECT_EQ(
    381       dnums.output_spatial_dimensions(0),
    382       new_conv->convolution_dimension_numbers().output_spatial_dimensions(0));
    383   EXPECT_EQ(
    384       dnums.output_spatial_dimensions(1),
    385       new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
    386 }
    387 
    388 // Test that a transpose of every dimension in the activations gets folded into
    389 // convolution.
    390 TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
    391   auto builder = HloComputation::Builder("entry_computation");
    392   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
    393       /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
    394       /*name=*/"x"));
    395   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
    396       /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
    397       /*name=*/"y"));
    398   HloInstruction* transpose_x =
    399       builder.AddInstruction(HloInstruction::CreateTranspose(
    400           ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2}));
    401   auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
    402   Window window;
    403   for (int i = 0; i < 2; ++i) {
    404     WindowDimension* dim = window.add_dimensions();
    405     dim->set_padding_low(0);
    406     dim->set_padding_high(0);
    407     dim->set_base_dilation(1);
    408     dim->set_window_dilation(1);
    409     dim->set_stride(1);
    410     dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
    411   }
    412   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
    413       transpose_x->shape(), y->shape(), window, dnums);
    414   EXPECT_IS_OK(conv_shape);
    415   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
    416       conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
    417 
    418   HloModule module("test_module");
    419   HloComputation* entry_computation =
    420       module.AddEntryComputation(builder.Build(conv));
    421   FoldTranspose(&module);
    422 
    423   // Instructions after folding: x, y, and the convolution.
    424   std::unordered_set<HloInstruction*> instruction_set(
    425       entry_computation->instructions().begin(),
    426       entry_computation->instructions().end());
    427   EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
    428   EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
    429   EXPECT_EQ(1, instruction_set.size())
    430       << "entry_computation should contain exactly 3 instructions.";
    431   HloInstruction* new_conv = *instruction_set.begin();
    432   EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
    433   EXPECT_EQ(dnums.input_feature_dimension(),
    434             new_conv->convolution_dimension_numbers().input_batch_dimension());
    435   EXPECT_EQ(
    436       dnums.input_batch_dimension(),
    437       new_conv->convolution_dimension_numbers().input_feature_dimension());
    438   EXPECT_EQ(
    439       dnums.input_spatial_dimensions(0),
    440       new_conv->convolution_dimension_numbers().input_spatial_dimensions(1));
    441   EXPECT_EQ(
    442       dnums.input_spatial_dimensions(1),
    443       new_conv->convolution_dimension_numbers().input_spatial_dimensions(0));
    444   EXPECT_EQ(
    445       dnums.output_spatial_dimensions(0),
    446       new_conv->convolution_dimension_numbers().output_spatial_dimensions(0));
    447   EXPECT_EQ(
    448       dnums.output_spatial_dimensions(1),
    449       new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
    450 }
    451 
    452 }  // namespace
    453 }  // namespace xla
    454