Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/user_computation.h"
     17 
     18 #include "tensorflow/compiler/xla/literal_util.h"
     19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     21 #include "tensorflow/compiler/xla/shape_util.h"
     22 #include "tensorflow/compiler/xla/status_macros.h"
     23 #include "tensorflow/compiler/xla/test.h"
     24 #include "tensorflow/compiler/xla/test_helpers.h"
     25 #include "tensorflow/compiler/xla/xla_data.pb.h"
     26 #include "tensorflow/core/lib/core/status_test_util.h"
     27 
     28 namespace op = xla::testing::opcode_matchers;
     29 
     30 namespace xla {
     31 namespace {
     32 
     33 using UserComputationTest = ::testing::Test;
     34 
     35 TEST_F(UserComputationTest, SimpleComputation) {
     36   const Shape kScalarShape = ShapeUtil::MakeShape(F32, {});
     37   const Shape kVectorShape = ShapeUtil::MakeShape(F32, {2});
     38 
     39   // Build a simple three operation computatation:
     40   //
     41   //   %constant = Constant({123, 42})
     42   //   %param = Param(0)
     43   //   %outfeed = Outfeed(%constant)
     44   //
     45   // Build the computation at two different versions and check invariants.
     46   ComputationHandle handle;
     47   handle.set_handle(123);
     48   UserComputation computation("TheComputation", handle);
     49 
     50   ConstantRequest constant_request;
     51   *constant_request.mutable_literal() =
     52       Literal::CreateR1<float>({123.0f, 42.0f})->ToProto();
     53   TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle constant_handle,
     54                           computation.AddConstantInstruction(constant_request));
     55 
     56   ParameterRequest param_request;
     57   *param_request.mutable_shape() = kScalarShape;
     58   param_request.set_parameter(0);
     59   param_request.set_name("param0");
     60   TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle param_handle,
     61                           computation.AddParameterInstruction(param_request));
     62   OpMetadata metadata;
     63   metadata.set_op_name("meta");
     64   TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata));
     65 
     66   OutfeedRequest outfeed_request;
     67   *outfeed_request.mutable_operand() = constant_handle;
     68   *outfeed_request.mutable_shape() = kVectorShape;
     69   outfeed_request.set_outfeed_config("abc");
     70   TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle,
     71                           computation.AddOutfeedInstruction(outfeed_request));
     72 
     73   auto hlo_resolver = [](const VersionedComputationHandle& handle) {
     74     return nullptr;
     75   };
     76   {
     77     // Test the computation at the latest version. In this case, the most
     78     // recently added operation is an outfeed. However, the outfeed is not the
     79     // root because outfeeds cannot be the root of a computation.
     80     VersionedComputationHandle latest_version =
     81         computation.GetVersionedHandle();
     82 
     83     // Program shape should have a single scalar parameter and scalar
     84     // result. The outfeed instruction should not affect the program shape.
     85     TF_ASSERT_OK_AND_ASSIGN(
     86         std::shared_ptr<const ProgramShape> program_shape,
     87         computation.ComputeProgramShape(latest_version.version));
     88     ASSERT_EQ(1, program_shape->parameters_size());
     89     EXPECT_TRUE(
     90         ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0)));
     91     EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result()));
     92 
     93     // Build the HLO computation.
     94     TF_ASSERT_OK_AND_ASSIGN(
     95         std::unique_ptr<HloComputation> hlo_computation,
     96         computation.BuildHloComputation(latest_version.version, hlo_resolver,
     97                                         DebugOptions()));
     98     // There should be one HloInstruction per UserComputation operation.
     99     EXPECT_EQ(3, hlo_computation->instruction_count());
    100     // The root of the instruction should be the parameter instruction (not the
    101     // outfeed).
    102     EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter());
    103   }
    104 
    105   {
    106     // Test the computation at the version right after the parameter instruction
    107     // is added.
    108     VersionedComputationHandle version_at_param =
    109         computation.GetVersionedHandleAtOperation(param_handle);
    110 
    111     // Program shape should have a single scalar parameter, and scalar result.
    112     TF_ASSERT_OK_AND_ASSIGN(
    113         std::shared_ptr<const ProgramShape> program_shape,
    114         computation.ComputeProgramShape(version_at_param.version));
    115     ASSERT_EQ(1, program_shape->parameters_size());
    116     EXPECT_TRUE(
    117         ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0)));
    118     EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result()));
    119 
    120     // There should be two instructions, one for the constant and one for the
    121     // parameter. The outfeed instruction should not be included.
    122     TF_ASSERT_OK_AND_ASSIGN(
    123         std::unique_ptr<HloComputation> hlo_computation,
    124         computation.BuildHloComputation(version_at_param.version, hlo_resolver,
    125                                         DebugOptions()));
    126     EXPECT_EQ(2, hlo_computation->instruction_count());
    127     EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter());
    128   }
    129   {
    130     // Test the computation at the latest version, but lowered with
    131     // include_unreachable_instructions set to false.
    132     VersionedComputationHandle latest_version =
    133         computation.GetVersionedHandle();
    134 
    135     // Build the HLO computation.
    136     TF_ASSERT_OK_AND_ASSIGN(
    137         std::unique_ptr<HloComputation> hlo_computation,
    138         computation.BuildHloComputation(
    139             latest_version.version, hlo_resolver, DebugOptions(),
    140             /*include_unreachable_instructions=*/false));
    141     // There is only one reachable instruction, the parameter.
    142     EXPECT_EQ(1, hlo_computation->instruction_count());
    143     // The root of the instruction should be the parameter instruction (not the
    144     // outfeed).
    145     EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter());
    146     EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(),
    147               "meta");
    148   }
    149 }
    150 
    151 TEST_F(UserComputationTest, EliminateScalarBroadcast) {
    152   auto debug_options = DebugOptions();
    153   debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);
    154 
    155   // Build a binary computation with scalar broadcast.
    156   //
    157   //  %a = Constant({123, 42})
    158   //  %b = Constant(1)
    159   //  %add = Add(%a, %b)
    160   ComputationHandle handle;
    161   handle.set_handle(123);
    162   UserComputation computation("TheComputation", handle);
    163 
    164   ConstantRequest a_request;
    165   *a_request.mutable_literal() =
    166       Literal::CreateR1<float>({123.0f, 42.0f})->ToProto();
    167   TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle,
    168                           computation.AddConstantInstruction(a_request));
    169 
    170   ConstantRequest b_request;
    171   *b_request.mutable_literal() = Literal::CreateR0<float>(1.0f)->ToProto();
    172   TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle,
    173                           computation.AddConstantInstruction(b_request));
    174 
    175   BinaryOpRequest add;
    176   add.set_binop(BINOP_ADD);
    177   *add.mutable_lhs() = a_handle;
    178   *add.mutable_rhs() = b_handle;
    179   TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
    180 
    181   auto hlo_resolver = [](const VersionedComputationHandle& handle) {
    182     return nullptr;
    183   };
    184   VersionedComputationHandle latest_version = computation.GetVersionedHandle();
    185 
    186   // Build the HLO computation.
    187   TF_ASSERT_OK_AND_ASSIGN(
    188       std::unique_ptr<HloComputation> hlo_computation,
    189       computation.BuildHloComputation(latest_version.version, hlo_resolver,
    190                                       debug_options));
    191   // The binary operation has implicit scalar broadcast, should be converted
    192   // to an explicit broadcast intruction and a binary instruction.
    193   EXPECT_EQ(4, hlo_computation->instruction_count());
    194   EXPECT_THAT(hlo_computation->root_instruction(), op::Add());
    195   LOG(INFO) << hlo_computation->root_instruction()->ToString();
    196   const auto& operands = hlo_computation->root_instruction()->operands();
    197   ASSERT_EQ(2, operands.size());
    198   EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast ||
    199               operands[1]->opcode() == HloOpcode::kBroadcast);
    200 }
    201 
    202 TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) {
    203   auto debug_options = DebugOptions();
    204   debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);
    205 
    206   // Build a binary computation with degenerate broadcast.
    207   //
    208   //  %a = Param({1, 2, 3});
    209   //  %b = Param({1, 2, 1});
    210   //  %add = Add(%a, %b, {});
    211   ComputationHandle handle;
    212   handle.set_handle(123);
    213   UserComputation computation("TheComputation", handle);
    214 
    215   ParameterRequest a_request;
    216   *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3});
    217   a_request.set_name("a");
    218   a_request.set_parameter(0);
    219   TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle,
    220                           computation.AddParameterInstruction(a_request));
    221 
    222   ParameterRequest b_request;
    223   *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1});
    224   b_request.set_name("b");
    225   b_request.set_parameter(1);
    226   TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle,
    227                           computation.AddParameterInstruction(b_request));
    228 
    229   const int64 kDevice = 7;
    230   OpSharding sharding;
    231   sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
    232   sharding.add_tile_assignment_dimensions(1);
    233   sharding.add_tile_assignment_devices(kDevice);
    234 
    235   TF_EXPECT_OK(computation.SetOpSharding(b_handle, sharding));
    236 
    237   BinaryOpRequest add;
    238   add.set_binop(BINOP_ADD);
    239   *add.mutable_lhs() = a_handle;
    240   *add.mutable_rhs() = b_handle;
    241   TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
    242 
    243   auto hlo_resolver = [](const VersionedComputationHandle& handle) {
    244     return nullptr;
    245   };
    246   VersionedComputationHandle latest_version = computation.GetVersionedHandle();
    247 
    248   // Build the HLO computation.
    249   TF_ASSERT_OK_AND_ASSIGN(
    250       std::unique_ptr<HloComputation> hlo_computation,
    251       computation.BuildHloComputation(latest_version.version, hlo_resolver,
    252                                       debug_options));
    253 
    254   //    b         a
    255   //    |         |
    256   // reshape      |
    257   //    |         |
    258   // broadcast    |
    259   //     \       /
    260   //        add
    261   EXPECT_EQ(5, hlo_computation->instruction_count());
    262   ASSERT_THAT(
    263       hlo_computation->root_instruction(),
    264       op::Add(op::Parameter(), op::Broadcast(op::Reshape(op::Parameter()))));
    265 
    266   const HloInstruction* broadcast =
    267       hlo_computation->root_instruction()->operand(1);
    268   EXPECT_TRUE(broadcast->has_sharding());
    269 
    270   const HloInstruction* reshape = broadcast->operand(0);
    271   EXPECT_TRUE(reshape->has_sharding());
    272 }
    273 
    274 TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) {
    275   auto debug_options = DebugOptions();
    276   debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);
    277 
    278   // Build a binary computation with in-dim broadcast and degenerate broadcast.
    279   //
    280   //  %a = Param({2, 3});
    281   //  %b = Param({2, 1, 4});
    282   //  %add = Add(%a, %b, {0, 1});
    283   ComputationHandle handle;
    284   handle.set_handle(123);
    285   UserComputation computation("TheComputation", handle);
    286 
    287   ParameterRequest a_request;
    288   *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3});
    289   a_request.set_name("a");
    290   a_request.set_parameter(0);
    291   TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle,
    292                           computation.AddParameterInstruction(a_request));
    293 
    294   ParameterRequest b_request;
    295   *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4});
    296   b_request.set_name("b");
    297   b_request.set_parameter(1);
    298   TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle,
    299                           computation.AddParameterInstruction(b_request));
    300 
    301   BinaryOpRequest add;
    302   add.set_binop(BINOP_ADD);
    303   *add.mutable_lhs() = a_handle;
    304   *add.mutable_rhs() = b_handle;
    305   add.add_broadcast_dimensions(0);
    306   add.add_broadcast_dimensions(1);
    307   TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
    308 
    309   auto hlo_resolver = [](const VersionedComputationHandle& handle) {
    310     return nullptr;
    311   };
    312   VersionedComputationHandle latest_version = computation.GetVersionedHandle();
    313 
    314   // Build the HLO computation.
    315   TF_ASSERT_OK_AND_ASSIGN(
    316       std::unique_ptr<HloComputation> hlo_computation,
    317       computation.BuildHloComputation(latest_version.version, hlo_resolver,
    318                                       debug_options));
    319 
    320   // The binary operation has in-dim broadcast and degenerate broadcast, should
    321   // first do the in-dim broadcast then convert the degnerate broadcast into a
    322   // reshape and a broadcast.
    323   //
    324   //    b         a
    325   //    |         |
    326   // broadcast reshape
    327   //    |         |
    328   //    |     broadcast
    329   //     \        /
    330   //        add
    331   EXPECT_EQ(6, hlo_computation->instruction_count());
    332   EXPECT_THAT(hlo_computation->root_instruction(), op::Add());
    333   const auto& operands = hlo_computation->root_instruction()->operands();
    334   ASSERT_EQ(2, operands.size());
    335   EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast &&
    336               operands[1]->opcode() == HloOpcode::kBroadcast);
    337 }
    338 
    339 }  // namespace
    340 }  // namespace xla
    341