1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <memory> 17 18 #include "tensorflow/compiler/xla/client/computation.h" 19 #include "tensorflow/compiler/xla/client/computation_builder.h" 20 #include "tensorflow/compiler/xla/client/global_data.h" 21 #include "tensorflow/compiler/xla/client/local_client.h" 22 #include "tensorflow/compiler/xla/literal_util.h" 23 #include "tensorflow/compiler/xla/protobuf_util.h" 24 #include "tensorflow/compiler/xla/service/session.pb.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 28 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 29 #include "tensorflow/compiler/xla/tests/test_macros.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 #include "tensorflow/core/platform/test.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace xla { 35 namespace { 36 37 class ReplayTest : public ClientLibraryTestBase {}; 38 39 TEST_F(ReplayTest, TwoPlusTwoReplay) { 40 // Make 2+2 computation. 41 ComputationBuilder builder(client_, TestName()); 42 auto two = builder.ConstantR0<int32>(2); 43 builder.Add(two, two); 44 Computation computation = builder.Build().ConsumeValueOrDie(); 45 46 // Serialize it out. 47 std::unique_ptr<SessionModule> module = 48 computation.Snapshot().ConsumeValueOrDie(); 49 50 // Replay it. 51 Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); 52 53 // Check signature is the same. 54 std::unique_ptr<ProgramShape> original_shape = 55 client_->GetComputationShape(computation).ConsumeValueOrDie(); 56 std::unique_ptr<ProgramShape> replayed_shape = 57 client_->GetComputationShape(replayed).ConsumeValueOrDie(); 58 ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); 59 60 // Run it. 61 std::unique_ptr<Literal> literal = 62 client_ 63 ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) 64 .ConsumeValueOrDie(); 65 66 // Expect 4. 67 LiteralTestUtil::ExpectR0Equal<int32>(4, *literal); 68 } 69 70 XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { 71 // Make computation. 72 ComputationBuilder builder(client_, TestName()); 73 auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); 74 auto y = builder.Parameter(1, ShapeUtil::MakeShape(S32, {}), "y"); 75 builder.Add(x, y); 76 Computation computation = builder.Build().ConsumeValueOrDie(); 77 78 // Serialize it out. 79 std::unique_ptr<SessionModule> module = 80 computation.Snapshot().ConsumeValueOrDie(); 81 82 // Replay it. 83 Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); 84 85 // Check signature is the same. 86 std::unique_ptr<ProgramShape> original_shape = 87 client_->GetComputationShape(computation).ConsumeValueOrDie(); 88 std::unique_ptr<ProgramShape> replayed_shape = 89 client_->GetComputationShape(replayed).ConsumeValueOrDie(); 90 ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); 91 92 // Run it. 93 std::unique_ptr<GlobalData> x_data = 94 client_->TransferToServer(*Literal::CreateR0<int32>(2)) 95 .ConsumeValueOrDie(); 96 std::unique_ptr<GlobalData> y_data = 97 client_->TransferToServer(*Literal::CreateR0<int32>(3)) 98 .ConsumeValueOrDie(); 99 std::unique_ptr<Literal> literal = 100 client_ 101 ->ExecuteAndTransfer(replayed, 102 /*arguments=*/{x_data.get(), y_data.get()}, 103 &execution_options_) 104 .ConsumeValueOrDie(); 105 106 // Expect 5. 107 LiteralTestUtil::ExpectR0Equal<int32>(5, *literal); 108 } 109 110 TEST_F(ReplayTest, MapPlusTwoOverR1) { 111 // As above, but with map(+2) over some constant array. 112 ComputationBuilder plus_two_builder(client_, "plus two"); 113 auto input = 114 plus_two_builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "input"); 115 plus_two_builder.Add(input, plus_two_builder.ConstantR0<int32>(2)); 116 Computation plus_two = plus_two_builder.Build().ConsumeValueOrDie(); 117 118 ComputationBuilder mapper_builder(client_, TestName()); 119 auto original = mapper_builder.ConstantR1<int32>({1, 2, 3}); 120 mapper_builder.Map({original}, plus_two, {0}); 121 122 Computation computation = mapper_builder.Build().ConsumeValueOrDie(); 123 124 // Serialize it out. 125 std::unique_ptr<SessionModule> module = 126 computation.Snapshot().ConsumeValueOrDie(); 127 128 // Replay it. 129 Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); 130 131 // Check signature is the same. 132 std::unique_ptr<ProgramShape> original_shape = 133 client_->GetComputationShape(computation).ConsumeValueOrDie(); 134 std::unique_ptr<ProgramShape> replayed_shape = 135 client_->GetComputationShape(replayed).ConsumeValueOrDie(); 136 ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); 137 138 // Destroy the originals. 139 computation.Reset(); 140 plus_two.Reset(); 141 142 // Run it. 143 std::unique_ptr<Literal> literal = 144 client_ 145 ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) 146 .ConsumeValueOrDie(); 147 148 // Expect result. 149 LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, *literal); 150 } 151 152 } // namespace 153 } // namespace xla 154