1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" 17 18 #include <memory> 19 #include <utility> 20 21 #include "tensorflow/compiler/xla/layout_util.h" 22 #include "tensorflow/compiler/xla/literal_util.h" 23 #include "tensorflow/compiler/xla/service/hlo_computation.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" 28 #include "tensorflow/compiler/xla/shape_util.h" 29 #include "tensorflow/compiler/xla/test.h" 30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 31 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 32 #include "tensorflow/compiler/xla/types.h" 33 34 namespace op = xla::testing::opcode_matchers; 35 36 namespace xla { 37 namespace { 38 39 using HloConstantFoldingTest = HloTestBase; 40 41 TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { 42 HloComputation::Builder builder(TestName()); 43 HloInstruction* input = builder.AddInstruction( 44 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 45 builder.AddInstruction( 46 HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); 47 48 auto module = CreateNewModule(); 49 auto computation = module->AddEntryComputation(builder.Build()); 50 51 EXPECT_THAT(computation->root_instruction(), op::Convert(input)); 52 53 HloConstantFolding const_folder; 54 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); 55 EXPECT_TRUE(result); 56 57 EXPECT_THAT(computation->root_instruction(), op::Constant()); 58 EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<int64>(), 59 42); 60 } 61 62 TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { 63 HloComputation::Builder builder(TestName()); 64 HloInstruction* input = builder.AddInstruction( 65 HloInstruction::CreateConstant(Literal::CreateR0<int64>(42))); 66 builder.AddInstruction( 67 HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); 68 69 auto module = CreateNewModule(); 70 auto computation = module->AddEntryComputation(builder.Build()); 71 72 EXPECT_THAT(computation->root_instruction(), op::Convert(input)); 73 74 HloConstantFolding const_folder; 75 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); 76 EXPECT_TRUE(result); 77 78 EXPECT_THAT(computation->root_instruction(), op::Constant()); 79 EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(), 80 42.0f); 81 } 82 83 TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { 84 HloComputation::Builder builder(TestName()); 85 HloInstruction* input = builder.AddInstruction( 86 HloInstruction::CreateConstant(Literal::CreateR1<float>({42.0f, 19.0f}))); 87 builder.AddInstruction( 88 HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); 89 90 auto module = CreateNewModule(); 91 auto computation = module->AddEntryComputation(builder.Build()); 92 93 EXPECT_THAT(computation->root_instruction(), op::Convert(input)); 94 95 HloConstantFolding const_folder; 96 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); 97 EXPECT_TRUE(result); 98 99 EXPECT_THAT(computation->root_instruction(), op::Constant()); 100 EXPECT_EQ(computation->root_instruction()->literal().Get<int64>({0}), 42); 101 EXPECT_EQ(computation->root_instruction()->literal().Get<int64>({1}), 19); 102 } 103 104 TEST_F(HloConstantFoldingTest, Concatenate) { 105 const struct TestConfig { 106 int concat_dimension; 107 tensorflow::gtl::ArraySlice<int64> dimensions; 108 tensorflow::gtl::ArraySlice<int64> concat_sizes; 109 } test_configs[] = { 110 {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}}, 111 {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}}, 112 }; 113 114 for (auto& test_config : test_configs) { 115 HloComputation::Builder builder(TestName()); 116 std::vector<int64> dimensions(test_config.dimensions.begin(), 117 test_config.dimensions.end()); 118 int64 concat_size = 0; 119 std::vector<HloInstruction*> operands; 120 for (auto csize : test_config.concat_sizes) { 121 dimensions[test_config.concat_dimension] = csize; 122 concat_size += csize; 123 auto literal = Literal::CreateFromDimensions(F32, dimensions); 124 HloInstruction* insn = builder.AddInstruction( 125 HloInstruction::CreateConstant(std::move(literal))); 126 operands.push_back(insn); 127 } 128 dimensions[test_config.concat_dimension] = concat_size; 129 Shape shape = ShapeUtil::MakeShape(F32, dimensions); 130 builder.AddInstruction(HloInstruction::CreateConcatenate( 131 shape, operands, test_config.concat_dimension)); 132 auto module = CreateNewModule(); 133 auto computation = module->AddEntryComputation(builder.Build()); 134 135 HloConstantFolding const_folder; 136 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); 137 EXPECT_TRUE(result); 138 139 HloInstruction* root = computation->root_instruction(); 140 EXPECT_THAT(root, op::Constant()); 141 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); 142 } 143 } 144 145 TEST_F(HloConstantFoldingTest, Slice) { 146 HloComputation::Builder builder(TestName()); 147 const int64 dimensions[] = {11, 8, 7, 5, 9}; 148 const int64 slice_start[] = {4, 2, 3, 1, 5}; 149 const int64 slice_limits[] = {10, 8, 6, 5, 9}; 150 const int64 slice_strides[] = {1, 1, 1, 1, 1}; 151 TF_ASSERT_OK_AND_ASSIGN(auto literal, 152 LiteralTestUtil::CreateRandomLiteral<F32>( 153 ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); 154 HloInstruction* literal_instruction = builder.AddInstruction( 155 HloInstruction::CreateConstant(std::move(literal))); 156 Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); 157 builder.AddInstruction(HloInstruction::CreateSlice( 158 shape, literal_instruction, slice_start, slice_limits, slice_strides)); 159 auto module = CreateNewModule(); 160 auto computation = module->AddEntryComputation(builder.Build()); 161 162 HloConstantFolding const_folder; 163 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); 164 EXPECT_TRUE(result); 165 166 HloInstruction* root = computation->root_instruction(); 167 EXPECT_THAT(root, op::Constant()); 168 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); 169 } 170 171 TEST_F(HloConstantFoldingTest, TransposeConstantFold) { 172 HloComputation::Builder builder(TestName()); 173 const int64 dimensions[] = {11, 8, 7, 5, 9}; 174 TF_ASSERT_OK_AND_ASSIGN(auto literal, 175 LiteralTestUtil::CreateRandomLiteral<F32>( 176 ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); 177 auto literal_clone = literal->Literal::CloneToUnique(); 178 HloInstruction* literal_instruction = builder.AddInstruction( 179 HloInstruction::CreateConstant(std::move(literal))); 180 Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); 181 const int64 permutation[] = {1, 2, 0, 4, 3}; 182 builder.AddInstruction( 183 HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); 184 auto module = CreateNewModule(); 185 auto computation = module->AddEntryComputation(builder.Build()); 186 187 HloConstantFolding const_folder; 188 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); 189 EXPECT_TRUE(result); 190 191 HloInstruction* root = computation->root_instruction(); 192 EXPECT_THAT(root, op::Constant()); 193 EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape)); 194 195 using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type; 196 bool matched = true; 197 root->literal().EachCell<NativeT>( 198 [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) { 199 std::vector<int64> rindexes = Permute(permutation, indices); 200 matched = matched && (value == literal_clone->Get<NativeT>(rindexes)); 201 }); 202 EXPECT_TRUE(matched); 203 } 204 205 } // namespace 206 } // namespace xla 207