1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/call_inliner.h" 17 18 #include <memory> 19 #include <utility> 20 21 #include "absl/memory/memory.h" 22 #include "tensorflow/compiler/xla/layout_util.h" 23 #include "tensorflow/compiler/xla/literal.h" 24 #include "tensorflow/compiler/xla/service/hlo_computation.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 27 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 28 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" 29 #include "tensorflow/compiler/xla/shape_util.h" 30 #include "tensorflow/compiler/xla/test.h" 31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 32 #include "tensorflow/compiler/xla/types.h" 33 #include "tensorflow/compiler/xla/xla_data.pb.h" 34 #include "tensorflow/core/lib/core/status_test_util.h" 35 36 namespace op = xla::testing::opcode_matchers; 37 38 namespace xla { 39 namespace { 40 41 // Tests for call inlining that are most tractable at the HLO level (vs 42 // ComputationBuilder API in call_test.cc). 43 using CallInlinerTest = HloTestBase; 44 45 TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { 46 // "inner" computation just has a control dependency from the "zero" value to 47 // the "one" value. 48 HloComputation::Builder inner(TestName() + ".inner"); 49 HloInstruction* zero = inner.AddInstruction( 50 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(24.0f))); 51 HloInstruction* one = inner.AddInstruction( 52 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); 53 TF_ASSERT_OK(zero->AddControlDependencyTo(one)); 54 auto module = CreateNewVerifiedModule(); 55 HloComputation* inner_computation = 56 module->AddEmbeddedComputation(inner.Build()); 57 58 // "outer" computation just calls the "inner" computation. 59 HloComputation::Builder outer(TestName() + ".outer"); 60 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 61 outer.AddInstruction( 62 HloInstruction::CreateCall(r0f32, {}, inner_computation)); 63 64 auto computation = module->AddEntryComputation(outer.Build()); 65 66 CallInliner call_inliner; 67 TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); 68 ASSERT_TRUE(mutated); 69 EXPECT_THAT(computation->root_instruction(), op::Constant()); 70 EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(), 71 42); 72 ASSERT_EQ(1, computation->root_instruction()->control_predecessors().size()); 73 auto prior = computation->root_instruction()->control_predecessors()[0]; 74 EXPECT_THAT(prior, op::Constant()); 75 EXPECT_EQ(prior->literal().GetFirstElement<float>(), 24); 76 } 77 78 // Tests for referential transparency (a function that calls a function that 79 // returns false should be identical to just returning false). 80 TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { 81 const Shape pred = ShapeUtil::MakeShape(PRED, {}); 82 auto module = CreateNewVerifiedModule(); 83 84 // Create a lambda that calls a function that returns the false predicate. 85 // Note we also use this lambda twice by reference, just to make the test a 86 // little trickier. 87 HloComputation::Builder just_false(TestName() + ".false"); 88 just_false.AddInstruction( 89 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); 90 HloComputation* false_computation = 91 module->AddEmbeddedComputation(just_false.Build()); 92 93 HloComputation::Builder call_false_builder(TestName() + ".call_false"); 94 call_false_builder.AddInstruction( 95 HloInstruction::CreateParameter(0, pred, "param")); 96 call_false_builder.AddInstruction( 97 HloInstruction::CreateCall(pred, {}, false_computation)); 98 HloComputation* call_false = 99 module->AddEmbeddedComputation(call_false_builder.Build()); 100 101 HloComputation::Builder outer(TestName() + ".outer"); 102 HloInstruction* init_value = outer.AddInstruction( 103 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); 104 outer.AddInstruction( 105 HloInstruction::CreateWhile(pred, call_false, call_false, init_value)); 106 107 auto computation = module->AddEntryComputation(outer.Build()); 108 109 CallInliner call_inliner; 110 TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); 111 ASSERT_TRUE(mutated); 112 EXPECT_THAT( 113 computation->root_instruction()->while_condition()->root_instruction(), 114 op::Constant()); 115 EXPECT_THAT(computation->root_instruction()->while_body()->root_instruction(), 116 op::Constant()); 117 } 118 119 // Check CallInliner::Inline, which inlines a specific call without running the 120 // whole pass. 121 TEST_F(CallInlinerTest, InlineWithoutRunningPass) { 122 const Shape pred = ShapeUtil::MakeShape(PRED, {}); 123 auto module = CreateNewVerifiedModule(); 124 125 HloComputation::Builder just_false(TestName() + ".false"); 126 auto* true_constant = just_false.AddInstruction( 127 HloInstruction::CreateConstant(LiteralUtil::CreateR1<bool>({true}))); 128 auto* false_constant = just_false.AddInstruction( 129 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); 130 TF_ASSERT_OK(false_constant->AddControlDependencyTo(true_constant)); 131 HloComputation* false_computation = 132 module->AddEmbeddedComputation(just_false.Build()); 133 134 HloComputation::Builder call_false_builder(TestName() + ".call_false"); 135 HloInstruction* call = call_false_builder.AddInstruction( 136 HloInstruction::CreateCall(pred, {}, false_computation)); 137 auto computation = module->AddEntryComputation(call_false_builder.Build()); 138 139 TF_ASSERT_OK(CallInliner::Inline(call).status()); 140 EXPECT_THAT(computation->root_instruction(), op::Constant()); 141 EXPECT_THAT(computation->root_instruction()->control_successors(), 142 ElementsAre(op::Constant())); 143 } 144 145 TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { 146 const Shape f32 = ShapeUtil::MakeShape(F32, {}); 147 auto module = CreateNewVerifiedModule(); 148 149 HloComputation::Builder outfeeder(TestName() + ".outfeeder"); 150 auto value = outfeeder.AddInstruction( 151 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); 152 auto token = outfeeder.AddInstruction(HloInstruction::CreateToken()); 153 outfeeder.AddInstruction( 154 HloInstruction::CreateOutfeed(f32, value, token, /*outfeed_config=*/"")); 155 156 auto outfeed_computation = module->AddEmbeddedComputation(outfeeder.Build()); 157 158 HloComputation::Builder outer(TestName() + ".outer"); 159 outer.AddInstruction(HloInstruction::CreateCall( 160 outfeed_computation->root_instruction()->shape(), /*operands=*/{}, 161 outfeed_computation)); 162 163 module->AddEntryComputation(outer.Build()); 164 165 CallInliner call_inliner; 166 TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); 167 ASSERT_TRUE(mutated); 168 } 169 170 } // namespace 171 } // namespace xla 172