Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 
     18 #include "tensorflow/compiler/xla/client/computation.h"
     19 #include "tensorflow/compiler/xla/client/computation_builder.h"
     20 #include "tensorflow/compiler/xla/client/global_data.h"
     21 #include "tensorflow/compiler/xla/client/local_client.h"
     22 #include "tensorflow/compiler/xla/literal_util.h"
     23 #include "tensorflow/compiler/xla/protobuf_util.h"
     24 #include "tensorflow/compiler/xla/service/session.pb.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/statusor.h"
     27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     29 #include "tensorflow/compiler/xla/tests/test_macros.h"
     30 #include "tensorflow/compiler/xla/xla_data.pb.h"
     31 #include "tensorflow/core/platform/test.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 namespace xla {
     35 namespace {
     36 
     37 class ReplayTest : public ClientLibraryTestBase {};
     38 
     39 TEST_F(ReplayTest, TwoPlusTwoReplay) {
     40   // Make 2+2 computation.
     41   ComputationBuilder builder(client_, TestName());
     42   auto two = builder.ConstantR0<int32>(2);
     43   builder.Add(two, two);
     44   Computation computation = builder.Build().ConsumeValueOrDie();
     45 
     46   // Serialize it out.
     47   std::unique_ptr<SessionModule> module =
     48       computation.Snapshot().ConsumeValueOrDie();
     49 
     50   // Replay it.
     51   Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
     52 
     53   // Check signature is the same.
     54   std::unique_ptr<ProgramShape> original_shape =
     55       client_->GetComputationShape(computation).ConsumeValueOrDie();
     56   std::unique_ptr<ProgramShape> replayed_shape =
     57       client_->GetComputationShape(replayed).ConsumeValueOrDie();
     58   ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
     59 
     60   // Run it.
     61   std::unique_ptr<Literal> literal =
     62       client_
     63           ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
     64           .ConsumeValueOrDie();
     65 
     66   // Expect 4.
     67   LiteralTestUtil::ExpectR0Equal<int32>(4, *literal);
     68 }
     69 
     70 XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
     71   // Make computation.
     72   ComputationBuilder builder(client_, TestName());
     73   auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
     74   auto y = builder.Parameter(1, ShapeUtil::MakeShape(S32, {}), "y");
     75   builder.Add(x, y);
     76   Computation computation = builder.Build().ConsumeValueOrDie();
     77 
     78   // Serialize it out.
     79   std::unique_ptr<SessionModule> module =
     80       computation.Snapshot().ConsumeValueOrDie();
     81 
     82   // Replay it.
     83   Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
     84 
     85   // Check signature is the same.
     86   std::unique_ptr<ProgramShape> original_shape =
     87       client_->GetComputationShape(computation).ConsumeValueOrDie();
     88   std::unique_ptr<ProgramShape> replayed_shape =
     89       client_->GetComputationShape(replayed).ConsumeValueOrDie();
     90   ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
     91 
     92   // Run it.
     93   std::unique_ptr<GlobalData> x_data =
     94       client_->TransferToServer(*Literal::CreateR0<int32>(2))
     95           .ConsumeValueOrDie();
     96   std::unique_ptr<GlobalData> y_data =
     97       client_->TransferToServer(*Literal::CreateR0<int32>(3))
     98           .ConsumeValueOrDie();
     99   std::unique_ptr<Literal> literal =
    100       client_
    101           ->ExecuteAndTransfer(replayed,
    102                                /*arguments=*/{x_data.get(), y_data.get()},
    103                                &execution_options_)
    104           .ConsumeValueOrDie();
    105 
    106   // Expect 5.
    107   LiteralTestUtil::ExpectR0Equal<int32>(5, *literal);
    108 }
    109 
    110 TEST_F(ReplayTest, MapPlusTwoOverR1) {
    111   // As above, but with map(+2) over some constant array.
    112   ComputationBuilder plus_two_builder(client_, "plus two");
    113   auto input =
    114       plus_two_builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "input");
    115   plus_two_builder.Add(input, plus_two_builder.ConstantR0<int32>(2));
    116   Computation plus_two = plus_two_builder.Build().ConsumeValueOrDie();
    117 
    118   ComputationBuilder mapper_builder(client_, TestName());
    119   auto original = mapper_builder.ConstantR1<int32>({1, 2, 3});
    120   mapper_builder.Map({original}, plus_two, {0});
    121 
    122   Computation computation = mapper_builder.Build().ConsumeValueOrDie();
    123 
    124   // Serialize it out.
    125   std::unique_ptr<SessionModule> module =
    126       computation.Snapshot().ConsumeValueOrDie();
    127 
    128   // Replay it.
    129   Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
    130 
    131   // Check signature is the same.
    132   std::unique_ptr<ProgramShape> original_shape =
    133       client_->GetComputationShape(computation).ConsumeValueOrDie();
    134   std::unique_ptr<ProgramShape> replayed_shape =
    135       client_->GetComputationShape(replayed).ConsumeValueOrDie();
    136   ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
    137 
    138   // Destroy the originals.
    139   computation.Reset();
    140   plus_two.Reset();
    141 
    142   // Run it.
    143   std::unique_ptr<Literal> literal =
    144       client_
    145           ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
    146           .ConsumeValueOrDie();
    147 
    148   // Expect result.
    149   LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, *literal);
    150 }
    151 
    152 }  // namespace
    153 }  // namespace xla
    154