Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
     17 #include "tensorflow/cc/framework/ops.h"
     18 #include "tensorflow/cc/ops/data_flow_ops.h"
     19 #include "tensorflow/cc/ops/function_ops.h"
     20 #include "tensorflow/cc/ops/resource_variable_ops.h"
     21 #include "tensorflow/cc/ops/standard_ops.h"
     22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     24 #include "tensorflow/compiler/xla/client/client_library.h"
     25 #include "tensorflow/compiler/xla/client/local_client.h"
     26 #include "tensorflow/compiler/xla/literal_util.h"
     27 #include "tensorflow/compiler/xla/shape_util.h"
     28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     29 #include "tensorflow/core/common_runtime/function.h"
     30 #include "tensorflow/core/framework/common_shape_fns.h"
     31 #include "tensorflow/core/framework/function.h"
     32 #include "tensorflow/core/framework/function_testlib.h"
     33 #include "tensorflow/core/framework/resource_mgr.h"
     34 #include "tensorflow/core/framework/tensor_testutil.h"
     35 #include "tensorflow/core/graph/graph.h"
     36 #include "tensorflow/core/graph/graph_constructor.h"
     37 #include "tensorflow/core/lib/core/status_test_util.h"
     38 #include "tensorflow/core/platform/test.h"
     39 #include "tensorflow/core/public/version.h"
     40 
     41 namespace tensorflow {
     42 
     43 class XlaCompilerTest : public ::testing::Test {
     44  protected:
     45   XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {}
     46 
     47   void SetUp() override {
     48     client_ = xla::ClientLibrary::LocalClientOrDie();
     49 
     50     XlaOpRegistry::RegisterCompilationKernels();
     51 
     52     FunctionDefLibrary flib;
     53     flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
     54   }
     55 
     56   XlaCompiler::Options DefaultOptions() {
     57     XlaCompiler::Options options;
     58     options.device_type = &cpu_device_type_;
     59     options.client = client_;
     60     options.flib_def = flib_def_.get();
     61     return options;
     62   }
     63 
     64   FunctionLibraryDefinition* LocalFlibDef(XlaCompiler* compiler) {
     65     return compiler->local_flib_def_.get();
     66   }
     67 
     68   DeviceType cpu_device_type_;
     69   xla::Client* client_;
     70   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
     71 };
     72 
     73 namespace {
     74 
     75 // Helper class to test the ability to pass resources through to XLA
     76 // compiled kernels.
     77 class DummyResourceForTest : public ResourceBase {
     78  public:
     79   string DebugString() override { return "dummy"; }
     80   void Increment() { ++value_; }
     81   int Get() { return value_; }
     82 
     83  private:
     84   int value_ = 0;
     85 };
     86 
     87 class DummyReadResourceOp : public XlaOpKernel {
     88  public:
     89   explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
     90   void Compile(XlaOpKernelContext* ctx) override {
     91     ResourceMgr* rm = ctx->op_kernel_context()->resource_manager();
     92     OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
     93     DummyResourceForTest* dummy;
     94     OP_REQUIRES_OK(ctx, rm->Lookup<DummyResourceForTest>(
     95                             rm->default_container(), "dummy", &dummy));
     96     dummy->Increment();
     97     dummy->Unref();
     98 
     99     ctx->SetOutput(0, ctx->Input(0));
    100     ctx->SetOutput(1, ctx->Input(0));
    101   }
    102 };
    103 
    104 class DummyReadResourceCC {
    105  public:
    106   DummyReadResourceCC(const Scope& scope, const Input& value) {
    107     if (!scope.ok()) return;
    108     auto _value = ops::AsNodeOut(scope, value);
    109     if (!scope.ok()) return;
    110     Node* ret;
    111     const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource");
    112     auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value);
    113     scope.UpdateBuilder(&builder);
    114     scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
    115     if (!scope.ok()) return;
    116     scope.UpdateStatus(scope.DoShapeInference(ret));
    117     if (!scope.ok()) return;
    118     this->output1_ = Output(ret, 0);
    119     this->output2_ = Output(ret, 1);
    120   }
    121 
    122   Output output1_;
    123   Output output2_;
    124 };
    125 
    126 REGISTER_OP("DummyReadResource")
    127     .Input("input: int32")
    128     .Output("output1: int32")
    129     .Output("output2: int32")
    130     .SetShapeFn(shape_inference::UnknownShape)
    131     .Doc(R"doc(
    132 A dummy Op.
    133 
    134 input: dummy input.
    135 output1: dummy output.
    136 output2: dummy output.
    137 )doc");
    138 
    139 REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp);
    140 
    141 // DummyDuplicateOp is present purely to test multiple REGISTER_XLA_OP calls
    142 // on the same Op name below.
    143 class DummyDuplicateOp : public XlaOpKernel {
    144  public:
    145   explicit DummyDuplicateOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    146   void Compile(XlaOpKernelContext* ctx) override {
    147     ctx->SetOutput(0, ctx->Input(0));
    148   }
    149 };
    150 
    151 REGISTER_OP("DummyDuplicateOp")
    152     .Input("input: int32")
    153     .Output("output: int32")
    154     .Doc(R"doc(
    155 A dummy Op.
    156 
    157 input: dummy input.
    158 output: dummy output.
    159 )doc");
    160 
    161 REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT),
    162                 DummyDuplicateOp);
    163 REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT),
    164                 DummyDuplicateOp);
    165 
    166 
    167 // Tests compilation and execution of an empty graph.
    168 TEST_F(XlaCompilerTest, EmptyReturnValues) {
    169   XlaCompiler compiler(DefaultOptions());
    170 
    171   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    172   XlaCompiler::CompilationResult result;
    173   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
    174                                      std::move(graph),
    175                                      /*args=*/{}, &result));
    176 
    177   TF_ASSERT_OK(client_->Execute(*result.computation, {}).status());
    178 }
    179 
    180 // Tests compilation and execution of a graph that adds two tensors.
    181 TEST_F(XlaCompilerTest, Simple) {
    182   // Builds a graph that adds two Tensors.
    183   Scope scope = Scope::NewRootScope().ExitOnError();
    184   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
    185   auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
    186   auto c = ops::Add(scope.WithOpName("C"), a, b);
    187   auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
    188   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    189   TF_ASSERT_OK(scope.ToGraph(graph.get()));
    190 
    191   // Builds a description of the arguments.
    192   std::vector<XlaCompiler::Argument> args(2);
    193   args[0].kind = XlaCompiler::Argument::kParameter;
    194   args[0].type = DT_INT32;
    195   args[0].shape = TensorShape({2});
    196   args[1].kind = XlaCompiler::Argument::kParameter;
    197   args[1].type = DT_INT32;
    198   args[1].shape = TensorShape({2});
    199 
    200   // Compiles the graph.
    201   XlaCompiler compiler(DefaultOptions());
    202 
    203   XlaCompiler::CompilationResult result;
    204   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
    205                                      std::move(graph), args, &result));
    206 
    207   // Tests that the generated computation works.
    208   std::unique_ptr<xla::Literal> param0_literal =
    209       xla::Literal::CreateR1<int32>({7, 42});
    210   std::unique_ptr<xla::Literal> param1_literal =
    211       xla::Literal::CreateR1<int32>({-3, 101});
    212   std::unique_ptr<xla::GlobalData> param0_data =
    213       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    214   std::unique_ptr<xla::GlobalData> param1_data =
    215       client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
    216 
    217   std::unique_ptr<xla::GlobalData> actual =
    218       client_
    219           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
    220           .ConsumeValueOrDie();
    221   std::unique_ptr<xla::Literal> actual_literal =
    222       client_->Transfer(*actual).ConsumeValueOrDie();
    223 
    224   std::unique_ptr<xla::Literal> expected0 =
    225       xla::Literal::CreateR1<int32>({4, 143});
    226   std::unique_ptr<xla::Literal> expected_literal =
    227       xla::Literal::MakeTuple({expected0.get()});
    228   xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
    229 }
    230 
    231 TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
    232   // Builds a graph that adds reshapes a tensor, but with the shape not
    233   // statically known.
    234   Scope scope = Scope::NewRootScope().ExitOnError();
    235   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
    236   auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
    237   auto c = ops::Reshape(scope.WithOpName("C"), a, b);
    238   auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
    239   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    240   TF_ASSERT_OK(scope.ToGraph(graph.get()));
    241 
    242   // Builds a description of the arguments.
    243   std::vector<XlaCompiler::Argument> args(2);
    244   args[0].kind = XlaCompiler::Argument::kParameter;
    245   args[0].type = DT_INT32;
    246   args[0].shape = TensorShape({2});
    247   args[1].kind = XlaCompiler::Argument::kParameter;
    248   args[1].type = DT_INT32;
    249   args[1].shape = TensorShape({2});
    250 
    251   // Compiles the graph.
    252   XlaCompiler compiler(DefaultOptions());
    253 
    254   XlaCompiler::CompilationResult result;
    255   Status status =
    256       compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape",
    257                             std::move(graph), args, &result);
    258   EXPECT_FALSE(status.ok());
    259   EXPECT_TRUE(
    260       StringPiece(status.error_message()).contains("depends on a parameter"))
    261       << status.error_message();
    262   EXPECT_TRUE(
    263       StringPiece(status.error_message()).contains("[[Node: C = Reshape"))
    264       << status.error_message();
    265 }
    266 
    267 // Tests handling of compile-time constant outputs.
    268 TEST_F(XlaCompilerTest, ConstantOutputs) {
    269   // Builds a graph with one compile-time constant output and one data-dependent
    270   // output, i.e.,
    271   // func(a) { b=7; c=-a; return b, c; }
    272   Scope scope = Scope::NewRootScope().ExitOnError();
    273   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
    274   auto b = ops::Const<int32>(scope.WithOpName("B"), 7);
    275   auto c = ops::Neg(scope.WithOpName("C"), a);
    276   auto d = ops::_Retval(scope.WithOpName("D"), b, 0);
    277   auto e = ops::_Retval(scope.WithOpName("E"), c, 1);
    278   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    279   TF_ASSERT_OK(scope.ToGraph(graph.get()));
    280 
    281   // Builds a description of the arguments.
    282   std::vector<XlaCompiler::Argument> args(1);
    283   args[0].kind = XlaCompiler::Argument::kParameter;
    284   args[0].type = DT_INT32;
    285   args[0].shape = TensorShape({2});
    286 
    287   XlaCompiler::Options options = DefaultOptions();
    288   XlaCompiler compiler(options);
    289   {
    290     // Compiles the graph, with resolve_compile_time_constants enabled.
    291 
    292     std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
    293     CopyGraph(*graph, graph_copy.get());
    294 
    295     XlaCompiler::CompileOptions compile_options;
    296     compile_options.resolve_compile_time_constants = true;
    297     XlaCompiler::CompilationResult result;
    298     TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
    299                                        std::move(graph_copy), args, &result));
    300 
    301     ASSERT_EQ(2, result.outputs.size());
    302     EXPECT_TRUE(result.outputs[0].is_constant);
    303     test::ExpectTensorEqual<int32>(result.outputs[0].constant_value,
    304                                    test::AsScalar(7));
    305     EXPECT_FALSE(result.outputs[1].is_constant);
    306 
    307     // Tests that the generated computation works.
    308     std::unique_ptr<xla::Literal> param0_literal =
    309         xla::Literal::CreateR1<int32>({7, 42});
    310     std::unique_ptr<xla::GlobalData> param0_data =
    311         client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    312 
    313     std::unique_ptr<xla::GlobalData> actual =
    314         client_->Execute(*result.computation, {param0_data.get()})
    315             .ConsumeValueOrDie();
    316     std::unique_ptr<xla::Literal> actual_literal =
    317         client_->Transfer(*actual).ConsumeValueOrDie();
    318 
    319     std::unique_ptr<xla::Literal> expected0 =
    320         xla::Literal::CreateR1<int32>({-7, -42});
    321     std::unique_ptr<xla::Literal> expected_literal =
    322         xla::Literal::MakeTuple({expected0.get()});
    323     xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
    324   }
    325 
    326   {
    327     // Compiles the graph, with resolve_compile_time_constants disabled.
    328     std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
    329     CopyGraph(*graph, graph_copy.get());
    330 
    331     XlaCompiler::CompileOptions compile_options;
    332     compile_options.resolve_compile_time_constants = false;
    333     XlaCompiler::CompilationResult result;
    334     TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
    335                                        std::move(graph_copy), args, &result));
    336 
    337     ASSERT_EQ(2, result.outputs.size());
    338     EXPECT_FALSE(result.outputs[0].is_constant);
    339     EXPECT_FALSE(result.outputs[1].is_constant);
    340 
    341     // Tests that the generated computation works.
    342     std::unique_ptr<xla::Literal> param0_literal =
    343         xla::Literal::CreateR1<int32>({7, 42});
    344     std::unique_ptr<xla::GlobalData> param0_data =
    345         client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    346 
    347     std::unique_ptr<xla::GlobalData> actual =
    348         client_->Execute(*result.computation, {param0_data.get()})
    349             .ConsumeValueOrDie();
    350     std::unique_ptr<xla::Literal> actual_literal =
    351         client_->Transfer(*actual).ConsumeValueOrDie();
    352 
    353     std::unique_ptr<xla::Literal> expected0 = xla::Literal::CreateR0<int32>(7);
    354     std::unique_ptr<xla::Literal> expected1 =
    355         xla::Literal::CreateR1<int32>({-7, -42});
    356     std::unique_ptr<xla::Literal> expected =
    357         xla::Literal::MakeTuple({expected0.get(), expected1.get()});
    358     xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal);
    359   }
    360 }
    361 
    362 // Tests compilation and execution of a graph that adds two tensors.
    363 TEST_F(XlaCompilerTest, ResourceManager) {
    364   // Builds a graph that calls the dummy resource Op.
    365   Scope scope = Scope::NewRootScope().ExitOnError();
    366   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
    367   auto b = DummyReadResourceCC(scope.WithOpName("B"), a);
    368   auto c = ops::Add(scope.WithOpName("C"), b.output2_, b.output1_);
    369   auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
    370   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    371   TF_ASSERT_OK(scope.ToGraph(graph.get()));
    372 
    373   // Builds a description of the argument.
    374   std::vector<XlaCompiler::Argument> args(1);
    375   args[0].kind = XlaCompiler::Argument::kParameter;
    376   args[0].type = DT_INT32;
    377   args[0].shape = TensorShape({2});
    378 
    379   DummyResourceForTest* resource = new DummyResourceForTest();
    380 
    381   // Compiles the graph.
    382   auto options = DefaultOptions();
    383   std::function<Status(ResourceMgr*)> populate_function =
    384       [resource](ResourceMgr* rm) {
    385         resource->Ref();
    386         return rm->Create(rm->default_container(), "dummy", resource);
    387       };
    388   options.populate_resource_manager = &populate_function;
    389   XlaCompiler compiler(options);
    390 
    391   EXPECT_EQ(0, resource->Get());
    392 
    393   XlaCompiler::CompilationResult result;
    394   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
    395                                      std::move(graph), args, &result));
    396 
    397   EXPECT_EQ(1, resource->Get());
    398 
    399   resource->Unref();
    400 }
    401 
    402 // Tests compilation and execution of a graph that adds two tensors.
    403 TEST_F(XlaCompilerTest, DeterministicCompilation) {
    404   // Builds a graph that contains a node with two output edges. The compiler
    405   // should always traverse them in the same order.
    406   const int64 test_count = 2;
    407 
    408   std::vector<XlaCompiler::CompilationResult> results(test_count);
    409 
    410   for (int64 i = 0; i < test_count; ++i) {
    411     Scope scope = Scope::NewRootScope().ExitOnError();
    412     auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
    413     auto b = ops::Neg(scope.WithOpName("B"), a);
    414     auto c = ops::Neg(scope.WithOpName("C"), a);
    415     auto d = ops::Add(scope.WithOpName("D"), b, c);
    416     auto e = ops::_Retval(scope.WithOpName("E"), d, 0);
    417     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    418     TF_ASSERT_OK(scope.ToGraph(graph.get()));
    419 
    420     // Builds a description of the argument.
    421     std::vector<XlaCompiler::Argument> args(1);
    422     args[0].kind = XlaCompiler::Argument::kParameter;
    423     args[0].type = DT_INT32;
    424     args[0].shape = TensorShape({2});
    425 
    426     // Compiles the graph.
    427     auto options = DefaultOptions();
    428     XlaCompiler compiler(options);
    429 
    430     TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
    431                                        std::move(graph), args, &results[i]));
    432   }
    433 
    434   for (int64 i = 1; i < test_count; ++i) {
    435     auto m1 =
    436         results[i - 1].computation->Snapshot().ValueOrDie()->entry().requests();
    437     auto m2 =
    438         results[i].computation->Snapshot().ValueOrDie()->entry().requests();
    439     // Check if every entry is the same.
    440     for (auto& entry1 : m1) {
    441       int64 key = entry1.first;
    442       auto value1 = entry1.second;
    443       auto entry2 = m2.find(key);
    444       auto value2 = entry2->second;
    445       EXPECT_TRUE(entry2 != m2.end());
    446       string str1, str2;
    447       value1.AppendToString(&str1);
    448       value2.AppendToString(&str2);
    449       EXPECT_EQ(str1, str2);
    450     }
    451   }
    452 }
    453 
    454 // Tests a computation that receives a TensorArray resource as input and
    455 // updates it.
    456 TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
    457   Scope scope = Scope::NewRootScope().ExitOnError();
    458   auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
    459   auto flow = ops::Const<float>(scope, {});
    460   auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
    461   auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2");
    462   auto index = ops::Const<int32>(scope, 1);
    463   auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index,
    464                                      grad2.flow_out);
    465   auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32);
    466   auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
    467   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    468   TF_ASSERT_OK(scope.ToGraph(graph.get()));
    469 
    470   // Builds a description of the arguments.
    471   std::vector<XlaCompiler::Argument> args(1);
    472   args[0].kind = XlaCompiler::Argument::kResource;
    473   args[0].resource_kind = XlaResource::kTensorArray;
    474   args[0].initialized = true;
    475   args[0].type = DT_INT32;
    476   args[0].shape = TensorShape({});
    477   args[0].tensor_array_size = 2;
    478   args[0].tensor_array_gradients = {"grad2"};
    479 
    480   // Compiles the graph.
    481   XlaCompiler compiler(DefaultOptions());
    482 
    483   XlaCompiler::CompilationResult result;
    484   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
    485                                      std::move(graph), args, &result));
    486 
    487   ASSERT_EQ(1, result.resource_updates.size());
    488   const XlaCompiler::ResourceUpdate& update = result.resource_updates[0];
    489   EXPECT_EQ(0, update.input_index);
    490   EXPECT_EQ(DT_INT32, update.type);
    491   EXPECT_EQ((std::set<string>{"grad1", "grad2"}),
    492             update.tensor_array_gradients_accessed);
    493 
    494   // Tests that the generated computation works.
    495   std::unique_ptr<xla::Literal> input_base =
    496       xla::Literal::CreateR1<int32>({7, 42});
    497   std::unique_ptr<xla::Literal> input_grad2 =
    498       xla::Literal::CreateR1<int32>({-3, 101});
    499   std::unique_ptr<xla::Literal> input =
    500       xla::Literal::MakeTuple({input_base.get(), input_grad2.get()});
    501   std::unique_ptr<xla::GlobalData> param0_data =
    502       client_->TransferToServer(*input).ConsumeValueOrDie();
    503 
    504   std::unique_ptr<xla::GlobalData> actual =
    505       client_->Execute(*result.computation, {param0_data.get()})
    506           .ConsumeValueOrDie();
    507   std::unique_ptr<xla::Literal> actual_literal =
    508       client_->Transfer(*actual).ConsumeValueOrDie();
    509 
    510   std::unique_ptr<xla::Literal> output_read = xla::Literal::CreateR0<int32>(42);
    511   std::unique_ptr<xla::Literal> output_base =
    512       xla::Literal::CreateR1<int32>({7, 42});
    513   std::unique_ptr<xla::Literal> output_grad1 =
    514       xla::Literal::CreateR1<int32>({0, 1});
    515   std::unique_ptr<xla::Literal> output_grad2 =
    516       xla::Literal::CreateR1<int32>({-3, 101});
    517   std::unique_ptr<xla::Literal> output_resource = xla::Literal::MakeTuple(
    518       {output_base.get(), output_grad1.get(), output_grad2.get()});
    519   std::unique_ptr<xla::Literal> expected_literal =
    520       xla::Literal::MakeTuple({output_read.get(), output_resource.get()});
    521   xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
    522 }
    523 
    524 // Tests compilation and execution of a graph that adds two tensors.
    525 TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
    526   Scope scope = Scope::NewRootScope().ExitOnError();
    527   auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
    528   auto flow = ops::Const<float>(scope, {});
    529   auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
    530   auto index = ops::Const<int32>(scope, 1);
    531   auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
    532   auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
    533   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    534   TF_ASSERT_OK(scope.ToGraph(graph.get()));
    535 
    536   // Builds a description of the arguments.
    537   std::vector<XlaCompiler::Argument> args(1);
    538   args[0].kind = XlaCompiler::Argument::kResource;
    539   args[0].resource_kind = XlaResource::kTensorArray;
    540   args[0].initialized = true;
    541   args[0].type = DT_INT32;
    542   args[0].shape = TensorShape({});
    543   args[0].tensor_array_size = 2;
    544   args[0].tensor_array_gradients = {"grad1"};
    545 
    546   // Compiles the graph.
    547   XlaCompiler compiler(DefaultOptions());
    548 
    549   XlaCompiler::CompilationResult result;
    550   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
    551                                      std::move(graph), args, &result));
    552 
    553   EXPECT_EQ(0, result.resource_updates.size());
    554 }
    555 
    556 // Tests compilation and execution of a graph that adds two tensors.
    557 TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
    558   Scope scope = Scope::NewRootScope().ExitOnError();
    559   auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
    560   auto flow = ops::Const<float>(scope, {});
    561   auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2");
    562   auto index = ops::Const<int32>(scope, 1);
    563   auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
    564   auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
    565   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    566   TF_ASSERT_OK(scope.ToGraph(graph.get()));
    567 
    568   // Builds a description of the arguments.
    569   std::vector<XlaCompiler::Argument> args(1);
    570   args[0].kind = XlaCompiler::Argument::kResource;
    571   args[0].resource_kind = XlaResource::kTensorArray;
    572   args[0].initialized = true;
    573   args[0].type = DT_INT32;
    574   args[0].shape = TensorShape({});
    575   args[0].tensor_array_size = 2;
    576   args[0].tensor_array_gradients = {"grad1"};
    577 
    578   // Compiles the graph.
    579   XlaCompiler compiler(DefaultOptions());
    580 
    581   XlaCompiler::CompilationResult result;
    582   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
    583                                      std::move(graph), args, &result));
    584 
    585   EXPECT_EQ(1, result.resource_updates.size());
    586 }
    587 
    588 // Tests CompileFunction with undefined function fails.
    589 TEST_F(XlaCompilerTest, UndefinedFunctionFails) {
    590   XlaCompiler compiler(DefaultOptions());
    591 
    592   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    593   XlaCompiler::CompilationResult result;
    594   NameAttrList name_attr;
    595   name_attr.set_name("Function_NotDefined_");
    596   Status status =
    597       compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
    598                                /*args=*/{}, &result);
    599   EXPECT_FALSE(status.ok());
    600   EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined."))
    601       << status.error_message();
    602 }
    603 
    604 FunctionDef FillFn() {
    605   return FunctionDefHelper::Define(
    606       // Name
    607       "FillFn",
    608       // Args
    609       {"x: T", "dims: int32"},
    610       // Return values
    611       {"y: T"},
    612       // Attr def
    613       {"T: {float, double, int32, int64}"},
    614       // Nodes
    615       {{{"y"}, "Fill", {"dims", "x"}, {{"T", "$T"}}}});
    616 }
    617 
    618 TEST_F(XlaCompilerTest, FunctionCallWithConstants) {
    619   // Certain operations in a function, "Fill" for example, requires the
    620   // operator's argument to be a compile-time constant instead of a parameter.
    621   // This testcase tests if XlaCompiler can handle such operators inside
    622   // function calls.
    623   XlaCompiler compiler(DefaultOptions());
    624 
    625   FunctionDefLibrary flib;
    626   *flib.add_function() = FillFn();
    627 
    628   TF_ASSERT_OK(flib_def_->AddFunctionDef(FillFn()));
    629 
    630   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    631 
    632   Scope scope = Scope::NewRootScope().ExitOnError();
    633   auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
    634   auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
    635   TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
    636 
    637   NodeDef def;
    638   TF_ASSERT_OK(NodeDefBuilder("fill", "FillFn", flib_def_.get())
    639                    .Input(value.name(), 0, DT_INT32)
    640                    .Input(shape.name(), 1, DT_INT32)
    641                    .Finalize(&def));
    642   Status status;
    643   Node* fill = scope.graph()->AddNode(def, &status);
    644   TF_ASSERT_OK(status);
    645   TF_ASSERT_OK(scope.DoShapeInference(fill));
    646   scope.graph()->AddEdge(value.node(), 0, fill, 0);
    647   scope.graph()->AddEdge(shape.node(), 0, fill, 1);
    648 
    649   auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
    650 
    651   TF_ASSERT_OK(scope.ToGraph(graph.get()));
    652 
    653   // Builds a description of the argument.
    654   std::vector<XlaCompiler::Argument> args;
    655 
    656   XlaCompiler::CompilationResult result;
    657   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
    658                                      std::move(graph), args, &result));
    659 }
    660 
    661 // Tests CompileFunction with a local function lookup failing, fails with
    662 // informative error about both lookups.
    663 TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
    664   XlaCompiler compiler(DefaultOptions());
    665 
    666   auto local_flib_def = LocalFlibDef(&compiler);
    667   TF_ASSERT_OK(local_flib_def->AddFunctionDef(test::function::XTimesTwo()));
    668 
    669   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    670   XlaCompiler::CompilationResult result;
    671   NameAttrList name_attr;
    672   name_attr.set_name("XTimesTwo");
    673   Status status =
    674       compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
    675                                /*args=*/{}, &result);
    676 
    677   ASSERT_FALSE(status.ok());
    678   // Flib lookup failure.
    679   EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined."))
    680       << status.error_message();
    681   // Local flib lookup failure.
    682   EXPECT_TRUE(
    683       StringPiece(status.error_message()).contains("Attr T is not found"))
    684       << status.error_message();
    685 }
    686 
    687 // Tests a simple graph that reads and writes a variable.
    688 TEST_F(XlaCompilerTest, Variables) {
    689   Scope scope = Scope::NewRootScope().ExitOnError();
    690   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
    691   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
    692   auto write = ops::AssignAddVariableOp(scope, var, a);
    693   auto read = ops::ReadVariableOp(
    694       scope.WithControlDependencies(std::vector<Operation>{write}), var,
    695       DT_INT32);
    696   auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
    697   auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
    698   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    699   TF_ASSERT_OK(scope.ToGraph(graph.get()));
    700 
    701   // Builds a description of the arguments.
    702   std::vector<XlaCompiler::Argument> args(2);
    703   args[0].kind = XlaCompiler::Argument::kParameter;
    704   args[0].type = DT_INT32;
    705   args[0].shape = TensorShape({2});
    706   args[1].kind = XlaCompiler::Argument::kResource;
    707   args[1].resource_kind = XlaResource::kVariable;
    708   args[1].initialized = true;
    709   args[1].type = DT_INT32;
    710   args[1].shape = TensorShape({2});
    711 
    712   // Compiles the graph.
    713   XlaCompiler compiler(DefaultOptions());
    714 
    715   XlaCompiler::CompilationResult result;
    716   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
    717                                      std::move(graph), args, &result));
    718 
    719   // Tests that the generated computation works.
    720   std::unique_ptr<xla::Literal> param0_literal =
    721       xla::Literal::CreateR1<int32>({7, 42});
    722   std::unique_ptr<xla::Literal> param1_literal =
    723       xla::Literal::CreateR1<int32>({-3, 101});
    724   std::unique_ptr<xla::GlobalData> param0_data =
    725       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    726   std::unique_ptr<xla::GlobalData> param1_data =
    727       client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
    728 
    729   std::unique_ptr<xla::GlobalData> actual =
    730       client_
    731           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
    732           .ConsumeValueOrDie();
    733   std::unique_ptr<xla::Literal> actual_literal =
    734       client_->Transfer(*actual).ConsumeValueOrDie();
    735 
    736   std::unique_ptr<xla::Literal> expected0 =
    737       xla::Literal::CreateR1<int32>({5, 144});
    738   std::unique_ptr<xla::Literal> expected1 =
    739       xla::Literal::CreateR1<int32>({4, 143});
    740   std::unique_ptr<xla::Literal> expected_literal =
    741       xla::Literal::MakeTuple({expected0.get(), expected1.get()});
    742   xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
    743 }
    744 
    745 // Tests a simple graph that reads and writes a variable, with a
    746 // variable_representation_shape_fn passed to the compiler that flattens all
    747 // variable tensors to vectors.
    748 TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
    749   Scope scope = Scope::NewRootScope().ExitOnError();
    750   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
    751   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
    752   auto write = ops::AssignAddVariableOp(scope, var, a);
    753   auto read = ops::ReadVariableOp(
    754       scope.WithControlDependencies(std::vector<Operation>{write}), var,
    755       DT_INT32);
    756   auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
    757   auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
    758   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    759   TF_ASSERT_OK(scope.ToGraph(graph.get()));
    760 
    761   // Builds a description of the arguments.
    762   std::vector<XlaCompiler::Argument> args(2);
    763   args[0].kind = XlaCompiler::Argument::kParameter;
    764   args[0].type = DT_INT32;
    765   args[0].shape = TensorShape({2, 2});
    766   args[1].kind = XlaCompiler::Argument::kResource;
    767   args[1].resource_kind = XlaResource::kVariable;
    768   args[1].initialized = true;
    769   args[1].type = DT_INT32;
    770   args[1].shape = TensorShape({2, 2});
    771 
    772   // Compiles the graph.
    773   XlaCompiler::Options options = DefaultOptions();
    774   options.variable_representation_shape_fn = [](const TensorShape& shape,
    775                                                 DataType type) {
    776     return TensorShape({shape.num_elements()});
    777   };
    778   XlaCompiler compiler(options);
    779 
    780   XlaCompiler::CompilationResult result;
    781   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
    782                                      std::move(graph), args, &result));
    783 
    784   // Tests that the generated computation works.
    785   std::unique_ptr<xla::Literal> param0_literal =
    786       xla::Literal::CreateR2<int32>({{4, 55}, {1, -3}});
    787   std::unique_ptr<xla::Literal> param1_literal =
    788       xla::Literal::CreateR1<int32>({22, 11, 33, 404});
    789   std::unique_ptr<xla::GlobalData> param0_data =
    790       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    791   std::unique_ptr<xla::GlobalData> param1_data =
    792       client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
    793 
    794   std::unique_ptr<xla::GlobalData> actual =
    795       client_
    796           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
    797           .ConsumeValueOrDie();
    798   std::unique_ptr<xla::Literal> actual_literal =
    799       client_->Transfer(*actual).ConsumeValueOrDie();
    800 
    801   std::unique_ptr<xla::Literal> expected0 =
    802       xla::Literal::CreateR2<int32>({{27, 67}, {35, 402}});
    803   std::unique_ptr<xla::Literal> expected1 =
    804       xla::Literal::CreateR1<int32>({26, 66, 34, 401});
    805   std::unique_ptr<xla::Literal> expected_literal =
    806       xla::Literal::MakeTuple({expected0.get(), expected1.get()});
    807   xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
    808 }
    809 
    810 }  // namespace
    811 }  // namespace tensorflow
    812