1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 17 #include "tensorflow/cc/framework/ops.h" 18 #include "tensorflow/cc/ops/data_flow_ops.h" 19 #include "tensorflow/cc/ops/function_ops.h" 20 #include "tensorflow/cc/ops/resource_variable_ops.h" 21 #include "tensorflow/cc/ops/standard_ops.h" 22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 24 #include "tensorflow/compiler/xla/client/client_library.h" 25 #include "tensorflow/compiler/xla/client/local_client.h" 26 #include "tensorflow/compiler/xla/literal_util.h" 27 #include "tensorflow/compiler/xla/shape_util.h" 28 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 29 #include "tensorflow/core/common_runtime/function.h" 30 #include "tensorflow/core/framework/common_shape_fns.h" 31 #include "tensorflow/core/framework/function.h" 32 #include "tensorflow/core/framework/function_testlib.h" 33 #include "tensorflow/core/framework/resource_mgr.h" 34 #include "tensorflow/core/framework/tensor_testutil.h" 35 #include "tensorflow/core/graph/graph.h" 36 #include "tensorflow/core/graph/graph_constructor.h" 37 #include "tensorflow/core/lib/core/status_test_util.h" 38 #include "tensorflow/core/platform/test.h" 39 #include "tensorflow/core/public/version.h" 40 41 namespace tensorflow { 42 43 class XlaCompilerTest : public ::testing::Test { 44 protected: 45 XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} 46 47 void SetUp() override { 48 client_ = xla::ClientLibrary::LocalClientOrDie(); 49 50 XlaOpRegistry::RegisterCompilationKernels(); 51 52 FunctionDefLibrary flib; 53 flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); 54 } 55 56 XlaCompiler::Options DefaultOptions() { 57 XlaCompiler::Options options; 58 options.device_type = &cpu_device_type_; 59 options.client = client_; 60 options.flib_def = flib_def_.get(); 61 return options; 62 } 63 64 FunctionLibraryDefinition* LocalFlibDef(XlaCompiler* compiler) { 65 return compiler->local_flib_def_.get(); 66 } 67 68 DeviceType cpu_device_type_; 69 xla::Client* client_; 70 std::unique_ptr<FunctionLibraryDefinition> flib_def_; 71 }; 72 73 namespace { 74 75 // Helper class to test the ability to pass resources through to XLA 76 // compiled kernels. 77 class DummyResourceForTest : public ResourceBase { 78 public: 79 string DebugString() override { return "dummy"; } 80 void Increment() { ++value_; } 81 int Get() { return value_; } 82 83 private: 84 int value_ = 0; 85 }; 86 87 class DummyReadResourceOp : public XlaOpKernel { 88 public: 89 explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 90 void Compile(XlaOpKernelContext* ctx) override { 91 ResourceMgr* rm = ctx->op_kernel_context()->resource_manager(); 92 OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); 93 DummyResourceForTest* dummy; 94 OP_REQUIRES_OK(ctx, rm->Lookup<DummyResourceForTest>( 95 rm->default_container(), "dummy", &dummy)); 96 dummy->Increment(); 97 dummy->Unref(); 98 99 ctx->SetOutput(0, ctx->Input(0)); 100 ctx->SetOutput(1, ctx->Input(0)); 101 } 102 }; 103 104 class DummyReadResourceCC { 105 public: 106 DummyReadResourceCC(const Scope& scope, const Input& value) { 107 if (!scope.ok()) return; 108 auto _value = ops::AsNodeOut(scope, value); 109 if (!scope.ok()) return; 110 Node* ret; 111 const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource"); 112 auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value); 113 scope.UpdateBuilder(&builder); 114 scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); 115 if (!scope.ok()) return; 116 scope.UpdateStatus(scope.DoShapeInference(ret)); 117 if (!scope.ok()) return; 118 this->output1_ = Output(ret, 0); 119 this->output2_ = Output(ret, 1); 120 } 121 122 Output output1_; 123 Output output2_; 124 }; 125 126 REGISTER_OP("DummyReadResource") 127 .Input("input: int32") 128 .Output("output1: int32") 129 .Output("output2: int32") 130 .SetShapeFn(shape_inference::UnknownShape) 131 .Doc(R"doc( 132 A dummy Op. 133 134 input: dummy input. 135 output1: dummy output. 136 output2: dummy output. 137 )doc"); 138 139 REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp); 140 141 // DummyDuplicateOp is present purely to test multiple REGISTER_XLA_OP calls 142 // on the same Op name below. 143 class DummyDuplicateOp : public XlaOpKernel { 144 public: 145 explicit DummyDuplicateOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 146 void Compile(XlaOpKernelContext* ctx) override { 147 ctx->SetOutput(0, ctx->Input(0)); 148 } 149 }; 150 151 REGISTER_OP("DummyDuplicateOp") 152 .Input("input: int32") 153 .Output("output: int32") 154 .Doc(R"doc( 155 A dummy Op. 156 157 input: dummy input. 158 output: dummy output. 159 )doc"); 160 161 REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT), 162 DummyDuplicateOp); 163 REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT), 164 DummyDuplicateOp); 165 166 167 // Tests compilation and execution of an empty graph. 168 TEST_F(XlaCompilerTest, EmptyReturnValues) { 169 XlaCompiler compiler(DefaultOptions()); 170 171 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 172 XlaCompiler::CompilationResult result; 173 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", 174 std::move(graph), 175 /*args=*/{}, &result)); 176 177 TF_ASSERT_OK(client_->Execute(*result.computation, {}).status()); 178 } 179 180 // Tests compilation and execution of a graph that adds two tensors. 181 TEST_F(XlaCompilerTest, Simple) { 182 // Builds a graph that adds two Tensors. 183 Scope scope = Scope::NewRootScope().ExitOnError(); 184 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); 185 auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); 186 auto c = ops::Add(scope.WithOpName("C"), a, b); 187 auto d = ops::_Retval(scope.WithOpName("D"), c, 0); 188 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 189 TF_ASSERT_OK(scope.ToGraph(graph.get())); 190 191 // Builds a description of the arguments. 192 std::vector<XlaCompiler::Argument> args(2); 193 args[0].kind = XlaCompiler::Argument::kParameter; 194 args[0].type = DT_INT32; 195 args[0].shape = TensorShape({2}); 196 args[1].kind = XlaCompiler::Argument::kParameter; 197 args[1].type = DT_INT32; 198 args[1].shape = TensorShape({2}); 199 200 // Compiles the graph. 201 XlaCompiler compiler(DefaultOptions()); 202 203 XlaCompiler::CompilationResult result; 204 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", 205 std::move(graph), args, &result)); 206 207 // Tests that the generated computation works. 208 std::unique_ptr<xla::Literal> param0_literal = 209 xla::Literal::CreateR1<int32>({7, 42}); 210 std::unique_ptr<xla::Literal> param1_literal = 211 xla::Literal::CreateR1<int32>({-3, 101}); 212 std::unique_ptr<xla::GlobalData> param0_data = 213 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 214 std::unique_ptr<xla::GlobalData> param1_data = 215 client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); 216 217 std::unique_ptr<xla::GlobalData> actual = 218 client_ 219 ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) 220 .ConsumeValueOrDie(); 221 std::unique_ptr<xla::Literal> actual_literal = 222 client_->Transfer(*actual).ConsumeValueOrDie(); 223 224 std::unique_ptr<xla::Literal> expected0 = 225 xla::Literal::CreateR1<int32>({4, 143}); 226 std::unique_ptr<xla::Literal> expected_literal = 227 xla::Literal::MakeTuple({expected0.get()}); 228 xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); 229 } 230 231 TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { 232 // Builds a graph that adds reshapes a tensor, but with the shape not 233 // statically known. 234 Scope scope = Scope::NewRootScope().ExitOnError(); 235 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); 236 auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); 237 auto c = ops::Reshape(scope.WithOpName("C"), a, b); 238 auto d = ops::_Retval(scope.WithOpName("D"), c, 0); 239 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 240 TF_ASSERT_OK(scope.ToGraph(graph.get())); 241 242 // Builds a description of the arguments. 243 std::vector<XlaCompiler::Argument> args(2); 244 args[0].kind = XlaCompiler::Argument::kParameter; 245 args[0].type = DT_INT32; 246 args[0].shape = TensorShape({2}); 247 args[1].kind = XlaCompiler::Argument::kParameter; 248 args[1].type = DT_INT32; 249 args[1].shape = TensorShape({2}); 250 251 // Compiles the graph. 252 XlaCompiler compiler(DefaultOptions()); 253 254 XlaCompiler::CompilationResult result; 255 Status status = 256 compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape", 257 std::move(graph), args, &result); 258 EXPECT_FALSE(status.ok()); 259 EXPECT_TRUE( 260 StringPiece(status.error_message()).contains("depends on a parameter")) 261 << status.error_message(); 262 EXPECT_TRUE( 263 StringPiece(status.error_message()).contains("[[Node: C = Reshape")) 264 << status.error_message(); 265 } 266 267 // Tests handling of compile-time constant outputs. 268 TEST_F(XlaCompilerTest, ConstantOutputs) { 269 // Builds a graph with one compile-time constant output and one data-dependent 270 // output, i.e., 271 // func(a) { b=7; c=-a; return b, c; } 272 Scope scope = Scope::NewRootScope().ExitOnError(); 273 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); 274 auto b = ops::Const<int32>(scope.WithOpName("B"), 7); 275 auto c = ops::Neg(scope.WithOpName("C"), a); 276 auto d = ops::_Retval(scope.WithOpName("D"), b, 0); 277 auto e = ops::_Retval(scope.WithOpName("E"), c, 1); 278 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 279 TF_ASSERT_OK(scope.ToGraph(graph.get())); 280 281 // Builds a description of the arguments. 282 std::vector<XlaCompiler::Argument> args(1); 283 args[0].kind = XlaCompiler::Argument::kParameter; 284 args[0].type = DT_INT32; 285 args[0].shape = TensorShape({2}); 286 287 XlaCompiler::Options options = DefaultOptions(); 288 XlaCompiler compiler(options); 289 { 290 // Compiles the graph, with resolve_compile_time_constants enabled. 291 292 std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global())); 293 CopyGraph(*graph, graph_copy.get()); 294 295 XlaCompiler::CompileOptions compile_options; 296 compile_options.resolve_compile_time_constants = true; 297 XlaCompiler::CompilationResult result; 298 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", 299 std::move(graph_copy), args, &result)); 300 301 ASSERT_EQ(2, result.outputs.size()); 302 EXPECT_TRUE(result.outputs[0].is_constant); 303 test::ExpectTensorEqual<int32>(result.outputs[0].constant_value, 304 test::AsScalar(7)); 305 EXPECT_FALSE(result.outputs[1].is_constant); 306 307 // Tests that the generated computation works. 308 std::unique_ptr<xla::Literal> param0_literal = 309 xla::Literal::CreateR1<int32>({7, 42}); 310 std::unique_ptr<xla::GlobalData> param0_data = 311 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 312 313 std::unique_ptr<xla::GlobalData> actual = 314 client_->Execute(*result.computation, {param0_data.get()}) 315 .ConsumeValueOrDie(); 316 std::unique_ptr<xla::Literal> actual_literal = 317 client_->Transfer(*actual).ConsumeValueOrDie(); 318 319 std::unique_ptr<xla::Literal> expected0 = 320 xla::Literal::CreateR1<int32>({-7, -42}); 321 std::unique_ptr<xla::Literal> expected_literal = 322 xla::Literal::MakeTuple({expected0.get()}); 323 xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); 324 } 325 326 { 327 // Compiles the graph, with resolve_compile_time_constants disabled. 328 std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global())); 329 CopyGraph(*graph, graph_copy.get()); 330 331 XlaCompiler::CompileOptions compile_options; 332 compile_options.resolve_compile_time_constants = false; 333 XlaCompiler::CompilationResult result; 334 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", 335 std::move(graph_copy), args, &result)); 336 337 ASSERT_EQ(2, result.outputs.size()); 338 EXPECT_FALSE(result.outputs[0].is_constant); 339 EXPECT_FALSE(result.outputs[1].is_constant); 340 341 // Tests that the generated computation works. 342 std::unique_ptr<xla::Literal> param0_literal = 343 xla::Literal::CreateR1<int32>({7, 42}); 344 std::unique_ptr<xla::GlobalData> param0_data = 345 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 346 347 std::unique_ptr<xla::GlobalData> actual = 348 client_->Execute(*result.computation, {param0_data.get()}) 349 .ConsumeValueOrDie(); 350 std::unique_ptr<xla::Literal> actual_literal = 351 client_->Transfer(*actual).ConsumeValueOrDie(); 352 353 std::unique_ptr<xla::Literal> expected0 = xla::Literal::CreateR0<int32>(7); 354 std::unique_ptr<xla::Literal> expected1 = 355 xla::Literal::CreateR1<int32>({-7, -42}); 356 std::unique_ptr<xla::Literal> expected = 357 xla::Literal::MakeTuple({expected0.get(), expected1.get()}); 358 xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal); 359 } 360 } 361 362 // Tests compilation and execution of a graph that adds two tensors. 363 TEST_F(XlaCompilerTest, ResourceManager) { 364 // Builds a graph that calls the dummy resource Op. 365 Scope scope = Scope::NewRootScope().ExitOnError(); 366 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); 367 auto b = DummyReadResourceCC(scope.WithOpName("B"), a); 368 auto c = ops::Add(scope.WithOpName("C"), b.output2_, b.output1_); 369 auto d = ops::_Retval(scope.WithOpName("D"), c, 0); 370 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 371 TF_ASSERT_OK(scope.ToGraph(graph.get())); 372 373 // Builds a description of the argument. 374 std::vector<XlaCompiler::Argument> args(1); 375 args[0].kind = XlaCompiler::Argument::kParameter; 376 args[0].type = DT_INT32; 377 args[0].shape = TensorShape({2}); 378 379 DummyResourceForTest* resource = new DummyResourceForTest(); 380 381 // Compiles the graph. 382 auto options = DefaultOptions(); 383 std::function<Status(ResourceMgr*)> populate_function = 384 [resource](ResourceMgr* rm) { 385 resource->Ref(); 386 return rm->Create(rm->default_container(), "dummy", resource); 387 }; 388 options.populate_resource_manager = &populate_function; 389 XlaCompiler compiler(options); 390 391 EXPECT_EQ(0, resource->Get()); 392 393 XlaCompiler::CompilationResult result; 394 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy", 395 std::move(graph), args, &result)); 396 397 EXPECT_EQ(1, resource->Get()); 398 399 resource->Unref(); 400 } 401 402 // Tests compilation and execution of a graph that adds two tensors. 403 TEST_F(XlaCompilerTest, DeterministicCompilation) { 404 // Builds a graph that contains a node with two output edges. The compiler 405 // should always traverse them in the same order. 406 const int64 test_count = 2; 407 408 std::vector<XlaCompiler::CompilationResult> results(test_count); 409 410 for (int64 i = 0; i < test_count; ++i) { 411 Scope scope = Scope::NewRootScope().ExitOnError(); 412 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); 413 auto b = ops::Neg(scope.WithOpName("B"), a); 414 auto c = ops::Neg(scope.WithOpName("C"), a); 415 auto d = ops::Add(scope.WithOpName("D"), b, c); 416 auto e = ops::_Retval(scope.WithOpName("E"), d, 0); 417 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 418 TF_ASSERT_OK(scope.ToGraph(graph.get())); 419 420 // Builds a description of the argument. 421 std::vector<XlaCompiler::Argument> args(1); 422 args[0].kind = XlaCompiler::Argument::kParameter; 423 args[0].type = DT_INT32; 424 args[0].shape = TensorShape({2}); 425 426 // Compiles the graph. 427 auto options = DefaultOptions(); 428 XlaCompiler compiler(options); 429 430 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy", 431 std::move(graph), args, &results[i])); 432 } 433 434 for (int64 i = 1; i < test_count; ++i) { 435 auto m1 = 436 results[i - 1].computation->Snapshot().ValueOrDie()->entry().requests(); 437 auto m2 = 438 results[i].computation->Snapshot().ValueOrDie()->entry().requests(); 439 // Check if every entry is the same. 440 for (auto& entry1 : m1) { 441 int64 key = entry1.first; 442 auto value1 = entry1.second; 443 auto entry2 = m2.find(key); 444 auto value2 = entry2->second; 445 EXPECT_TRUE(entry2 != m2.end()); 446 string str1, str2; 447 value1.AppendToString(&str1); 448 value2.AppendToString(&str2); 449 EXPECT_EQ(str1, str2); 450 } 451 } 452 } 453 454 // Tests a computation that receives a TensorArray resource as input and 455 // updates it. 456 TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { 457 Scope scope = Scope::NewRootScope().ExitOnError(); 458 auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); 459 auto flow = ops::Const<float>(scope, {}); 460 auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1"); 461 auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2"); 462 auto index = ops::Const<int32>(scope, 1); 463 auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index, 464 grad2.flow_out); 465 auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32); 466 auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); 467 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 468 TF_ASSERT_OK(scope.ToGraph(graph.get())); 469 470 // Builds a description of the arguments. 471 std::vector<XlaCompiler::Argument> args(1); 472 args[0].kind = XlaCompiler::Argument::kResource; 473 args[0].resource_kind = XlaResource::kTensorArray; 474 args[0].initialized = true; 475 args[0].type = DT_INT32; 476 args[0].shape = TensorShape({}); 477 args[0].tensor_array_size = 2; 478 args[0].tensor_array_gradients = {"grad2"}; 479 480 // Compiles the graph. 481 XlaCompiler compiler(DefaultOptions()); 482 483 XlaCompiler::CompilationResult result; 484 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", 485 std::move(graph), args, &result)); 486 487 ASSERT_EQ(1, result.resource_updates.size()); 488 const XlaCompiler::ResourceUpdate& update = result.resource_updates[0]; 489 EXPECT_EQ(0, update.input_index); 490 EXPECT_EQ(DT_INT32, update.type); 491 EXPECT_EQ((std::set<string>{"grad1", "grad2"}), 492 update.tensor_array_gradients_accessed); 493 494 // Tests that the generated computation works. 495 std::unique_ptr<xla::Literal> input_base = 496 xla::Literal::CreateR1<int32>({7, 42}); 497 std::unique_ptr<xla::Literal> input_grad2 = 498 xla::Literal::CreateR1<int32>({-3, 101}); 499 std::unique_ptr<xla::Literal> input = 500 xla::Literal::MakeTuple({input_base.get(), input_grad2.get()}); 501 std::unique_ptr<xla::GlobalData> param0_data = 502 client_->TransferToServer(*input).ConsumeValueOrDie(); 503 504 std::unique_ptr<xla::GlobalData> actual = 505 client_->Execute(*result.computation, {param0_data.get()}) 506 .ConsumeValueOrDie(); 507 std::unique_ptr<xla::Literal> actual_literal = 508 client_->Transfer(*actual).ConsumeValueOrDie(); 509 510 std::unique_ptr<xla::Literal> output_read = xla::Literal::CreateR0<int32>(42); 511 std::unique_ptr<xla::Literal> output_base = 512 xla::Literal::CreateR1<int32>({7, 42}); 513 std::unique_ptr<xla::Literal> output_grad1 = 514 xla::Literal::CreateR1<int32>({0, 1}); 515 std::unique_ptr<xla::Literal> output_grad2 = 516 xla::Literal::CreateR1<int32>({-3, 101}); 517 std::unique_ptr<xla::Literal> output_resource = xla::Literal::MakeTuple( 518 {output_base.get(), output_grad1.get(), output_grad2.get()}); 519 std::unique_ptr<xla::Literal> expected_literal = 520 xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); 521 xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); 522 } 523 524 // Tests compilation and execution of a graph that adds two tensors. 525 TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { 526 Scope scope = Scope::NewRootScope().ExitOnError(); 527 auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); 528 auto flow = ops::Const<float>(scope, {}); 529 auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1"); 530 auto index = ops::Const<int32>(scope, 1); 531 auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32); 532 auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); 533 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 534 TF_ASSERT_OK(scope.ToGraph(graph.get())); 535 536 // Builds a description of the arguments. 537 std::vector<XlaCompiler::Argument> args(1); 538 args[0].kind = XlaCompiler::Argument::kResource; 539 args[0].resource_kind = XlaResource::kTensorArray; 540 args[0].initialized = true; 541 args[0].type = DT_INT32; 542 args[0].shape = TensorShape({}); 543 args[0].tensor_array_size = 2; 544 args[0].tensor_array_gradients = {"grad1"}; 545 546 // Compiles the graph. 547 XlaCompiler compiler(DefaultOptions()); 548 549 XlaCompiler::CompilationResult result; 550 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", 551 std::move(graph), args, &result)); 552 553 EXPECT_EQ(0, result.resource_updates.size()); 554 } 555 556 // Tests compilation and execution of a graph that adds two tensors. 557 TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { 558 Scope scope = Scope::NewRootScope().ExitOnError(); 559 auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); 560 auto flow = ops::Const<float>(scope, {}); 561 auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2"); 562 auto index = ops::Const<int32>(scope, 1); 563 auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32); 564 auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); 565 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 566 TF_ASSERT_OK(scope.ToGraph(graph.get())); 567 568 // Builds a description of the arguments. 569 std::vector<XlaCompiler::Argument> args(1); 570 args[0].kind = XlaCompiler::Argument::kResource; 571 args[0].resource_kind = XlaResource::kTensorArray; 572 args[0].initialized = true; 573 args[0].type = DT_INT32; 574 args[0].shape = TensorShape({}); 575 args[0].tensor_array_size = 2; 576 args[0].tensor_array_gradients = {"grad1"}; 577 578 // Compiles the graph. 579 XlaCompiler compiler(DefaultOptions()); 580 581 XlaCompiler::CompilationResult result; 582 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", 583 std::move(graph), args, &result)); 584 585 EXPECT_EQ(1, result.resource_updates.size()); 586 } 587 588 // Tests CompileFunction with undefined function fails. 589 TEST_F(XlaCompilerTest, UndefinedFunctionFails) { 590 XlaCompiler compiler(DefaultOptions()); 591 592 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 593 XlaCompiler::CompilationResult result; 594 NameAttrList name_attr; 595 name_attr.set_name("Function_NotDefined_"); 596 Status status = 597 compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, 598 /*args=*/{}, &result); 599 EXPECT_FALSE(status.ok()); 600 EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) 601 << status.error_message(); 602 } 603 604 FunctionDef FillFn() { 605 return FunctionDefHelper::Define( 606 // Name 607 "FillFn", 608 // Args 609 {"x: T", "dims: int32"}, 610 // Return values 611 {"y: T"}, 612 // Attr def 613 {"T: {float, double, int32, int64}"}, 614 // Nodes 615 {{{"y"}, "Fill", {"dims", "x"}, {{"T", "$T"}}}}); 616 } 617 618 TEST_F(XlaCompilerTest, FunctionCallWithConstants) { 619 // Certain operations in a function, "Fill" for example, requires the 620 // operator's argument to be a compile-time constant instead of a parameter. 621 // This testcase tests if XlaCompiler can handle such operators inside 622 // function calls. 623 XlaCompiler compiler(DefaultOptions()); 624 625 FunctionDefLibrary flib; 626 *flib.add_function() = FillFn(); 627 628 TF_ASSERT_OK(flib_def_->AddFunctionDef(FillFn())); 629 630 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 631 632 Scope scope = Scope::NewRootScope().ExitOnError(); 633 auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {}); 634 auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1}); 635 TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); 636 637 NodeDef def; 638 TF_ASSERT_OK(NodeDefBuilder("fill", "FillFn", flib_def_.get()) 639 .Input(value.name(), 0, DT_INT32) 640 .Input(shape.name(), 1, DT_INT32) 641 .Finalize(&def)); 642 Status status; 643 Node* fill = scope.graph()->AddNode(def, &status); 644 TF_ASSERT_OK(status); 645 TF_ASSERT_OK(scope.DoShapeInference(fill)); 646 scope.graph()->AddEdge(value.node(), 0, fill, 0); 647 scope.graph()->AddEdge(shape.node(), 0, fill, 1); 648 649 auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0); 650 651 TF_ASSERT_OK(scope.ToGraph(graph.get())); 652 653 // Builds a description of the argument. 654 std::vector<XlaCompiler::Argument> args; 655 656 XlaCompiler::CompilationResult result; 657 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", 658 std::move(graph), args, &result)); 659 } 660 661 // Tests CompileFunction with a local function lookup failing, fails with 662 // informative error about both lookups. 663 TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { 664 XlaCompiler compiler(DefaultOptions()); 665 666 auto local_flib_def = LocalFlibDef(&compiler); 667 TF_ASSERT_OK(local_flib_def->AddFunctionDef(test::function::XTimesTwo())); 668 669 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 670 XlaCompiler::CompilationResult result; 671 NameAttrList name_attr; 672 name_attr.set_name("XTimesTwo"); 673 Status status = 674 compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, 675 /*args=*/{}, &result); 676 677 ASSERT_FALSE(status.ok()); 678 // Flib lookup failure. 679 EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) 680 << status.error_message(); 681 // Local flib lookup failure. 682 EXPECT_TRUE( 683 StringPiece(status.error_message()).contains("Attr T is not found")) 684 << status.error_message(); 685 } 686 687 // Tests a simple graph that reads and writes a variable. 688 TEST_F(XlaCompilerTest, Variables) { 689 Scope scope = Scope::NewRootScope().ExitOnError(); 690 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); 691 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); 692 auto write = ops::AssignAddVariableOp(scope, var, a); 693 auto read = ops::ReadVariableOp( 694 scope.WithControlDependencies(std::vector<Operation>{write}), var, 695 DT_INT32); 696 auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1)); 697 auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); 698 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 699 TF_ASSERT_OK(scope.ToGraph(graph.get())); 700 701 // Builds a description of the arguments. 702 std::vector<XlaCompiler::Argument> args(2); 703 args[0].kind = XlaCompiler::Argument::kParameter; 704 args[0].type = DT_INT32; 705 args[0].shape = TensorShape({2}); 706 args[1].kind = XlaCompiler::Argument::kResource; 707 args[1].resource_kind = XlaResource::kVariable; 708 args[1].initialized = true; 709 args[1].type = DT_INT32; 710 args[1].shape = TensorShape({2}); 711 712 // Compiles the graph. 713 XlaCompiler compiler(DefaultOptions()); 714 715 XlaCompiler::CompilationResult result; 716 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", 717 std::move(graph), args, &result)); 718 719 // Tests that the generated computation works. 720 std::unique_ptr<xla::Literal> param0_literal = 721 xla::Literal::CreateR1<int32>({7, 42}); 722 std::unique_ptr<xla::Literal> param1_literal = 723 xla::Literal::CreateR1<int32>({-3, 101}); 724 std::unique_ptr<xla::GlobalData> param0_data = 725 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 726 std::unique_ptr<xla::GlobalData> param1_data = 727 client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); 728 729 std::unique_ptr<xla::GlobalData> actual = 730 client_ 731 ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) 732 .ConsumeValueOrDie(); 733 std::unique_ptr<xla::Literal> actual_literal = 734 client_->Transfer(*actual).ConsumeValueOrDie(); 735 736 std::unique_ptr<xla::Literal> expected0 = 737 xla::Literal::CreateR1<int32>({5, 144}); 738 std::unique_ptr<xla::Literal> expected1 = 739 xla::Literal::CreateR1<int32>({4, 143}); 740 std::unique_ptr<xla::Literal> expected_literal = 741 xla::Literal::MakeTuple({expected0.get(), expected1.get()}); 742 xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); 743 } 744 745 // Tests a simple graph that reads and writes a variable, with a 746 // variable_representation_shape_fn passed to the compiler that flattens all 747 // variable tensors to vectors. 748 TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { 749 Scope scope = Scope::NewRootScope().ExitOnError(); 750 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); 751 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); 752 auto write = ops::AssignAddVariableOp(scope, var, a); 753 auto read = ops::ReadVariableOp( 754 scope.WithControlDependencies(std::vector<Operation>{write}), var, 755 DT_INT32); 756 auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1)); 757 auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); 758 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 759 TF_ASSERT_OK(scope.ToGraph(graph.get())); 760 761 // Builds a description of the arguments. 762 std::vector<XlaCompiler::Argument> args(2); 763 args[0].kind = XlaCompiler::Argument::kParameter; 764 args[0].type = DT_INT32; 765 args[0].shape = TensorShape({2, 2}); 766 args[1].kind = XlaCompiler::Argument::kResource; 767 args[1].resource_kind = XlaResource::kVariable; 768 args[1].initialized = true; 769 args[1].type = DT_INT32; 770 args[1].shape = TensorShape({2, 2}); 771 772 // Compiles the graph. 773 XlaCompiler::Options options = DefaultOptions(); 774 options.variable_representation_shape_fn = [](const TensorShape& shape, 775 DataType type) { 776 return TensorShape({shape.num_elements()}); 777 }; 778 XlaCompiler compiler(options); 779 780 XlaCompiler::CompilationResult result; 781 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", 782 std::move(graph), args, &result)); 783 784 // Tests that the generated computation works. 785 std::unique_ptr<xla::Literal> param0_literal = 786 xla::Literal::CreateR2<int32>({{4, 55}, {1, -3}}); 787 std::unique_ptr<xla::Literal> param1_literal = 788 xla::Literal::CreateR1<int32>({22, 11, 33, 404}); 789 std::unique_ptr<xla::GlobalData> param0_data = 790 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 791 std::unique_ptr<xla::GlobalData> param1_data = 792 client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); 793 794 std::unique_ptr<xla::GlobalData> actual = 795 client_ 796 ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) 797 .ConsumeValueOrDie(); 798 std::unique_ptr<xla::Literal> actual_literal = 799 client_->Transfer(*actual).ConsumeValueOrDie(); 800 801 std::unique_ptr<xla::Literal> expected0 = 802 xla::Literal::CreateR2<int32>({{27, 67}, {35, 402}}); 803 std::unique_ptr<xla::Literal> expected1 = 804 xla::Literal::CreateR1<int32>({26, 66, 34, 401}); 805 std::unique_ptr<xla::Literal> expected_literal = 806 xla::Literal::MakeTuple({expected0.get(), expected1.get()}); 807 xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); 808 } 809 810 } // namespace 811 } // namespace tensorflow 812