1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/inliner.h" 17 18 #include <memory> 19 #include <utility> 20 21 #include "tensorflow/compiler/xla/literal_util.h" 22 #include "tensorflow/compiler/xla/ptr_util.h" 23 #include "tensorflow/compiler/xla/service/hlo_computation.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/shape_util.h" 28 #include "tensorflow/compiler/xla/test.h" 29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 30 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 33 namespace op = xla::testing::opcode_matchers; 34 35 namespace xla { 36 namespace { 37 38 using InlinerTest = HloTestBase; 39 40 // Test that `map` with `max` is transformed to `max` 41 TEST_F(InlinerTest, MapMax) { 42 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 43 44 auto max_builder = HloComputation::Builder(TestName()); 45 auto param1 = max_builder.AddInstruction( 46 HloInstruction::CreateParameter(0, r0f32, "x")); 47 auto param2 = max_builder.AddInstruction( 48 HloInstruction::CreateParameter(1, r0f32, "y")); 49 max_builder.AddInstruction(HloInstruction::CreateBinary( 50 param1->shape(), HloOpcode::kMaximum, param1, param2)); 51 auto max_f32 = max_builder.Build(); 52 53 auto builder = HloComputation::Builder("MapMaxFunction"); 54 auto lhs = builder.AddInstruction( 55 HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3, 4}))); 56 auto rhs = builder.AddInstruction( 57 HloInstruction::CreateConstant(Literal::CreateR1<float>({4, 3, 2, 1}))); 58 builder.AddInstruction( 59 HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); 60 61 auto computation = builder.Build(); 62 auto hlo_module = CreateNewModule(); 63 hlo_module->AddEmbeddedComputation(std::move(max_f32)); 64 hlo_module->AddEntryComputation(std::move(computation)); 65 66 Inliner inliner; 67 EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); 68 EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), 69 op::Maximum(lhs, rhs)); 70 71 // Verify execution on CPU. 72 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 73 auto expected = Literal::CreateR1<float>({4, 3, 3, 4}); 74 LiteralTestUtil::ExpectEqual(*result, *expected); 75 } 76 77 // Test that `constant` function is changed to `broadcast`. 78 TEST_F(InlinerTest, MapConstant) { 79 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 80 81 auto const2_builder = HloComputation::Builder(TestName()); 82 auto param1 = const2_builder.AddInstruction( 83 HloInstruction::CreateParameter(0, r0f32, "x")); 84 (void)param1; 85 const2_builder.AddInstruction( 86 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f))); 87 auto const2_f32 = const2_builder.Build(); 88 89 auto builder = HloComputation::Builder("MapConstFunction"); 90 auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( 91 Literal::CreateR2<float>({{1, 2, 3, 4}, {5, 6, 7, 8}}))); 92 builder.AddInstruction( 93 HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); 94 95 auto computation = builder.Build(); 96 auto hlo_module = CreateNewModule(); 97 hlo_module->AddEmbeddedComputation(std::move(const2_f32)); 98 hlo_module->AddEntryComputation(std::move(computation)); 99 HloInstruction* root = hlo_module->entry_computation()->root_instruction(); 100 Inliner inliner; 101 EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); 102 root = hlo_module->entry_computation()->root_instruction(); 103 EXPECT_THAT(root, op::Broadcast(op::Constant())); 104 105 // Verify execution on CPU. 106 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 107 auto expected = Literal::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}}); 108 LiteralTestUtil::ExpectEqual(*result, *expected); 109 } 110 111 TEST_F(InlinerTest, MapSubtractOppositeOrder) { 112 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 113 114 // Note that the parameter ordinals are in the opposite order to their 115 // position as operands 116 auto max_builder = HloComputation::Builder(TestName()); 117 auto param1 = max_builder.AddInstruction( 118 HloInstruction::CreateParameter(1, r0f32, "x")); 119 auto param2 = max_builder.AddInstruction( 120 HloInstruction::CreateParameter(0, r0f32, "y")); 121 max_builder.AddInstruction(HloInstruction::CreateBinary( 122 param1->shape(), HloOpcode::kSubtract, param1, param2)); 123 auto max_f32 = max_builder.Build(); 124 125 auto builder = HloComputation::Builder("MapSubFunction"); 126 auto lhs = builder.AddInstruction( 127 HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3, 4}))); 128 auto rhs = builder.AddInstruction( 129 HloInstruction::CreateConstant(Literal::CreateR1<float>({4, 3, 2, 1}))); 130 builder.AddInstruction( 131 HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); 132 133 auto computation = builder.Build(); 134 auto hlo_module = CreateNewModule(); 135 hlo_module->AddEmbeddedComputation(std::move(max_f32)); 136 hlo_module->AddEntryComputation(std::move(computation)); 137 138 Inliner inliner; 139 EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); 140 EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), 141 op::Subtract(rhs, lhs)); 142 143 // Verify execution on CPU. 144 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 145 auto expected = Literal::CreateR1<float>({3, 1, -1, -3}); 146 LiteralTestUtil::ExpectEqual(*result, *expected); 147 } 148 149 150 } // namespace 151 } // namespace xla 152