1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <initializer_list> 17 #include <memory> 18 #include <vector> 19 20 #include "tensorflow/compiler/xla/client/client_library.h" 21 #include "tensorflow/compiler/xla/client/local_client.h" 22 #include "tensorflow/compiler/xla/client/xla_builder.h" 23 #include "tensorflow/compiler/xla/layout_util.h" 24 #include "tensorflow/compiler/xla/literal.h" 25 #include "tensorflow/compiler/xla/service/device_memory_allocator.h" 26 #include "tensorflow/compiler/xla/service/local_service.h" 27 #include "tensorflow/compiler/xla/service/platform_util.h" 28 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 29 #include "tensorflow/compiler/xla/service/transfer_manager.h" 30 #include "tensorflow/compiler/xla/shape_util.h" 31 #include "tensorflow/compiler/xla/statusor.h" 32 #include "tensorflow/compiler/xla/test.h" 33 #include "tensorflow/compiler/xla/test_helpers.h" 34 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 35 #include "tensorflow/compiler/xla/tests/local_client_test_base.h" 36 #include "tensorflow/compiler/xla/tests/test_macros.h" 37 #include "tensorflow/compiler/xla/tests/test_utils.h" 38 #include "tensorflow/compiler/xla/xla_data.pb.h" 39 #include "tensorflow/core/platform/env.h" 40 #include "tensorflow/core/platform/logging.h" 41 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 42 #include "tensorflow/core/platform/test.h" 43 #include "tensorflow/core/platform/test_benchmark.h" 44 45 namespace xla { 46 namespace { 47 48 using ::testing::ContainsRegex; 49 50 class LocalClientExecuteTest : public LocalClientTestBase { 51 protected: 52 ErrorSpec error_spec_{0.0001}; 53 }; 54 55 XLA_TEST_F(LocalClientExecuteTest, Constant) { 56 XlaBuilder builder(TestName()); 57 ConstantR0<float>(&builder, 123.0f); 58 59 ScopedShapedBuffer result = 60 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); 61 LiteralTestUtil::ExpectR0Near<float>(123.f, ShapedBufferToLiteral(result), 62 error_spec_); 63 } 64 65 XLA_TEST_F(LocalClientExecuteTest, AddScalars) { 66 XlaBuilder builder(TestName()); 67 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); 68 auto y = ConstantR0<float>(&builder, 123.0f); 69 Add(x, y); 70 71 auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0<float>(42.0f)); 72 ScopedShapedBuffer result = 73 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value}); 74 LiteralTestUtil::ExpectR0Near<float>(165.f, ShapedBufferToLiteral(result), 75 error_spec_); 76 } 77 78 XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { 79 XlaBuilder builder(TestName()); 80 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "x"); 81 auto y = ConstantR1<float>(&builder, {}); 82 Add(x, y); 83 84 auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({})); 85 ScopedShapedBuffer result = 86 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); 87 LiteralTestUtil::ExpectR1Near<float>({}, ShapedBufferToLiteral(result), 88 error_spec_); 89 } 90 91 XLA_TEST_F(LocalClientExecuteTest, AddVectors) { 92 XlaBuilder builder(TestName()); 93 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); 94 auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f}); 95 Add(x, y); 96 97 auto x_array = 98 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f})); 99 ScopedShapedBuffer result = 100 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); 101 LiteralTestUtil::ExpectR1Near<float>( 102 {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); 103 } 104 105 XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { 106 XlaBuilder builder(TestName()); 107 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); 108 auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f}); 109 Add(x, y); 110 111 auto x_array = 112 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f})); 113 ExecutionProfile profile; 114 ScopedShapedBuffer result = ExecuteLocallyOrDie( 115 builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(), 116 DefaultExecutableRunOptions().set_execution_profile(&profile)); 117 118 LiteralTestUtil::ExpectR1Near<float>( 119 {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); 120 EXPECT_GT(profile.compute_and_transfer_time_ns(), 0); 121 } 122 123 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { 124 XlaBuilder builder(TestName()); 125 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); 126 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); 127 Add(x, y); 128 auto computation = builder.Build().ConsumeValueOrDie(); 129 130 // Create x as a col-major array. 131 auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout( 132 {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); 133 EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(), 134 LayoutUtil::MakeLayout({0, 1}))); 135 136 // Create y as a row-major array. 137 auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout( 138 {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); 139 EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(), 140 LayoutUtil::MakeLayout({1, 0}))); 141 142 ScopedShapedBuffer result_colmaj = 143 ExecuteLocallyOrDie(computation, {&x_array, &y_array}); 144 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}}, 145 ShapedBufferToLiteral(result_colmaj), 146 error_spec_); 147 148 // Run with the parameter values in a different order. 149 ScopedShapedBuffer result_param_swap = 150 ExecuteLocallyOrDie(computation, {&y_array, &x_array}); 151 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}}, 152 ShapedBufferToLiteral(result_param_swap), 153 error_spec_); 154 } 155 156 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { 157 XlaBuilder builder(TestName()); 158 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); 159 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); 160 Add(x, y); 161 auto computation = builder.Build().ConsumeValueOrDie(); 162 163 auto x_array = LiteralToShapedBuffer( 164 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}})); 165 auto y_array = LiteralToShapedBuffer( 166 LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}})); 167 168 // Run with col-major result layout. 169 ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie( 170 computation, {&x_array, &y_array}, 171 DefaultExecutableBuildOptions().set_result_layout( 172 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})), 173 DefaultExecutableRunOptions()); 174 EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(), 175 LayoutUtil::MakeLayout({0, 1}))); 176 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}}, 177 ShapedBufferToLiteral(result_colmaj), 178 error_spec_); 179 180 // Run with row-major result layout. 181 ScopedShapedBuffer result_rowmaj = ExecuteLocallyOrDie( 182 computation, {&x_array, &y_array}, 183 DefaultExecutableBuildOptions().set_result_layout( 184 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})), 185 DefaultExecutableRunOptions()); 186 EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(), 187 LayoutUtil::MakeLayout({1, 0}))); 188 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}}, 189 ShapedBufferToLiteral(result_rowmaj), 190 error_spec_); 191 } 192 193 XLA_TEST_F(LocalClientExecuteTest, TupleResult) { 194 XlaBuilder builder(TestName()); 195 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); 196 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); 197 Tuple(&builder, {x, y, x}); 198 auto computation = builder.Build().ConsumeValueOrDie(); 199 200 auto x_array = LiteralToShapedBuffer( 201 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}})); 202 auto y_array = LiteralToShapedBuffer( 203 LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}})); 204 205 ScopedShapedBuffer result = 206 ExecuteLocallyOrDie(computation, {&x_array, &y_array}); 207 208 EXPECT_TRUE(result.on_host_shape().IsTuple()); 209 EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); 210 211 Literal result_literal = ShapedBufferToLiteral(result); 212 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}, 213 LiteralSlice(result_literal, {0})); 214 LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}, 215 LiteralSlice(result_literal, {1})); 216 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}, 217 LiteralSlice(result_literal, {2})); 218 } 219 220 XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { 221 XlaBuilder builder(TestName()); 222 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); 223 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); 224 auto inner_tuple = Tuple(&builder, {x, y, x}); 225 Tuple(&builder, {inner_tuple, x}); 226 auto computation = builder.Build().ConsumeValueOrDie(); 227 228 auto x_array = LiteralToShapedBuffer( 229 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}})); 230 auto y_array = LiteralToShapedBuffer( 231 LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}})); 232 233 ScopedShapedBuffer result = 234 ExecuteLocallyOrDie(computation, {&x_array, &y_array}); 235 236 EXPECT_TRUE(result.on_host_shape().IsTuple()); 237 EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); 238 239 Literal result_literal = ShapedBufferToLiteral(result); 240 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}, 241 LiteralSlice(result_literal, {1})); 242 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}, 243 LiteralSlice(result_literal, {0, 0})); 244 LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}, 245 LiteralSlice(result_literal, {0, 1})); 246 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}, 247 LiteralSlice(result_literal, {0, 2})); 248 } 249 250 XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { 251 // Verify setting the result layout of a computation with a tuple output. 252 XlaBuilder builder(TestName()); 253 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); 254 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); 255 Tuple(&builder, {x, y}); 256 257 auto array = LiteralToShapedBuffer( 258 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}})); 259 260 ExecutableBuildOptions options = DefaultExecutableBuildOptions(); 261 Shape shape_with_layout = ShapeUtil::MakeTupleShape( 262 {ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, 263 /*minor_to_major=*/{0, 1}), 264 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, 265 /*minor_to_major=*/{1, 0})}); 266 options.set_result_layout(shape_with_layout); 267 ScopedShapedBuffer result = 268 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array}, 269 options, DefaultExecutableRunOptions()); 270 271 Literal result_literal = ShapedBufferToLiteral(result); 272 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}, 273 LiteralSlice(result_literal, {0})); 274 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}, 275 LiteralSlice(result_literal, {1})); 276 } 277 278 XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { 279 const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2}); 280 const Shape vector_shape = ShapeUtil::MakeShape(F32, {3}); 281 282 const Shape tuple_shape0 = 283 ShapeUtil::MakeTupleShape({array_shape, vector_shape}); 284 const Shape tuple_shape1 = 285 ShapeUtil::MakeTupleShape({vector_shape, array_shape}); 286 287 // Computation adds the respective array and vector elements from each tuple 288 // argument and returns the results as a tuple. 289 XlaBuilder builder(TestName()); 290 auto x = Parameter(&builder, 0, tuple_shape0, "x"); 291 auto y = Parameter(&builder, 1, tuple_shape1, "y"); 292 auto x_0 = GetTupleElement(x, 0); 293 auto x_1 = GetTupleElement(x, 1); 294 auto y_0 = GetTupleElement(y, 0); 295 auto y_1 = GetTupleElement(y, 1); 296 auto array_sum = Add(x_0, y_1); 297 auto vector_diff = Sub(x_1, y_0); 298 Tuple(&builder, {array_sum, vector_diff}); 299 auto computation = builder.Build().ConsumeValueOrDie(); 300 301 auto x_literal = LiteralUtil::MakeTupleFromSlices( 302 {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), 303 LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})}); 304 auto y_literal = LiteralUtil::MakeTupleFromSlices( 305 {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}), 306 LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}})}); 307 308 auto x_buffer = LiteralToShapedBuffer(x_literal); 309 auto y_buffer = LiteralToShapedBuffer(y_literal); 310 311 ScopedShapedBuffer result = 312 ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer}); 313 314 EXPECT_TRUE(result.on_host_shape().IsTuple()); 315 EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); 316 317 Literal result_literal = ShapedBufferToLiteral(result); 318 LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}}, 319 LiteralSlice(result_literal, {0})); 320 LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f}, 321 LiteralSlice(result_literal, {1})); 322 } 323 324 XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { 325 const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2}); 326 const Shape vector_shape = ShapeUtil::MakeShape(F32, {3}); 327 328 const Shape inner_tuple_shape = 329 ShapeUtil::MakeTupleShape({array_shape, vector_shape}); 330 const Shape nested_tuple_shape = 331 ShapeUtil::MakeTupleShape({inner_tuple_shape, vector_shape}); 332 333 // Computation negates the array element and sums the two vector elements in 334 // the nested tuple. The resulting array and vector are returned as a tuple. 335 XlaBuilder builder(TestName()); 336 auto param = Parameter(&builder, 0, nested_tuple_shape, "param"); 337 auto inner_tuple = GetTupleElement(param, 0); 338 auto inner_array = GetTupleElement(inner_tuple, 0); 339 auto inner_vector = GetTupleElement(inner_tuple, 1); 340 auto outer_vector = GetTupleElement(param, 1); 341 342 auto negate_array = Neg(inner_array); 343 auto vector_sum = Add(inner_vector, outer_vector); 344 Tuple(&builder, {negate_array, vector_sum}); 345 auto computation = builder.Build().ConsumeValueOrDie(); 346 347 auto arg_literal = LiteralUtil::MakeTupleFromSlices( 348 {LiteralUtil::MakeTupleFromSlices( 349 {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), 350 LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})}), 351 LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0})}); 352 auto arg_buffer = LiteralToShapedBuffer(arg_literal); 353 354 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); 355 356 Literal result_literal = ShapedBufferToLiteral(result); 357 LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}}, 358 LiteralSlice(result_literal, {0})); 359 LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0}, 360 LiteralSlice(result_literal, {1})); 361 } 362 363 XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { 364 // Construct a computation which takes and returns the same shape (a 365 // tuple). Feed the result of the computation back into the input. This 366 // provides additional verification that the returned tuple is properly 367 // constructed. 368 const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2}); 369 const Shape tuple_shape = 370 ShapeUtil::MakeTupleShape({array_shape, array_shape}); 371 372 XlaBuilder builder(TestName()); 373 auto param = Parameter(&builder, 0, tuple_shape, "param"); 374 auto element_0 = GetTupleElement(param, 0); 375 auto element_1 = GetTupleElement(param, 1); 376 Tuple(&builder, {Neg(element_0), Add(element_1, element_1)}); 377 auto computation = builder.Build().ConsumeValueOrDie(); 378 379 auto arg_literal = LiteralUtil::MakeTupleFromSlices( 380 {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), 381 LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}})}); 382 auto arg_buffer = LiteralToShapedBuffer(arg_literal); 383 384 ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer}); 385 Literal result_0_literal = ShapedBufferToLiteral(result_0); 386 LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}}, 387 LiteralSlice(result_0_literal, {0})); 388 LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}}, 389 LiteralSlice(result_0_literal, {1})); 390 391 ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); 392 Literal result_1_literal = ShapedBufferToLiteral(result_1); 393 LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}}, 394 LiteralSlice(result_1_literal, {0})); 395 LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}}, 396 LiteralSlice(result_1_literal, {1})); 397 } 398 399 XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { 400 // Construct a computation which takes a tuple parameter with a very large 401 // number of elements. 402 403 // A larger number of elements would make for a better, more strenuous test, 404 // but: 405 // TODO(b/66959878): On cpu a large number of elements results in long 406 // compilation time. 407 // TODO(b/66954197): On gpu a large number of elements OOMs. 408 const int kElementCount = 100; 409 410 // Each element is a 2-element vector. 411 const Shape element_shape = ShapeUtil::MakeShape(F32, {2}); 412 std::vector<Shape> element_shapes(kElementCount, element_shape); 413 const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes); 414 415 XlaBuilder builder(TestName()); 416 auto param = Parameter(&builder, 0, tuple_shape, "param"); 417 418 // Add each element's tuple index value to every element. 419 std::vector<XlaOp> result_elements; 420 for (int i = 0; i < kElementCount; ++i) { 421 auto element = GetTupleElement(param, i); 422 result_elements.push_back(Add(element, ConstantR0<float>(&builder, i))); 423 } 424 Tuple(&builder, result_elements); 425 auto computation = builder.Build().ConsumeValueOrDie(); 426 427 // Feed in a tuple where each two-element vector element is {tuple_index, 428 // -tuple_index}. 429 std::vector<Literal> arg_elements; 430 for (int i = 0; i < kElementCount; ++i) { 431 arg_elements.push_back(LiteralUtil::CreateR1<float>({1.0f * i, -1.0f * i})); 432 } 433 Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements)); 434 auto arg_buffer = LiteralToShapedBuffer(arg_literal); 435 436 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); 437 Literal result_literal = ShapedBufferToLiteral(result); 438 439 for (int i = 0; i < kElementCount; ++i) { 440 LiteralTestUtil::ExpectR1Near<float>( 441 {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_); 442 } 443 } 444 445 XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { 446 // Construct and run a computation which takes a two-level nested tuple 447 // parameter with a large fanout. 448 const int kFanout = 40; 449 450 // Tuple shape is full two-level tree with the given fanout. 451 const Shape element_shape = ShapeUtil::MakeShape(F32, {}); 452 std::vector<Shape> element_shapes(kFanout, element_shape); 453 const Shape inner_tuple_shape = ShapeUtil::MakeTupleShape(element_shapes); 454 std::vector<Shape> inner_tuple_shapes(kFanout, inner_tuple_shape); 455 const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes); 456 457 XlaBuilder builder(TestName()); 458 auto param = Parameter(&builder, 0, tuple_shape, "param"); 459 460 // The computation increments each leaf value by an amount equal to the leaf's 461 // ordinal position in a traversal of the tuple. 462 std::vector<XlaOp> result_elements; 463 for (int i = 0; i < kFanout; ++i) { 464 auto outer_element = GetTupleElement(param, i); 465 std::vector<XlaOp> inner_result_elements; 466 for (int j = 0; j < kFanout; ++j) { 467 auto inner_element = GetTupleElement(outer_element, j); 468 inner_result_elements.push_back( 469 Add(inner_element, ConstantR0<float>(&builder, i * kFanout + j))); 470 } 471 result_elements.push_back(Tuple(&builder, inner_result_elements)); 472 } 473 Tuple(&builder, result_elements); 474 auto computation = builder.Build().ConsumeValueOrDie(); 475 476 // Construct the argument to pass to the computation. 477 std::vector<Literal> outer_tuple_elements; 478 for (int i = 0; i < kFanout; ++i) { 479 std::vector<Literal> inner_tuple_elements; 480 for (int j = 0; j < kFanout; ++j) { 481 inner_tuple_elements.push_back(LiteralUtil::CreateR0<float>(i + j)); 482 } 483 outer_tuple_elements.push_back( 484 LiteralUtil::MakeTupleOwned(std::move(inner_tuple_elements))); 485 } 486 auto arg_literal = 487 LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements)); 488 auto arg_buffer = LiteralToShapedBuffer(arg_literal); 489 490 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); 491 Literal result_literal = ShapedBufferToLiteral(result); 492 493 for (int i = 0; i < kFanout; ++i) { 494 for (int j = 0; j < kFanout; ++j) { 495 LiteralTestUtil::ExpectR0Near<float>(i + j + i * kFanout + j, 496 LiteralSlice(result_literal, {i, j}), 497 error_spec_); 498 } 499 } 500 } 501 502 XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { 503 // Construct and run a computation which takes a very deep tuple. The tuple 504 // has no fan out and a single scalar element at the bottom. 505 const int kTupleDepth = 100; 506 507 // Tuple shape is full two-level tree with the given fanout. 508 Shape shape = ShapeUtil::MakeShape(F32, {}); 509 for (int i = 0; i < kTupleDepth; ++i) { 510 shape = ShapeUtil::MakeTupleShape({shape}); 511 } 512 513 XlaBuilder builder(TestName()); 514 auto element = Parameter(&builder, 0, shape, "param"); 515 for (int i = 0; i < kTupleDepth; ++i) { 516 element = GetTupleElement(element, 0); 517 } 518 519 auto output = Add(element, ConstantR0<float>(&builder, 42.0)); 520 for (int i = 0; i < kTupleDepth; ++i) { 521 output = Tuple(&builder, {output}); 522 } 523 auto computation = builder.Build().ConsumeValueOrDie(); 524 525 // Construct the argument to pass to the computation. 526 Literal arg_literal = LiteralUtil::CreateR0<float>(123.0); 527 for (int i = 0; i < kTupleDepth; ++i) { 528 std::vector<Literal> arg_vector; 529 arg_vector.push_back(std::move(arg_literal)); 530 arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector)); 531 } 532 auto arg_buffer = LiteralToShapedBuffer(arg_literal); 533 534 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); 535 Literal result_literal = ShapedBufferToLiteral(result); 536 537 ShapeIndex index; 538 for (int i = 0; i < kTupleDepth; ++i) { 539 index.push_back(0); 540 } 541 LiteralTestUtil::ExpectR0Equal<float>(165.0, 542 LiteralSlice(result_literal, index)); 543 } 544 545 XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { 546 // Test passing in an invalid number of arguments. 547 XlaBuilder builder(TestName()); 548 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); 549 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {3}), "y"); 550 Add(x, y); 551 552 auto x_array = 553 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f})); 554 auto execute_status = 555 ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); 556 557 EXPECT_FALSE(execute_status.ok()); 558 EXPECT_THAT(execute_status.status().error_message(), 559 ContainsRegex("Invalid number of arguments")); 560 } 561 562 XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { 563 // Test passing in an argument with the wrong shape. 564 XlaBuilder builder(TestName()); 565 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); 566 Neg(x); 567 568 auto x_array = LiteralToShapedBuffer( 569 LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}})); 570 auto execute_status = 571 ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); 572 573 EXPECT_FALSE(execute_status.ok()); 574 EXPECT_THAT(execute_status.status().error_message(), 575 ContainsRegex("Invalid argument shape")) 576 << execute_status.status(); 577 } 578 579 XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { 580 // Test passing in an invalid result layout parameter. 581 XlaBuilder builder(TestName()); 582 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); 583 Neg(x); 584 585 auto x_array = LiteralToShapedBuffer( 586 LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}})); 587 auto execute_status = ExecuteLocally( 588 builder.Build().ValueOrDie(), {&x_array}, 589 DefaultExecutableBuildOptions().set_result_layout( 590 ShapeUtil::MakeShapeWithLayout(F32, 591 /*dimensions=*/{1, 2, 3, 4}, 592 /*minor_to_major=*/{0, 1, 2, 3})), 593 DefaultExecutableRunOptions()); 594 595 EXPECT_FALSE(execute_status.ok()); 596 EXPECT_THAT(execute_status.status().error_message(), 597 ContainsRegex("not compatible with result shape")) 598 << execute_status.status(); 599 } 600 601 XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { 602 // Try to run a trivial computation on every device on the system. If a 603 // specific device is not supported, check that the right error is returned. 604 XlaBuilder builder(TestName()); 605 ConstantR0<float>(&builder, 42.0f); 606 auto computation = builder.Build().ConsumeValueOrDie(); 607 for (int d = 0; d < local_client_->device_count(); ++d) { 608 if (!local_client_->device_ordinal_supported(d)) { 609 auto execute_status = 610 ExecuteLocally(computation, {}, 611 DefaultExecutableBuildOptions().set_device_ordinal(d), 612 DefaultExecutableRunOptions().set_device_ordinal(d)); 613 EXPECT_FALSE(execute_status.ok()); 614 EXPECT_THAT(execute_status.status().error_message(), 615 ContainsRegex("device .* not supported")); 616 } else { 617 auto result = ExecuteLocallyOrDie( 618 computation, {}, 619 DefaultExecutableBuildOptions().set_device_ordinal(d), 620 DefaultExecutableRunOptions().set_device_ordinal(d)); 621 EXPECT_EQ(d, result.device_ordinal()); 622 LiteralTestUtil::ExpectR0Equal<float>(42.0f, 623 ShapedBufferToLiteral(result)); 624 } 625 } 626 } 627 628 XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) { 629 // Try running computations on devices with device ordinal values which do not 630 // exist. 631 XlaBuilder builder(TestName()); 632 ConstantR0<float>(&builder, 42.0f); 633 auto computation = builder.Build().ConsumeValueOrDie(); 634 635 auto execute_status = 636 ExecuteLocally(computation, {}, 637 DefaultExecutableBuildOptions().set_device_ordinal( 638 local_client_->device_count()), 639 DefaultExecutableRunOptions().set_device_ordinal( 640 local_client_->device_count())); 641 EXPECT_FALSE(execute_status.ok()); 642 EXPECT_THAT(execute_status.status().error_message(), 643 ContainsRegex("Invalid device ordinal value")); 644 } 645 646 XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { 647 // Run a computation on a specific stream on each device on the system. 648 XlaBuilder builder(TestName()); 649 ConstantR0<float>(&builder, 42.0f); 650 auto computation = builder.Build().ConsumeValueOrDie(); 651 652 for (int d = 0; d < local_client_->device_count(); ++d) { 653 if (!local_client_->device_ordinal_supported(d)) { 654 continue; 655 } 656 se::StreamExecutor* executor = 657 local_client_->platform()->ExecutorForDevice(d).ValueOrDie(); 658 se::Stream stream(executor); 659 stream.Init(); 660 661 auto result = 662 ExecuteLocallyOrDie(computation, {}, DefaultExecutableBuildOptions(), 663 DefaultExecutableRunOptions().set_stream(&stream)); 664 // As a check to verify that the computation ran of the device associated 665 // with the stream. This is a weak check, but stronger verification is hard. 666 EXPECT_EQ(d, result.device_ordinal()); 667 LiteralTestUtil::ExpectR0Equal<float>(42.0f, ShapedBufferToLiteral(result)); 668 } 669 } 670 671 // Disable this test on CPU because we're using the CPU as the platform 672 // which does not match the service platform. 673 XLA_TEST_F(LocalClientExecuteTest, 674 DISABLED_ON_CPU(RunOnStreamForWrongPlatform)) { 675 // Try to run a computation on a stream for a platform (CPU) which does not 676 // match the platform of the service (!= CPU). 677 se::Platform* wrong_platform = 678 se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId) 679 .ValueOrDie(); 680 se::Stream wrong_stream(wrong_platform->ExecutorForDevice(0).ValueOrDie()); 681 wrong_stream.Init(); 682 683 XlaBuilder builder(TestName()); 684 ConstantR0<float>(&builder, 42.0f); 685 auto execute_status = ExecuteLocally( 686 builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), 687 DefaultExecutableRunOptions().set_stream(&wrong_stream)); 688 EXPECT_FALSE(execute_status.ok()); 689 EXPECT_THAT(execute_status.status().error_message(), 690 ContainsRegex("stream is for platform .*, but service targets")); 691 } 692 693 XLA_TEST_F(LocalClientExecuteTest, 694 DISABLED_ON_CPU(AllocatorDoesNotMatchPlatform)) { 695 se::Platform* wrong_platform = 696 se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId) 697 .ValueOrDie(); 698 TestAllocator allocator(wrong_platform); 699 700 XlaBuilder builder(TestName()); 701 ConstantR0<float>(&builder, 123.0f); 702 703 auto execute_status = ExecuteLocally( 704 builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), 705 DefaultExecutableRunOptions().set_allocator(&allocator)); 706 EXPECT_FALSE(execute_status.ok()); 707 EXPECT_THAT(execute_status.status().error_message(), 708 ContainsRegex("allocator platform .* does not match service")); 709 } 710 711 XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) { 712 // Try to run a computation on a stream that has not been initialized. 713 XlaBuilder builder(TestName()); 714 ConstantR0<float>(&builder, 42.0f); 715 716 LOG(INFO) << "default device = " << local_client_->default_device_ordinal(); 717 se::StreamExecutor* executor = 718 local_client_->platform() 719 ->ExecutorForDevice(local_client_->default_device_ordinal()) 720 .ValueOrDie(); 721 se::Stream stream(executor); 722 // Don't call stream.Init(). 723 724 auto execute_status = ExecuteLocally( 725 builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), 726 DefaultExecutableRunOptions().set_stream(&stream)); 727 EXPECT_FALSE(execute_status.ok()); 728 EXPECT_THAT(execute_status.status().error_message(), 729 ContainsRegex("stream is uninitialized or in an error state")); 730 } 731 732 XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { 733 XlaBuilder builder(TestName()); 734 735 std::initializer_list<float> vec1 = {1.f, 2.f, 3.f}; 736 std::initializer_list<float> vec2 = {2.f, 4.f, 6.f}; 737 auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1), 738 ConstantR1<float>(&builder, vec2)}); 739 auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2), 740 ConstantR1<float>(&builder, vec1)}); 741 Select(ConstantR0<bool>(&builder, false), tuple12, tuple21); 742 743 ScopedShapedBuffer result = 744 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); 745 Literal tuple_literal = ShapedBufferToLiteral(result); 746 LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f}, 747 LiteralSlice(tuple_literal, {0})); 748 LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f}, 749 LiteralSlice(tuple_literal, {1})); 750 } 751 752 XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { 753 XlaBuilder builder(TestName()); 754 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); 755 auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f}); 756 Add(x, y); 757 758 Shape argument_layout = 759 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0}); 760 auto executable_status = 761 local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout}, 762 ExecutableBuildOptions()); 763 ASSERT_IS_OK(executable_status); 764 std::unique_ptr<LocalExecutable> executable = 765 executable_status.ConsumeValueOrDie(); 766 767 auto x_array = 768 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f})); 769 ScopedShapedBuffer result = 770 executable->Run({&x_array}, DefaultExecutableRunOptions()) 771 .ConsumeValueOrDie(); 772 ASSERT_IS_OK(local_client_->mutable_backend() 773 ->BorrowStream(0) 774 .ValueOrDie() 775 ->BlockHostUntilDone()); 776 777 LiteralTestUtil::ExpectR1Near<float>( 778 {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); 779 } 780 781 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { 782 // Test copying Literals to the device as ShapedBuffers, then copying them 783 // back again to Literals. 784 auto test_to_device_and_back = [this](const Literal& literal) { 785 TF_ASSERT_OK_AND_ASSIGN( 786 auto shaped_buffer, 787 local_client_->LiteralToShapedBuffer( 788 literal, local_client_->default_device_ordinal(), allocator_)); 789 TF_ASSERT_OK_AND_ASSIGN( 790 auto transferred_literal, 791 local_client_->ShapedBufferToLiteral(shaped_buffer)); 792 EXPECT_EQ(literal, transferred_literal); 793 }; 794 795 // Array shapes. 796 test_to_device_and_back(LiteralUtil::CreateR0<float>(42.0)); 797 test_to_device_and_back(LiteralUtil::CreateR0<bool>(true)); 798 test_to_device_and_back(LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4})); 799 test_to_device_and_back( 800 LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); 801 test_to_device_and_back(LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}})); 802 803 // Null shape (empty tuple). 804 test_to_device_and_back(LiteralUtil::MakeTuple({})); 805 806 // Non-nested tuples. 807 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( 808 {LiteralUtil::CreateR0<float>(12223.0)})); 809 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( 810 {LiteralUtil::CreateR1<float>({1.0, -42.0}), 811 LiteralUtil::CreateR0<float>(123456.0)})); 812 813 // Nested tuple. 814 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( 815 {LiteralUtil::MakeTupleFromSlices( 816 {LiteralUtil::CreateR1<float>({1.0, -42.0}), 817 LiteralUtil::CreateR0<float>(123456.0)}), 818 LiteralUtil::CreateR0<bool>(false)})); 819 } 820 821 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { 822 // Test copying Literals to the device as ShapedBuffers, then copying them 823 // back again to Literals for 64-bit values. 824 auto test_to_device_and_back = [this](const Literal& literal) { 825 TF_ASSERT_OK_AND_ASSIGN( 826 auto shaped_buffer, 827 local_client_->LiteralToShapedBuffer( 828 literal, local_client_->default_device_ordinal(), allocator_)); 829 TF_ASSERT_OK_AND_ASSIGN( 830 auto transferred_literal, 831 local_client_->ShapedBufferToLiteral(shaped_buffer)); 832 EXPECT_EQ(literal, transferred_literal); 833 }; 834 835 test_to_device_and_back( 836 LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); 837 test_to_device_and_back(LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}})); 838 test_to_device_and_back( 839 LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}})); 840 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( 841 {LiteralUtil::CreateR1<double>({1.0, -42.0}), 842 LiteralUtil::CreateR0<int64>(123456789000LL)})); 843 } 844 845 // Disabled on interpreter backend since infeed HLO is unsupported. 846 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedTest)) { 847 XlaBuilder builder(TestName()); 848 const Shape shape = ShapeUtil::MakeShape(F32, {3}); 849 auto in = Infeed(&builder, shape); 850 auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f}); 851 Add(in, constant); 852 853 Literal result; 854 std::unique_ptr<tensorflow::Thread> thread( 855 tensorflow::Env::Default()->StartThread( 856 tensorflow::ThreadOptions(), "execute_thread", [&] { 857 result = ShapedBufferToLiteral(ExecuteLocallyOrDie( 858 builder.Build().ValueOrDie(), /*arguments=*/{})); 859 })); 860 861 ASSERT_IS_OK(local_client_->TransferToInfeedLocal( 862 LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}), 863 local_client_->default_device_ordinal())); 864 865 // Join the thread. 866 thread.reset(); 867 868 LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result); 869 } 870 871 // Disabled on interpreter backend since infeed/outfeed HLOs are unsupported. 872 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) { 873 XlaBuilder builder(TestName()); 874 const Shape shape = ShapeUtil::MakeShape(F32, {3}); 875 auto in = Infeed(&builder, shape); 876 auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f}); 877 auto sum = Add(in, constant); 878 Outfeed(sum, shape, /*outfeed_config=*/""); 879 880 std::unique_ptr<tensorflow::Thread> thread( 881 tensorflow::Env::Default()->StartThread( 882 tensorflow::ThreadOptions(), "execute_thread", 883 [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); })); 884 885 ASSERT_IS_OK(local_client_->TransferToInfeedLocal( 886 LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}), 887 local_client_->default_device_ordinal())); 888 889 TF_ASSERT_OK_AND_ASSIGN(Literal result, 890 local_client_->TransferFromOutfeedLocal( 891 shape, local_client_->default_device_ordinal())); 892 893 LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result); 894 } 895 896 // Benchmark that measures the overhead of the LocalClient API when running a 897 // trivial computation 898 void BM_LocalClientOverhead(int num_iters) { 899 tensorflow::testing::StopTiming(); 900 901 se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); 902 auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); 903 StreamExecutorMemoryAllocator allocator(platform, executors); 904 LocalClient* client = 905 ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie(); 906 auto* transfer_manager = 907 TransferManager::GetForPlatform(platform).ValueOrDie(); 908 int device_ordinal = client->default_device_ordinal(); 909 910 // Use a tiny add operation as the computation. 911 XlaBuilder builder("Add"); 912 auto shape = ShapeUtil::MakeShape(F32, {2, 3}); 913 auto x = Parameter(&builder, 0, shape, "x"); 914 Add(x, x); 915 auto computation = builder.Build().ConsumeValueOrDie(); 916 917 auto buffer = 918 transfer_manager 919 ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0) 920 .ConsumeValueOrDie(); 921 auto literal = LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, 0, 0}}); 922 auto stream = 923 client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); 924 ASSERT_IS_OK( 925 transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer)); 926 927 const int kWarmups = 2; 928 929 auto executable_status = client->Compile( 930 computation, {&buffer.on_host_shape()}, ExecutableBuildOptions()); 931 ASSERT_IS_OK(executable_status); 932 std::unique_ptr<LocalExecutable> executable = 933 executable_status.ConsumeValueOrDie(); 934 935 ExecutableRunOptions run_options; 936 run_options.set_allocator(&allocator).set_stream(stream.get()); 937 938 for (int i = 0; i < kWarmups; ++i) { 939 auto result = executable->Run({&buffer}, run_options); 940 ASSERT_IS_OK(result); 941 } 942 943 tensorflow::testing::StartTiming(); 944 for (int i = 0; i < num_iters; ++i) { 945 auto result = executable->Run({&buffer}, run_options); 946 ASSERT_IS_OK(result); 947 } 948 } 949 950 BENCHMARK(BM_LocalClientOverhead); 951 952 } // namespace 953 } // namespace xla 954