1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <unistd.h> 17 #include <memory> 18 19 #include "tensorflow/compiler/xla/client/global_data.h" 20 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 21 #include "tensorflow/compiler/xla/client/local_client.h" 22 #include "tensorflow/compiler/xla/client/xla_builder.h" 23 #include "tensorflow/compiler/xla/client/xla_computation.h" 24 #include "tensorflow/compiler/xla/literal.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/compiler/xla/test_helpers.h" 28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 29 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 #include "tensorflow/core/lib/math/math_util.h" 32 #include "tensorflow/core/platform/env.h" 33 #include "tensorflow/core/platform/test.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace xla { 37 namespace { 38 39 class InfeedTest : public ClientLibraryTestBase { 40 protected: 41 // Transfers the given literal to the infeed interface of the device, and 42 // check if the returned data from Infeed HLO is same as the literal. 43 void TestInfeedRoundTrip(const Literal& literal) { 44 // TODO(b/31037751) Explicitly reset the Infeed state so that the 45 // test is not affected by the state from the previous tests by 46 // adding ClearInfeed if necessary when it is implemented. For now 47 // don't use ResetDevice since it is not implemented on CPU. 48 ASSERT_IS_OK(client_->TransferToInfeed(literal)); 49 XlaBuilder builder(TestName()); 50 Infeed(&builder, literal.shape()); 51 if (literal.shape().IsTuple()) { 52 // TODO(b/30609564): Use ComputeAndCompareLiteral instead. 53 ComputeAndCompareTuple(&builder, literal, {}); 54 } else { 55 ComputeAndCompareLiteral(&builder, literal, {}); 56 } 57 } 58 }; 59 60 TEST_F(InfeedTest, SingleInfeedR0Bool) { 61 TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true)); 62 } 63 64 TEST_F(InfeedTest, SingleInfeedR1U32) { 65 TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3})); 66 } 67 68 TEST_F(InfeedTest, SingleInfeedR2F32) { 69 TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); 70 } 71 72 TEST_F(InfeedTest, SingleInfeedR3F32) { 73 TestInfeedRoundTrip( 74 LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, 75 {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); 76 } 77 78 TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { 79 const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); 80 const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); 81 82 TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( 83 {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, 84 {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, 85 r3_dim0minor)); 86 87 TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( 88 {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, 89 {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, 90 r3_dim0major)); 91 } 92 93 TEST_F(InfeedTest, SingleInfeedR4S32) { 94 TestInfeedRoundTrip(LiteralUtil::CreateR4( 95 {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, 96 {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); 97 } 98 99 TEST_F(InfeedTest, SingleInfeedTuple) { 100 TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( 101 {LiteralUtil::CreateR1<uint32>({1, 2, 3}), 102 LiteralUtil::CreateR0<bool>(false)})); 103 } 104 105 TEST_F(InfeedTest, SingleInfeedEmptyTuple) { 106 TestInfeedRoundTrip(LiteralUtil::MakeTuple({})); 107 } 108 109 // Tests Infeed operation used in a while loop, as in the code below. The 110 // computation is launched asynchronously, and then infeed data is transferred. 111 // 112 // float acc = 0.0f; 113 // while (acc < 40.0f) { 114 // acc += reduce_add(Infeed()); 115 // } 116 // return acc; 117 // TODO(b/30671675) enable this test once asynchronous execution is 118 // implemented for CPU. 119 TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { 120 XlaBuilder builder(TestName()); 121 const auto infeed_shape = ShapeUtil::MakeShape(F32, {3}); 122 const auto result_shape = ShapeUtil::MakeShape(F32, {}); 123 124 // Create a computation for the condition: repeat until (prev < 40.0f) holds. 125 XlaComputation condition; 126 { 127 XlaBuilder builder("condition"); 128 auto prev = Parameter(&builder, 0, result_shape, "prev"); 129 Gt(ConstantR0<float>(&builder, 40.0f), prev); 130 condition = builder.Build().ConsumeValueOrDie(); 131 } 132 // Create a computation for the body: add the reduced value of the Infeed 133 // data to the result variable. 134 XlaComputation body; 135 { 136 XlaBuilder builder("body"); 137 auto prev = Parameter(&builder, 0, result_shape, "prev"); 138 auto infeed = Infeed(&builder, infeed_shape); 139 auto addend = Reduce(infeed, ConstantR0<float>(&builder, 0.0f), 140 CreateScalarAddComputation(F32, &builder), {0}); 141 Add(prev, addend); 142 body = builder.Build().ConsumeValueOrDie(); 143 } 144 // Create a While node with computations for the condition and the body. 145 auto init = ConstantR0<float>(&builder, 0.0f); 146 While(condition, body, init); 147 148 // Build and asynchronously launch the computation. 149 auto computation = builder.Build().ConsumeValueOrDie(); 150 std::unique_ptr<GlobalData> result; 151 tensorflow::Thread* computation_thread = 152 tensorflow::Env::Default()->StartThread( 153 tensorflow::ThreadOptions{}, "computation_thread", [&] { 154 result = client_->Execute(computation, {}, &execution_options_) 155 .ValueOrDie(); 156 }); 157 158 // Send 5 Infeed data of shape F32[3]. 159 ASSERT_IS_OK( 160 client_->TransferToInfeed(LiteralUtil::CreateR1<float>({1, 2, 3}))); 161 ASSERT_IS_OK( 162 client_->TransferToInfeed(LiteralUtil::CreateR1<float>({4, 5, 6}))); 163 ASSERT_IS_OK( 164 client_->TransferToInfeed(LiteralUtil::CreateR1<float>({7, 8, 9}))); 165 ASSERT_IS_OK( 166 client_->TransferToInfeed(LiteralUtil::CreateR1<float>({10, 11, 12}))); 167 ASSERT_IS_OK( 168 client_->TransferToInfeed(LiteralUtil::CreateR1<float>({13, 14, 15}))); 169 170 delete computation_thread; // Joins the thread. 171 auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); 172 173 // Only the first 3 infeed data should be added. 174 LiteralTestUtil::ExpectR0Near<float>(45.0f, result_literal, ErrorSpec{1e-7}); 175 } 176 177 // Tests two Infeed operations with a total order. The order is enforced by 178 // using the result of the first while loop as the initial value of the second 179 // while loop. The shapes of both Infeeds are Tuples, where the first tuple 180 // element (R1F32) is for the data to reduce and accumulate, and the second 181 // tuple element (PRED) to indicate whether the loop should continue. The 182 // computation is launched asynchronously, and then infeed data is transferred. 183 // 184 // float acc = 0.0f; 185 // continue = true; 186 // while (!continue) { 187 // (data, continue) = Infeed(shape1); 188 // acc += reduce_add(data) 189 // } 190 // continue = true; 191 // while(!continue) { 192 // (data, continue) = Infeed(shape2); 193 // acc += reduce_add(data) 194 // } 195 // return acc; 196 // TODO(b/30671675) enable this test once asynchronous execution is 197 // implemented for CPU. 198 TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { 199 XlaBuilder builder(TestName()); 200 const auto infeed1_shape = ShapeUtil::MakeTupleShape( 201 {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(PRED, {})}); 202 const auto infeed2_shape = ShapeUtil::MakeTupleShape( 203 {ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(PRED, {})}); 204 const auto result_shape = ShapeUtil::MakeTupleShape( 205 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(PRED, {})}); 206 207 // Create a computation for the condition: repeat until the second tuple 208 // element is false. 209 XlaComputation condition; 210 { 211 XlaBuilder builder("condition"); 212 auto prev = Parameter(&builder, 0, result_shape, "prev"); 213 GetTupleElement(prev, 1); 214 condition = builder.Build().ConsumeValueOrDie(); 215 } 216 217 // A lambda that builds the body computation of a while loop with the given 218 // infeed shape, and returns the computation with the ownership. 219 // 220 // The body adds the reduced value of the Infeed data (first tuple element) 221 // to the previous accumulator, and returns the accumulator and the continue 222 // flag (second tuple element) as a tuple. 223 const auto build_body = [&result_shape](const Shape& infeed_shape) { 224 XlaComputation body; 225 XlaBuilder builder("body"); 226 auto prev = Parameter(&builder, 0, result_shape, "prev"); 227 auto infeed = Infeed(&builder, infeed_shape); 228 auto addend = 229 Reduce(GetTupleElement(infeed, 0), ConstantR0<float>(&builder, 0.0f), 230 CreateScalarAddComputation(F32, &builder), {0}); 231 auto result = Add(GetTupleElement(prev, 0), addend); 232 Tuple(&builder, {result, GetTupleElement(infeed, 1)}); 233 return builder.Build().ConsumeValueOrDie(); 234 }; 235 236 // Create the first while loop with infeed1_shape. 237 auto init = Tuple(&builder, {ConstantR0<float>(&builder, 0.0f), 238 ConstantR0<bool>(&builder, true)}); 239 auto while1 = While(condition, build_body(infeed1_shape), init); 240 auto result1 = Tuple( 241 &builder, {GetTupleElement(while1, 0), ConstantR0<bool>(&builder, true)}); 242 243 // Create the second while loop with infeed2_shape. Note that the result from 244 // the first while loop is used as the initial value. 245 auto while2 = While(condition, build_body(infeed2_shape), result1); 246 GetTupleElement(while2, 0); 247 248 // Build the computation. 249 auto computation = builder.Build().ConsumeValueOrDie(); 250 251 // Send the first 4 Infeed data of shape Tuple(F32[2], PRED). 252 ASSERT_IS_OK(client_->TransferToInfeed( 253 LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}), 254 LiteralUtil::CreateR0<bool>(true)}))); 255 ASSERT_IS_OK(client_->TransferToInfeed( 256 LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({3, 4}), 257 LiteralUtil::CreateR0<bool>(true)}))); 258 ASSERT_IS_OK(client_->TransferToInfeed( 259 LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({5, 6}), 260 LiteralUtil::CreateR0<bool>(true)}))); 261 ASSERT_IS_OK(client_->TransferToInfeed( 262 LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8}), 263 LiteralUtil::CreateR0<bool>(false)}))); 264 265 // Asynchronously launch the execution on the device. 266 std::unique_ptr<GlobalData> result; 267 tensorflow::Thread* computation_thread = 268 tensorflow::Env::Default()->StartThread( 269 tensorflow::ThreadOptions{}, "computation_thread", [&] { 270 result = client_->Execute(computation, {}, &execution_options_) 271 .ValueOrDie(); 272 }); 273 274 // Wait for a second to ensure testing that the execution is waiting on the 275 // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED). 276 sleep(1); 277 ASSERT_IS_OK(client_->TransferToInfeed( 278 LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2, 3}), 279 LiteralUtil::CreateR0<bool>(true)}))); 280 ASSERT_IS_OK(client_->TransferToInfeed( 281 LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8, 9}), 282 LiteralUtil::CreateR0<bool>(false)}))); 283 ASSERT_IS_OK(client_->TransferToInfeed( 284 LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({4, 5, 6}), 285 LiteralUtil::CreateR0<bool>(true)}))); 286 287 // Wait for the execution to be done, and transfer the result. 288 delete computation_thread; // Joins the thread. 289 auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); 290 291 // Only the first 6 infeed data should be added. 292 LiteralTestUtil::ExpectR0Near<float>(66.0f, result_literal, ErrorSpec{1e-7}); 293 } 294 295 } // namespace 296 } // namespace xla 297