Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
     17 
     18 #include <memory>
     19 #include <utility>
     20 
     21 #include "tensorflow/compiler/xla/layout_util.h"
     22 #include "tensorflow/compiler/xla/literal_util.h"
     23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     25 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     27 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
     28 #include "tensorflow/compiler/xla/shape_util.h"
     29 #include "tensorflow/compiler/xla/test.h"
     30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     32 #include "tensorflow/compiler/xla/types.h"
     33 
     34 namespace op = xla::testing::opcode_matchers;
     35 
     36 namespace xla {
     37 namespace {
     38 
     39 using HloConstantFoldingTest = HloTestBase;
     40 
     41 TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
     42   HloComputation::Builder builder(TestName());
     43   HloInstruction* input = builder.AddInstruction(
     44       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
     45   builder.AddInstruction(
     46       HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input));
     47 
     48   auto module = CreateNewModule();
     49   auto computation = module->AddEntryComputation(builder.Build());
     50 
     51   EXPECT_THAT(computation->root_instruction(), op::Convert(input));
     52 
     53   HloConstantFolding const_folder;
     54   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
     55   EXPECT_TRUE(result);
     56 
     57   EXPECT_THAT(computation->root_instruction(), op::Constant());
     58   EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<int64>(),
     59             42);
     60 }
     61 
     62 TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
     63   HloComputation::Builder builder(TestName());
     64   HloInstruction* input = builder.AddInstruction(
     65       HloInstruction::CreateConstant(Literal::CreateR0<int64>(42)));
     66   builder.AddInstruction(
     67       HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
     68 
     69   auto module = CreateNewModule();
     70   auto computation = module->AddEntryComputation(builder.Build());
     71 
     72   EXPECT_THAT(computation->root_instruction(), op::Convert(input));
     73 
     74   HloConstantFolding const_folder;
     75   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
     76   EXPECT_TRUE(result);
     77 
     78   EXPECT_THAT(computation->root_instruction(), op::Constant());
     79   EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
     80             42.0f);
     81 }
     82 
     83 TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
     84   HloComputation::Builder builder(TestName());
     85   HloInstruction* input = builder.AddInstruction(
     86       HloInstruction::CreateConstant(Literal::CreateR1<float>({42.0f, 19.0f})));
     87   builder.AddInstruction(
     88       HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input));
     89 
     90   auto module = CreateNewModule();
     91   auto computation = module->AddEntryComputation(builder.Build());
     92 
     93   EXPECT_THAT(computation->root_instruction(), op::Convert(input));
     94 
     95   HloConstantFolding const_folder;
     96   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
     97   EXPECT_TRUE(result);
     98 
     99   EXPECT_THAT(computation->root_instruction(), op::Constant());
    100   EXPECT_EQ(computation->root_instruction()->literal().Get<int64>({0}), 42);
    101   EXPECT_EQ(computation->root_instruction()->literal().Get<int64>({1}), 19);
    102 }
    103 
    104 TEST_F(HloConstantFoldingTest, Concatenate) {
    105   const struct TestConfig {
    106     int concat_dimension;
    107     tensorflow::gtl::ArraySlice<int64> dimensions;
    108     tensorflow::gtl::ArraySlice<int64> concat_sizes;
    109   } test_configs[] = {
    110       {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}},
    111       {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}},
    112   };
    113 
    114   for (auto& test_config : test_configs) {
    115     HloComputation::Builder builder(TestName());
    116     std::vector<int64> dimensions(test_config.dimensions.begin(),
    117                                   test_config.dimensions.end());
    118     int64 concat_size = 0;
    119     std::vector<HloInstruction*> operands;
    120     for (auto csize : test_config.concat_sizes) {
    121       dimensions[test_config.concat_dimension] = csize;
    122       concat_size += csize;
    123       auto literal = Literal::CreateFromDimensions(F32, dimensions);
    124       HloInstruction* insn = builder.AddInstruction(
    125           HloInstruction::CreateConstant(std::move(literal)));
    126       operands.push_back(insn);
    127     }
    128     dimensions[test_config.concat_dimension] = concat_size;
    129     Shape shape = ShapeUtil::MakeShape(F32, dimensions);
    130     builder.AddInstruction(HloInstruction::CreateConcatenate(
    131         shape, operands, test_config.concat_dimension));
    132     auto module = CreateNewModule();
    133     auto computation = module->AddEntryComputation(builder.Build());
    134 
    135     HloConstantFolding const_folder;
    136     TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
    137     EXPECT_TRUE(result);
    138 
    139     HloInstruction* root = computation->root_instruction();
    140     EXPECT_THAT(root, op::Constant());
    141     EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
    142   }
    143 }
    144 
    145 TEST_F(HloConstantFoldingTest, Slice) {
    146   HloComputation::Builder builder(TestName());
    147   const int64 dimensions[] = {11, 8, 7, 5, 9};
    148   const int64 slice_start[] = {4, 2, 3, 1, 5};
    149   const int64 slice_limits[] = {10, 8, 6, 5, 9};
    150   const int64 slice_strides[] = {1, 1, 1, 1, 1};
    151   TF_ASSERT_OK_AND_ASSIGN(auto literal,
    152                           LiteralTestUtil::CreateRandomLiteral<F32>(
    153                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
    154   HloInstruction* literal_instruction = builder.AddInstruction(
    155       HloInstruction::CreateConstant(std::move(literal)));
    156   Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
    157   builder.AddInstruction(HloInstruction::CreateSlice(
    158       shape, literal_instruction, slice_start, slice_limits, slice_strides));
    159   auto module = CreateNewModule();
    160   auto computation = module->AddEntryComputation(builder.Build());
    161 
    162   HloConstantFolding const_folder;
    163   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
    164   EXPECT_TRUE(result);
    165 
    166   HloInstruction* root = computation->root_instruction();
    167   EXPECT_THAT(root, op::Constant());
    168   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
    169 }
    170 
    171 TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
    172   HloComputation::Builder builder(TestName());
    173   const int64 dimensions[] = {11, 8, 7, 5, 9};
    174   TF_ASSERT_OK_AND_ASSIGN(auto literal,
    175                           LiteralTestUtil::CreateRandomLiteral<F32>(
    176                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
    177   auto literal_clone = literal->Literal::CloneToUnique();
    178   HloInstruction* literal_instruction = builder.AddInstruction(
    179       HloInstruction::CreateConstant(std::move(literal)));
    180   Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
    181   const int64 permutation[] = {1, 2, 0, 4, 3};
    182   builder.AddInstruction(
    183       HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
    184   auto module = CreateNewModule();
    185   auto computation = module->AddEntryComputation(builder.Build());
    186 
    187   HloConstantFolding const_folder;
    188   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
    189   EXPECT_TRUE(result);
    190 
    191   HloInstruction* root = computation->root_instruction();
    192   EXPECT_THAT(root, op::Constant());
    193   EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape));
    194 
    195   using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
    196   bool matched = true;
    197   root->literal().EachCell<NativeT>(
    198       [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
    199         std::vector<int64> rindexes = Permute(permutation, indices);
    200         matched = matched && (value == literal_clone->Get<NativeT>(rindexes));
    201       });
    202   EXPECT_TRUE(matched);
    203 }
    204 
    205 }  // namespace
    206 }  // namespace xla
    207