1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" 17 18 #include "tensorflow/compiler/xla/service/hlo_computation.h" 19 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" 20 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 21 #include "tensorflow/compiler/xla/service/hlo_module.h" 22 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 23 #include "tensorflow/compiler/xla/shape_util.h" 24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 25 #include "tensorflow/compiler/xla/tests/test_utils.h" 26 27 namespace xla { 28 29 class HloSubcomputationUnificationTest : public HloTestBase { 30 protected: 31 HloSubcomputationUnificationTest() {} 32 33 std::unique_ptr<HloComputation> CreateR0S32IdentityComputation() { 34 auto builder = HloComputation::Builder("Identity"); 35 builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32_, "x")); 36 return builder.Build(); 37 } 38 39 std::unique_ptr<HloComputation> CreateR0S32AdditionComputation() { 40 auto builder = HloComputation::Builder("Addition"); 41 auto x = 42 builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32_, "x")); 43 auto y = 44 builder.AddInstruction(HloInstruction::CreateParameter(1, r0s32_, "y")); 45 builder.AddInstruction( 46 HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y)); 47 return builder.Build(); 48 } 49 50 std::unique_ptr<HloComputation> CreateR1S32AdditionComputation( 51 const Shape& shape) { 52 auto builder = HloComputation::Builder("Addition"); 53 auto x = 54 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); 55 auto y = 56 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "y")); 57 builder.AddInstruction( 58 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, x, y)); 59 return builder.Build(); 60 } 61 62 Shape r0s32_ = ShapeUtil::MakeShape(S32, {}); 63 Shape r0f32_ = ShapeUtil::MakeShape(S32, {}); 64 Shape r1s32_5_ = ShapeUtil::MakeShape(S32, {5}); 65 Shape r1s32_3_ = ShapeUtil::MakeShape(S32, {3}); 66 }; 67 68 TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { 69 auto module = CreateNewModule(); 70 auto builder = HloComputation::Builder(TestName()); 71 72 auto callee1 = 73 module->AddEmbeddedComputation(CreateR0S32IdentityComputation()); 74 auto callee2 = 75 module->AddEmbeddedComputation(CreateR0S32IdentityComputation()); 76 77 auto constant = builder.AddInstruction( 78 HloInstruction::CreateConstant(Literal::CreateR0<int32>(5))); 79 auto x = builder.AddInstruction( 80 HloInstruction::CreateCall(r0s32_, {constant}, callee1)); 81 auto y = builder.AddInstruction( 82 HloInstruction::CreateCall(r0s32_, {constant}, callee2)); 83 builder.AddInstruction( 84 HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y)); 85 86 module->AddEntryComputation(builder.Build()); 87 88 EXPECT_EQ(3, module->computation_count()); 89 EXPECT_NE(x->to_apply(), y->to_apply()); 90 if (VLOG_IS_ON(1)) { 91 hlo_graph_dumper::DumpGraph(*module->entry_computation(), 92 "before unification", 93 module->config().debug_options()); 94 } 95 EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); 96 if (VLOG_IS_ON(1)) { 97 hlo_graph_dumper::DumpGraph(*module->entry_computation(), 98 "after unification", 99 module->config().debug_options()); 100 } 101 EXPECT_EQ(2, module->computation_count()); 102 EXPECT_EQ(x->to_apply(), y->to_apply()); 103 } 104 105 TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { 106 auto module = CreateNewModule(); 107 auto builder = HloComputation::Builder(TestName()); 108 109 auto callee1 = 110 module->AddEmbeddedComputation(CreateR0S32AdditionComputation()); 111 auto callee2 = 112 module->AddEmbeddedComputation(CreateR0S32AdditionComputation()); 113 114 auto constant1 = builder.AddInstruction( 115 HloInstruction::CreateConstant(Literal::CreateR0<int32>(5))); 116 auto constant2 = builder.AddInstruction( 117 HloInstruction::CreateConstant(Literal::CreateR0<int32>(3))); 118 auto x = builder.AddInstruction( 119 HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee1)); 120 auto y = builder.AddInstruction( 121 HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee2)); 122 builder.AddInstruction( 123 HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y)); 124 125 module->AddEntryComputation(builder.Build()); 126 127 EXPECT_EQ(3, module->computation_count()); 128 EXPECT_NE(x->to_apply(), y->to_apply()); 129 if (VLOG_IS_ON(1)) { 130 hlo_graph_dumper::DumpGraph(*module->entry_computation(), 131 "before unification", 132 module->config().debug_options()); 133 } 134 EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); 135 if (VLOG_IS_ON(1)) { 136 hlo_graph_dumper::DumpGraph(*module->entry_computation(), 137 "after unification", 138 module->config().debug_options()); 139 } 140 EXPECT_EQ(2, module->computation_count()); 141 EXPECT_EQ(x->to_apply(), y->to_apply()); 142 } 143 144 // Do not unify subcomputations with different parameter shapes. 145 TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { 146 auto module = CreateNewModule(); 147 auto builder = HloComputation::Builder(TestName()); 148 149 auto callee1 = 150 module->AddEmbeddedComputation(CreateR1S32AdditionComputation(r1s32_5_)); 151 auto callee2 = 152 module->AddEmbeddedComputation(CreateR1S32AdditionComputation(r1s32_3_)); 153 154 auto param1 = builder.AddInstruction( 155 HloInstruction::CreateParameter(0, r1s32_5_, "param1")); 156 auto param2 = builder.AddInstruction( 157 HloInstruction::CreateParameter(1, r1s32_5_, "param2")); 158 auto x = builder.AddInstruction( 159 HloInstruction::CreateCall(r1s32_5_, {param1, param1}, callee1)); 160 auto y = builder.AddInstruction( 161 HloInstruction::CreateCall(r1s32_3_, {param2, param2}, callee2)); 162 builder.AddInstruction(HloInstruction::CreateConcatenate( 163 ShapeUtil::MakeShape(S32, {8}), {x, y}, 0)); 164 165 module->AddEntryComputation(builder.Build()); 166 167 EXPECT_EQ(3, module->computation_count()); 168 EXPECT_NE(x->to_apply(), y->to_apply()); 169 if (VLOG_IS_ON(1)) { 170 hlo_graph_dumper::DumpGraph(*module->entry_computation(), 171 "before unification", 172 module->config().debug_options()); 173 } 174 EXPECT_FALSE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); 175 if (VLOG_IS_ON(1)) { 176 hlo_graph_dumper::DumpGraph(*module->entry_computation(), 177 "after unification", 178 module->config().debug_options()); 179 } 180 EXPECT_EQ(3, module->computation_count()); 181 EXPECT_NE(x->to_apply(), y->to_apply()); 182 } 183 184 // Regression test for b/31466798. Checks that entry_computation is still valid 185 // after unification. 186 TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) { 187 auto module = CreateNewModule(); 188 for (int i = 0; i < 2; ++i) { 189 HloComputation::Builder builder("pow"); 190 auto x = 191 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); 192 auto y = 193 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y")); 194 builder.AddInstruction( 195 HloInstruction::CreateBinary(r0f32_, HloOpcode::kPower, x, y)); 196 if (i == 0) { 197 module->AddEmbeddedComputation(builder.Build()); 198 } else { 199 module->AddEntryComputation(builder.Build()); 200 } 201 } 202 203 EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); 204 EXPECT_EQ(1, module->computation_count()); 205 EXPECT_EQ(*module->computations().begin(), module->entry_computation()); 206 } 207 208 } // namespace xla 209