1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/call_inliner.h" 17 18 #include <memory> 19 #include <utility> 20 21 #include "tensorflow/compiler/xla/layout_util.h" 22 #include "tensorflow/compiler/xla/literal_util.h" 23 #include "tensorflow/compiler/xla/ptr_util.h" 24 #include "tensorflow/compiler/xla/service/hlo_computation.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 27 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 28 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" 29 #include "tensorflow/compiler/xla/shape_util.h" 30 #include "tensorflow/compiler/xla/test.h" 31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 32 #include "tensorflow/compiler/xla/types.h" 33 #include "tensorflow/compiler/xla/xla_data.pb.h" 34 #include "tensorflow/core/lib/core/status_test_util.h" 35 #include "tensorflow/core/lib/strings/str_util.h" 36 37 namespace op = xla::testing::opcode_matchers; 38 39 namespace xla { 40 namespace { 41 42 // Tests for call inlining that are most tractable at the HLO level (vs 43 // ComputationBuilder API in call_test.cc). 44 using CallInlinerTest = HloTestBase; 45 46 TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { 47 // "inner" computation just has a control dependency from the "zero" value to 48 // the "one" value. 49 HloComputation::Builder inner(TestName() + ".inner"); 50 HloInstruction* zero = inner.AddInstruction( 51 HloInstruction::CreateConstant(Literal::CreateR0<float>(24.0f))); 52 HloInstruction* one = inner.AddInstruction( 53 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 54 TF_ASSERT_OK(zero->AddControlDependencyTo(one)); 55 auto module = CreateNewModule(); 56 HloComputation* inner_computation = 57 module->AddEmbeddedComputation(inner.Build()); 58 59 // "outer" computation just calls the "inner" computation. 60 HloComputation::Builder outer(TestName() + ".outer"); 61 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 62 outer.AddInstruction( 63 HloInstruction::CreateCall(r0f32, {}, inner_computation)); 64 65 auto computation = module->AddEntryComputation(outer.Build()); 66 67 CallInliner call_inliner; 68 TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); 69 ASSERT_TRUE(mutated); 70 EXPECT_THAT(computation->root_instruction(), op::Constant()); 71 EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(), 72 42); 73 ASSERT_EQ(1, computation->root_instruction()->control_predecessors().size()); 74 auto prior = computation->root_instruction()->control_predecessors()[0]; 75 EXPECT_THAT(prior, op::Constant()); 76 EXPECT_EQ(prior->literal().GetFirstElement<float>(), 24); 77 } 78 79 // Tests for referential transparency (a function that calls a function that 80 // returns false should be identical to just returning false). 81 TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { 82 const Shape pred = ShapeUtil::MakeShape(PRED, {}); 83 auto module = CreateNewModule(); 84 85 // Create a lambda that calls a function that returns the false predicate. 86 // Note we also use this lambda twice by reference, just to make the test a 87 // little trickier. 88 HloComputation::Builder just_false(TestName() + ".false"); 89 just_false.AddInstruction( 90 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 91 HloComputation* false_computation = 92 module->AddEmbeddedComputation(just_false.Build()); 93 94 HloComputation::Builder call_false_builder(TestName() + ".call_false"); 95 call_false_builder.AddInstruction( 96 HloInstruction::CreateCall(pred, {}, false_computation)); 97 HloComputation* call_false = 98 module->AddEmbeddedComputation(call_false_builder.Build()); 99 100 HloComputation::Builder outer(TestName() + ".outer"); 101 HloInstruction* init_value = outer.AddInstruction( 102 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 103 outer.AddInstruction( 104 HloInstruction::CreateWhile(pred, call_false, call_false, init_value)); 105 106 auto computation = module->AddEntryComputation(outer.Build()); 107 108 CallInliner call_inliner; 109 TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); 110 ASSERT_TRUE(mutated); 111 EXPECT_THAT( 112 computation->root_instruction()->while_condition()->root_instruction(), 113 op::Constant()); 114 EXPECT_THAT(computation->root_instruction()->while_body()->root_instruction(), 115 op::Constant()); 116 } 117 118 // Check CallInliner::Inline, which inlines a specific call without running the 119 // whole pass. 120 TEST_F(CallInlinerTest, InlineWithoutRunningPass) { 121 const Shape pred = ShapeUtil::MakeShape(PRED, {}); 122 auto module = CreateNewModule(); 123 124 HloComputation::Builder just_false(TestName() + ".false"); 125 auto* true_constant = just_false.AddInstruction( 126 HloInstruction::CreateConstant(Literal::CreateR1<bool>({true}))); 127 auto* false_constant = just_false.AddInstruction( 128 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 129 TF_ASSERT_OK(false_constant->AddControlDependencyTo(true_constant)); 130 HloComputation* false_computation = 131 module->AddEmbeddedComputation(just_false.Build()); 132 133 HloComputation::Builder call_false_builder(TestName() + ".call_false"); 134 HloInstruction* call = call_false_builder.AddInstruction( 135 HloInstruction::CreateCall(pred, {}, false_computation)); 136 auto computation = module->AddEntryComputation(call_false_builder.Build()); 137 138 TF_ASSERT_OK(CallInliner::Inline(call).status()); 139 EXPECT_THAT(computation->root_instruction(), op::Constant()); 140 EXPECT_THAT(computation->root_instruction()->control_successors(), 141 ElementsAre(op::Constant())); 142 } 143 144 TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { 145 const Shape f32 = ShapeUtil::MakeShape(F32, {}); 146 auto module = CreateNewModule(); 147 148 HloComputation::Builder outfeeder(TestName() + ".outfeeder"); 149 auto value = outfeeder.AddInstruction( 150 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0))); 151 outfeeder.AddInstruction( 152 HloInstruction::CreateOutfeed(f32, value, /*outfeed_config=*/"")); 153 154 auto outfeed_computation = module->AddEmbeddedComputation(outfeeder.Build()); 155 156 HloComputation::Builder outer(TestName() + ".outer"); 157 outer.AddInstruction(HloInstruction::CreateCall( 158 ShapeUtil::MakeNil(), /*operands=*/{}, outfeed_computation)); 159 160 module->AddEntryComputation(outer.Build()); 161 162 CallInliner call_inliner; 163 TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); 164 ASSERT_TRUE(mutated); 165 } 166 167 } // namespace 168 } // namespace xla 169