Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/function.h"
     17 #include <vector>
     18 #include "tensorflow/core/framework/function.pb.h"
     19 #include "tensorflow/core/framework/function_testlib.h"
     20 #include "tensorflow/core/framework/op.h"
     21 #include "tensorflow/core/framework/tensor_testutil.h"
     22 #include "tensorflow/core/kernels/ops_util.h"
     23 #include "tensorflow/core/lib/core/status_test_util.h"
     24 #include "tensorflow/core/lib/gtl/array_slice.h"
     25 #include "tensorflow/core/lib/strings/str_util.h"
     26 #include "tensorflow/core/lib/strings/strcat.h"
     27 #include "tensorflow/core/platform/test.h"
     28 #include "tensorflow/core/platform/types.h"
     29 
     30 namespace tensorflow {
     31 namespace {
     32 
     33 // A helper class to make AttrSlice from initializer lists
     34 class Attrs {
     35  public:
     36   Attrs(const std::initializer_list<  // NOLINT(runtime/explicit)
     37         std::pair<string, FunctionDefHelper::AttrValueWrapper>>
     38             attrs) {
     39     for (const auto& aval : attrs) {
     40       map_.insert({aval.first, aval.second.proto});
     41     }
     42   }
     43 
     44   operator AttrSlice() { return AttrSlice(&map_); }  // NOLINT(runtime/explicit)
     45 
     46  private:
     47   AttrValueMap map_;
     48 };
     49 
     50 typedef FunctionDefHelper FDH;
     51 
     52 Status GetOpSig(const string& op, const OpDef** sig) {
     53   return OpRegistry::Global()->LookUpOpDef(op, sig);
     54 }
     55 
     56 REGISTER_OP("One")
     57     .Output("y: T")
     58     .Attr("T: {float, double, int32, int64}")
     59     .Doc(R"doc(
     60 Returns a tensor with a single element (1) of type T.
     61 
     62 y: A scalar in type T.
     63 
     64 )doc");
     65 
     66 TEST(TFunc, SquarePlusOne) {
     67   auto fdef = FDH::Create(
     68       // Name
     69       "SquarePlusOne",
     70       // Inputs
     71       {"x: T"},
     72       // Outputs
     73       {"y: T"},
     74       // Attrs
     75       {"T: {float, double, int32, int64}"},
     76       // Nodes
     77       {// a = Square<T>(x)
     78        {{"a"}, "Square", {"x"}, {{"T", "$T"}}},
     79        // o = One<T>()
     80        // NOTE: We can also have a Cast<Tin, Tout>(x) instead.
     81        {{"o"}, "One", {}, {{"T", "$T"}}},
     82        // y = Add<T>(a, o)
     83        {{"y"}, "Add", {"a:y", "o:y"}, {{"T", "$T"}}}},
     84       // Returns
     85       {{"y", "y:z:0"}});
     86 
     87   const char* e = R"P(
     88 SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
     89   a = Square[T=$T](x)
     90   o = One[T=$T]()
     91   y = Add[T=$T](a:y, o:y)
     92   return y = y:z:0
     93 }
     94 )P";
     95   EXPECT_EQ(DebugString(fdef), e);
     96 
     97   // Instantiate one with T=float
     98   InstantiationResult result;
     99   TF_ASSERT_OK(
    100       InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result));
    101   const char* e2 = R"P(
    102 (x:float) -> (y:float) {
    103   a = Square[T=float](x)
    104   o = One[T=float]()
    105   y = Add[T=float](a, o)
    106 }
    107 )P";
    108   EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
    109   EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
    110   EXPECT_EQ(DebugString(result.nodes), e2);
    111 }
    112 
    113 TEST(TFunc, ControlDep) {
    114   auto fdef = FDH::Create(
    115       // Name
    116       "ControlDep",
    117       // Inputs
    118       {"x: int32"},
    119       // Outputs
    120       {"y: int32"},
    121       // Attrs
    122       {},
    123       // Nodes
    124       {// a = Identity<int32>(x)
    125        {{"a"}, "Identity", {"x"}, {{"T", DT_INT32}}},
    126        // o = NoOp(^a)
    127        {{"o"}, "NoOp", {"^a"}, {}},
    128        // y = Identity<int32>(a, ^o)
    129        {{"y"}, "Identity", {"a:output:0", "^o"}, {{"T", DT_INT32}}}},
    130       // Returns
    131       {{"y", "y:output:0"}});
    132 
    133   const char* e = R"P(
    134 ControlDep(x:int32) -> (y:int32) {
    135   a = Identity[T=int32](x)
    136   o = NoOp() @ a
    137   y = Identity[T=int32](a:output:0) @ o
    138   return y = y:output:0
    139 }
    140 )P";
    141   EXPECT_EQ(DebugString(fdef), e);
    142 
    143   // Instantiate one with T=float
    144   InstantiationResult result;
    145   TF_ASSERT_OK(
    146       InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result));
    147   const char* e2 = R"P(
    148 (x:int32) -> (y:int32) {
    149   a = Identity[T=int32](x)
    150   o = NoOp() @ a
    151   y = Identity[T=int32](a) @ o
    152 }
    153 )P";
    154   EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32}));
    155   EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32}));
    156   EXPECT_EQ(DebugString(result.nodes), e2);
    157 }
    158 
    159 REGISTER_OP("HasDefaultType")
    160     .Output("out: T")
    161     .Attr("T: {float, double, int32, int64} = DT_FLOAT");
    162 
    163 // This verifies that a function using an op before a type attr (with
    164 // a default) is added, still works.  This is important for backwards
    165 // compatibility.
    166 TEST(TFunc, MissingTypeAttr) {
    167   auto fdef = FDH::Create(
    168       // Name
    169       "BackCompat",
    170       // Args
    171       {},
    172       // Return values
    173       {"y: float"},
    174       // Attrs
    175       {},
    176       // Nodes
    177       {// y = HasDefaultType(x), T missing, defaults to float
    178        {{"a"}, "HasDefaultType", {}, {}}},
    179       // Returns
    180       {{"y", "a:out:0"}});
    181 
    182   const char* e = R"P(
    183 BackCompat() -> (y:float) {
    184   a = HasDefaultType()
    185   return y = a:out:0
    186 }
    187 )P";
    188   EXPECT_EQ(DebugString(fdef), e);
    189 
    190   InstantiationResult result;
    191   TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
    192   // Should get T=float from Op's default.
    193   const char* e2 = R"P(
    194 () -> (a:float) {
    195   a = HasDefaultType[T=float]()
    196 }
    197 )P";
    198   EXPECT_EQ(result.arg_types, DataTypeVector());
    199   EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
    200   EXPECT_EQ(DebugString(result.nodes), e2);
    201 }
    202 
    203 TEST(TFunc, NTimesT) {
    204   auto fdef = FDH::Create(
    205       // Name
    206       "NTimesT",
    207       // Inputs
    208       {"x: float", "y: float"},
    209       // Outputs
    210       {"z: float"},
    211       // Attrs
    212       {},
    213       // Nodes
    214       {// a = AddN<N=2>(x, y)
    215        {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 2}}}},
    216       // Returns
    217       {{"z", "a:sum:0"}});
    218 
    219   const char* e = R"P(
    220 NTimesT(x:float, y:float) -> (z:float) {
    221   a = AddN[N=2, T=float](x, y)
    222   return z = a:sum:0
    223 }
    224 )P";
    225   EXPECT_EQ(DebugString(fdef), e);
    226 
    227   InstantiationResult result;
    228   TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
    229   const char* e2 = R"P(
    230 (x:float, y:float) -> (a:float) {
    231   a = AddN[N=2, T=float](x, y)
    232 }
    233 )P";
    234   EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT}));
    235   EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
    236   EXPECT_EQ(DebugString(result.nodes), e2);
    237 }
    238 
    239 // NOTE: This is the simplest Map op. It takes a f:T->U.
    240 REGISTER_OP("Map")
    241     .Input("x: N * T")
    242     .Output("y: N * U")
    243     .Attr("T: type")
    244     .Attr("U: type")
    245     .Attr("N: int >= 1")
    246     // .Attr("func: func_name_with_attr")
    247     .Doc(R"doc(
    248 Applies the 'func' on every input. I.e.,
    249 
    250 y[i] = func<...>(x[i])
    251 
    252 x: N tensors, each of type T;
    253 y: N tensors, each of type U;
    254 
    255 )doc");
    256 
    257 TEST(TFunc, AddSquared) {
    258   auto fdef = FDH::Create(
    259       // Name
    260       "AddSquared",
    261       // Args
    262       {"x: N*T"},
    263       // Return values
    264       {"y: T"},
    265       // Attrs
    266       {"N:int", "T:{float, double, int32, int64}"},
    267       // Nodes
    268       {// a = Map<func=Square<$T>,T=$T,U=$T,N=$N>(x)
    269        {{"a"},
    270         "Map",
    271         {"x"},
    272         {{"func", FDH::FunctionRef("Square", {{"T", "$T"}})},
    273          {"T", "$T"},
    274          {"U", "$T"},
    275          {"N", "$N"}}},
    276        // y = AddN<N=$N,T=$T>(a)
    277        {{"y"}, "AddN", {"a:y"}, {{"N", "$N"}, {"T", "$T"}}}},
    278       {{"y", "y:sum"}});
    279 
    280   const char* e = R"P(
    281 AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
    282   a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x)
    283   y = AddN[N=$N, T=$T](a:y)
    284   return y = y:sum
    285 }
    286 )P";
    287   EXPECT_EQ(DebugString(fdef), e);
    288 
    289   // Instantiate one with T=float
    290   InstantiationResult result;
    291   TF_ASSERT_OK(InstantiateFunction(fdef, Attrs({{"N", 3}, {"T", DT_FLOAT}}),
    292                                    GetOpSig, &result));
    293   const char* e2 = R"P(
    294 (x_0:float, x_1:float, x_2:float) -> (y:float) {
    295   a = Map[N=3, T=float, U=float, func=Square[T=float]](x_0, x_1, x_2)
    296   y = AddN[N=3, T=float](a, a:1, a:2)
    297 }
    298 )P";
    299   EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT}));
    300   EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
    301   EXPECT_EQ(DebugString(result.nodes), e2);
    302 }
    303 
    304 TEST(TFunc, ControlDeps) {
    305   auto fdef = FDH::Define(
    306       // Name
    307       "ControlDeps",
    308       // Args
    309       {"x: float"},
    310       // Return values
    311       {},
    312       // Attrs
    313       {},
    314       // Nodes
    315       {
    316           {{"a"}, "One", {}, {{"T", DT_FLOAT}}, {"x"}},
    317           {{"u"}, "NoOp", {}, {}, {"a"}},
    318           {{"b"}, "One", {}, {{"T", DT_FLOAT}}, {"u"}},
    319           {{"v"}, "NoOp", {}, {}, {"b"}},
    320           {{"c"}, "One", {}, {{"T", DT_FLOAT}}, {"a", "v"}},
    321       });
    322   const char* e = R"P(
    323 ControlDeps(x:float) -> () {
    324   a = One[T=float]() @ x
    325   u = NoOp() @ a
    326   b = One[T=float]() @ u
    327   v = NoOp() @ b
    328   c = One[T=float]() @ a, v
    329 }
    330 )P";
    331   EXPECT_EQ(DebugString(fdef), e);
    332 
    333   InstantiationResult result;
    334   TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
    335   const char* e2 = R"P(
    336 (x:float) -> () {
    337   a = One[T=float]() @ x
    338   u = NoOp() @ a
    339   b = One[T=float]() @ u
    340   v = NoOp() @ b
    341   c = One[T=float]() @ a, v
    342 }
    343 )P";
    344   EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
    345   EXPECT_EQ(result.ret_types, DataTypeVector({}));
    346   EXPECT_EQ(DebugString(result.nodes), e2);
    347 }
    348 
    349 TEST(TFunc, XTimesTwo) {
    350   auto expect = R"P(
    351 XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
    352   two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
    353   scale = Cast[DstT=$T, SrcT=int64](two:output:0)
    354   y = Mul[T=$T](x, scale:y:0)
    355   return y = y:z:0
    356 }
    357 )P";
    358   EXPECT_EQ(expect, DebugString(test::function::XTimesTwo()));
    359 }
    360 
    361 TEST(TFunc, WXPlusB) {
    362   auto expect = R"P(
    363 WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
    364   mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
    365   y = Add[T=$T](mm:product:0, b)
    366   return y = y:z:0
    367 }
    368 )P";
    369   EXPECT_EQ(expect, DebugString(test::function::WXPlusB()));
    370 }
    371 
    372 TEST(TFunc, Body_TypeList) {
    373   const Tensor kZero = test::AsScalar<int32>(0);
    374   auto fdef = FDH::Create(
    375       // Name
    376       "Test",
    377       // Args
    378       {"i:float"},
    379       // Return values
    380       {"o:float"},
    381       // Attrs
    382       {},
    383       // Nodes
    384       {{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}},
    385        {{"s"},
    386         "Split",
    387         {"zero:output:0", "i"},
    388         {{"num_split", 4}, {"T", DT_FLOAT}}},
    389        {{"l"}, "Mul", {"s:output:0", "s:output:1"}, {{"T", DT_FLOAT}}},
    390        {{"r"}, "Mul", {"s:output:2", "s:output:3"}, {{"T", DT_FLOAT}}},
    391        {{"x"},
    392         "_ListToArray",
    393         {"l:z", "r:z"},
    394         {{"N", 2},
    395          {"T", DT_FLOAT},
    396          {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
    397        {{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}},
    398       {{"o", "o:sum:0"}});
    399 
    400   const char* e = R"P(
    401 Test(i:float) -> (o:float) {
    402   zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
    403   s = Split[T=float, num_split=4](zero:output:0, i)
    404   l = Mul[T=float](s:output:0, s:output:1)
    405   r = Mul[T=float](s:output:2, s:output:3)
    406   x = _ListToArray[N=2, T=float, Tin={float, float}](l:z, r:z)
    407   o = AddN[N=2, T=float](x:output)
    408   return o = o:sum:0
    409 }
    410 )P";
    411   EXPECT_EQ(DebugString(fdef), e);
    412 
    413   InstantiationResult result;
    414   TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
    415   const char* e2 = R"P(
    416 (i:float) -> (o:float) {
    417   zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
    418   s = Split[T=float, num_split=4](zero, i)
    419   l = Mul[T=float](s, s:1)
    420   r = Mul[T=float](s:2, s:3)
    421   x = _ListToArray[N=2, T=float, Tin={float, float}](l, r)
    422   o = AddN[N=2, T=float](x, x:1)
    423 }
    424 )P";
    425   EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
    426   EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
    427   EXPECT_EQ(DebugString(result.nodes), e2);
    428 }
    429 
    430 REGISTER_OP("Cond")
    431     .Input("input: Tin")
    432     .Output("output: out_types")
    433     .Attr("Tin: list(type)")
    434     .Attr("out_types: list(type)")
    435     .Attr("cond: func")
    436     .Attr("then_branch: func")
    437     .Attr("else_branch: func")
    438     .Doc(R"doc(
    439 output = Cond(input) ? then_branch(input) : else_branch(input)
    440 
    441 cond: A function takes 'input' and returns a scalar.
    442 then_branch: A function takes 'input' and returns 'output'.
    443 else_branch: A function takes 'input' and returns 'output'.
    444 )doc");
    445 
    446 TEST(TFunc, Body_Array_List_Converter) {
    447   auto fdef = FDH::Define(
    448       // Name
    449       "MySelect",
    450       // Args
    451       {"x:float"},
    452       // Return values
    453       {"z:float"},
    454       // Attrs
    455       {},
    456       // Nodes
    457       {
    458           {{"y"},
    459            "Cond",
    460            {"x"},
    461            {{"Tin", DataTypeSlice{DT_FLOAT}},
    462             {"out_types", DataTypeSlice{DT_FLOAT}},
    463             {"cond", FDH::FunctionRef("MyCond")},
    464             {"then_branch", FDH::FunctionRef("MyThen")},
    465             {"else_branch", FDH::FunctionRef("MyElse")}}},
    466           {{"z"},
    467            "Cond",
    468            {"y", "y"},
    469            {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
    470             {"out_types", DataTypeSlice{DT_FLOAT}},
    471             {"cond", FDH::FunctionRef("MyCond2")},
    472             {"then_branch", FDH::FunctionRef("MyThen2")},
    473             {"else_branch", FDH::FunctionRef("MyElse2")}}},
    474       });
    475 
    476   const char* e = R"P(
    477 MySelect(x:float) -> (z:float) {
    478   y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
    479   z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y:output:0, y:output:0)
    480   return z = z:output:0
    481 }
    482 )P";
    483   EXPECT_EQ(DebugString(fdef), e);
    484 
    485   InstantiationResult result;
    486   TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
    487   const char* e2 = R"P(
    488 (x:float) -> (z:float) {
    489   y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
    490   z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y, y)
    491 }
    492 )P";
    493   EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
    494   EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
    495   EXPECT_EQ(DebugString(result.nodes), e2);
    496 }
    497 
    498 static void HasError(const Status& s, const string& substr) {
    499   EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
    500       << ">>" << s << "<<, expected substring >>" << substr << "<<";
    501 }
    502 
    503 TEST(InstantiateErrors, Not_Sufficient_Attrs) {
    504   auto fdef =
    505       FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
    506   InstantiationResult result;
    507   HasError(
    508       InstantiateFunction(fdef, Attrs({{"U", DT_FLOAT}}), GetOpSig, &result),
    509       "Attr T is not found from ");
    510 }
    511 
    512 #if 0  // TODO(josh11b): Enable this test once having an extra attr is an error.
    513 TEST(InstantiateErrors, Too_Many_Attrs) {
    514   auto fdef =
    515       FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
    516   InstantiationResult result;
    517   HasError(InstantiateFunction(fdef, Attrs({{"T", DT_INT32}, {"U", DT_FLOAT}}),
    518                                GetOpSig, &result),
    519            "Attr U is not found in ");
    520 }
    521 #endif
    522 
    523 TEST(InstantiateErrors, AttrValue_Value_Placeholder) {
    524   auto fdef =
    525       FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
    526   InstantiationResult result;
    527   HasError(
    528       InstantiateFunction(fdef, Attrs({{"T", "$bad"}}), GetOpSig, &result),
    529       "AttrValue had value with unexpected type 'placeholder'\n\tfor attr 'T'");
    530 }
    531 
    532 TEST(InstantiateErrors, Unbounded_Attr) {
    533   auto fdef = FDH::Define("test", {}, {}, {"T:{float, double, int32, int64}"},
    534                           {
    535                               {{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}},
    536                           });
    537   InstantiationResult result;
    538   HasError(
    539       InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result),
    540       "Failed to bind all placeholders");
    541 }
    542 
    543 TEST(InstantiateErrors, DupArgs) {
    544   auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {});
    545   InstantiationResult result;
    546   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    547            "Duplicated arg name");
    548 }
    549 
    550 TEST(InstantiateErrors, Dup_Node_Names) {
    551   auto fdef = FDH::Define("test", {"x:float"}, {}, {},
    552                           {
    553                               {{"y"}, "One", {}, {{"T", DT_FLOAT}}},
    554                               {{"y"}, "One", {}, {{"T", DT_FLOAT}}},
    555                           });
    556   InstantiationResult result;
    557   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    558            "Duplicated ret name");
    559 }
    560 
    561 TEST(InstantiateErrors, Node_Arg_Notfound) {
    562   auto fdef = FDH::Create("test", {"x:float"}, {}, {},
    563                           {
    564                               {{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}},
    565                           },
    566                           {});
    567   InstantiationResult result;
    568   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    569            "input z is not found");
    570 }
    571 
    572 TEST(InstantiateErrors, Node_Arg_TypeMismatch) {
    573   auto fdef = FDH::Define("test", {"x:float"}, {}, {},
    574                           {
    575                               {{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}},
    576                           });
    577   InstantiationResult result;
    578   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    579            "input x[0] expected type int32 != float, the type of x[0]");
    580 }
    581 
    582 TEST(InstantiateErrors, Node_Arg_ControlMissing) {
    583   auto fdef =
    584       FDH::Define("test", {"x:float"}, {}, {},
    585                   {
    586                       {{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}},
    587                   });
    588   InstantiationResult result;
    589   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    590            "input[2] == '^z', is not found.");
    591 }
    592 
    593 TEST(InstantiateErrors, FuncRet_Missing) {
    594   auto fdef = FDH::Create("test", {}, {"y: float"}, {},
    595                           {
    596                               {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
    597                           },
    598                           {});
    599   InstantiationResult result;
    600   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    601            "Return y missing");
    602 }
    603 
    604 TEST(InstantiateErrors, FuncRet_NotFound) {
    605   auto fdef = FDH::Create("test", {}, {"y: float"}, {},
    606                           {
    607                               {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
    608                           },
    609                           {{"y", "z"}});
    610   InstantiationResult result;
    611   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    612            "Return y -> z is not found");
    613 }
    614 
    615 TEST(InstantiateErrors, FuncRet_NameMismatch) {
    616   auto fdef = FDH::Create("test", {}, {"y: float"}, {},
    617                           {
    618                               {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
    619                           },
    620                           {{"z", "x:y:0"}});
    621   InstantiationResult result;
    622   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    623            "Return y missing");
    624 }
    625 
    626 // TODO(josh11b): Make this an error.
    627 // TEST(InstantiateErrors, FuncRet_Extra) {
    628 //   auto fdef = FDH::Create("test", {}, {"y: float"}, {},
    629 //                           {
    630 //                               {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
    631 //                           },
    632 //                           {{"y", "x:y:0"}, {"z", "x:y:0"}});
    633 //   InstantiationResult result;
    634 //   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    635 //            "ret is not found");
    636 // }
    637 
    638 TEST(InstantiateErrors, FuncRet_TypeMismatch) {
    639   auto fdef = FDH::Define("test", {}, {"y: float"}, {},
    640                           {
    641                               {{"y"}, "One", {}, {{"T", DT_DOUBLE}}},
    642                           });
    643   InstantiationResult result;
    644   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    645            "Invalid ret types y : float vs. double\n\tIn function output y");
    646 }
    647 
    648 TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
    649   auto fdef = FDH::Create(
    650       // Name
    651       "MySelect",
    652       // Args
    653       {"x: float"},
    654       // Return values
    655       {"y: float"},
    656       // Attrs
    657       {},
    658       // Nodes
    659       {
    660           {{"y"},
    661            "Cond",
    662            {"x", "x"},
    663            {{"tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
    664             {"cond", FDH::FunctionRef("MyCond2")},
    665             {"then_branch", FDH::FunctionRef("MyThen2")},
    666             {"else_branch", FDH::FunctionRef("MyElse2")}}},
    667       },
    668       {{"y", "y:output"}});
    669   InstantiationResult result;
    670   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    671            "type attr not found: out_types");
    672 }
    673 
    674 TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
    675   auto fdef = FDH::Create(
    676       // Name
    677       "MySelect",
    678       // Args
    679       {"x: float"},
    680       // Return values
    681       {"y: float"},
    682       // Attrs
    683       {},
    684       // Nodes
    685       {
    686           {{"y"},
    687            "Cond",
    688            {"x", "x"},
    689            {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
    690             {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
    691             {"cond", FDH::FunctionRef("MyCond2")},
    692             {"then_branch", FDH::FunctionRef("MyThen2")},
    693             {"else_branch", FDH::FunctionRef("MyElse2")}}},
    694       },
    695       {{"y", "y:output"}});
    696   InstantiationResult result;
    697   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    698            "Invalid ret types");
    699 }
    700 
    701 TEST(InstantiateErrors, TypeList_Missing_Arg) {
    702   auto fdef = FDH::Create(
    703       // Name
    704       "MySelect",
    705       // Args
    706       {"x: float"},
    707       // Return values
    708       {"y: float"},
    709       // Attrs
    710       {},
    711       // Nodes
    712       {
    713           {{"y"},
    714            "Cond",
    715            {"x", "unknown"},
    716            {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
    717             {"out_types", DataTypeSlice{DT_FLOAT}},
    718             {"cond", FDH::FunctionRef("MyCond2")},
    719             {"then_branch", FDH::FunctionRef("MyThen2")},
    720             {"else_branch", FDH::FunctionRef("MyElse2")}}},
    721       },
    722       {{"y", "y:output"}});
    723   InstantiationResult result;
    724   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    725            "input unknown is not found");
    726 }
    727 
    728 TEST(InstantiateErrors, TooManyInputs) {
    729   auto fdef = FDH::Create(
    730       // Name
    731       "TooManyInputs",
    732       // Inputs
    733       {"x: float", "y: float"},
    734       // Outputs
    735       {"z: float"},
    736       // Attrs
    737       {},
    738       // Nodes
    739       {// a = AddN<N=2>(x, y, x)
    740        {{"a"}, "AddN", {"x", "y", "x"}, {{"T", DT_FLOAT}, {"N", 2}}}},
    741       // Returns
    742       {{"z", "a:sum:0"}});
    743 
    744   InstantiationResult result;
    745   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    746            "Expected input[2] == 'x' to be a control input.");
    747 }
    748 
    749 TEST(InstantiateErrors, TooFewInputs) {
    750   auto fdef = FDH::Create(
    751       // Name
    752       "TooFewInputs",
    753       // Inputs
    754       {"x: float", "y: float"},
    755       // Outputs
    756       {"z: float"},
    757       // Attrs
    758       {},
    759       // Nodes
    760       {// a = AddN<N=3>(x, y)
    761        {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 3}}}},
    762       // Returns
    763       {{"z", "a:sum:0"}});
    764 
    765   InstantiationResult result;
    766   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    767            "Attempt to access beyond input size: 2 >= 2");
    768 }
    769 
    770 TEST(InstantiateErrors, TooManyInputsFromArray1) {
    771   auto fdef = FDH::Create(
    772       // Name
    773       "TooManyInputsFromArray",
    774       // Inputs
    775       {"x: float", "y: float"},
    776       // Outputs
    777       {"z: float"},
    778       // Attrs
    779       {},
    780       // Nodes
    781       {// a = _ListToArray(x,y)
    782        {{"a"},
    783         "_ListToArray",
    784         {"x", "y"},
    785         {{"N", 2},
    786          {"T", DT_FLOAT},
    787          {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
    788        // b = AddN<N=2>(a, y)
    789        {{"b"}, "AddN", {"a:output", "y"}, {{"T", DT_FLOAT}, {"N", 2}}}},
    790       // Returns
    791       {{"z", "a:sum:0"}});
    792 
    793   InstantiationResult result;
    794   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    795            "Expected input[1] == 'y' to be a control input.");
    796 }
    797 
    798 TEST(InstantiateErrors, TooManyInputsFromArray2) {
    799   auto fdef = FDH::Create(
    800       // Name
    801       "TooManyInputsFromArray",
    802       // Inputs
    803       {"x: float", "y: float"},
    804       // Outputs
    805       {"z: float"},
    806       // Attrs
    807       {},
    808       // Nodes
    809       {// a = _ListToArray(x,y)
    810        {{"a"},
    811         "_ListToArray",
    812         {"x", "y"},
    813         {{"N", 2},
    814          {"T", DT_FLOAT},
    815          {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
    816        // b = AddN<N=2>(x, a)
    817        {{"b"}, "AddN", {"x", "a:output"}, {{"T", DT_FLOAT}, {"N", 2}}}},
    818       // Returns
    819       {{"z", "a:sum:0"}});
    820 
    821   InstantiationResult result;
    822   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    823            "Input a:output too long for inputs");
    824 }
    825 
    826 TEST(InstantiateErrors, TypeMismatch) {
    827   auto fdef = FDH::Create(
    828       // Name
    829       "TypeMismatch",
    830       // Inputs
    831       {"x: float", "y: int32"},
    832       // Outputs
    833       {"z: float"},
    834       // Attrs
    835       {},
    836       // Nodes
    837       {// a = AddN<N=2>(x, y)
    838        {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 3}}}},
    839       // Returns
    840       {{"z", "a:sum:0"}});
    841 
    842   InstantiationResult result;
    843   HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
    844            "input inputs[1] expected type float != int32, the type of y[0]");
    845 }
    846 
    847 TEST(FunctionCallFrame, Void_Void) {
    848   FunctionCallFrame frame({}, {});
    849   TF_EXPECT_OK(frame.SetArgs({}));
    850   auto a = test::AsTensor<float>({100});
    851   HasError(frame.SetArgs({a}), "Invalid argument");
    852   Tensor v;
    853   HasError(frame.GetArg(0, &v), "Invalid argument");
    854   HasError(frame.SetRetval(0, v), "Invalid argument");
    855   std::vector<Tensor> rets;
    856   TF_EXPECT_OK(frame.GetRetvals(&rets));
    857   EXPECT_EQ(rets.size(), 0);
    858 }
    859 
    860 TEST(FunctionCallFrame, Float_Float_Float) {
    861   FunctionCallFrame frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT});
    862   HasError(frame.SetArgs({}), "Invalid argument: Expects 2 arguments");
    863   auto a = test::AsTensor<float>({100});
    864   auto b = test::AsTensor<float>({200});
    865   auto c = test::AsTensor<int64>({300});
    866   HasError(frame.SetArgs({a, c}),
    867            "Invalid argument: Expects arg[1] to be float");
    868   TF_EXPECT_OK(frame.SetArgs({a, b}));
    869 
    870   Tensor v;
    871   HasError(frame.GetArg(-1, &v), "Invalid argument");
    872   HasError(frame.GetArg(2, &v), "Invalid argument");
    873   TF_EXPECT_OK(frame.GetArg(0, &v));
    874   test::ExpectTensorEqual<float>(a, v);
    875   TF_EXPECT_OK(frame.GetArg(1, &v));
    876   test::ExpectTensorEqual<float>(b, v);
    877 
    878   v = test::AsTensor<float>({-100});
    879   HasError(frame.SetRetval(-1, v), "Invalid argument");
    880   HasError(frame.SetRetval(1, v), "Invalid argument");
    881   HasError(frame.SetRetval(0, test::AsTensor<int64>({-100})),
    882            "Invalid argument: Expects ret[0] to be float");
    883 
    884   std::vector<Tensor> rets;
    885   HasError(frame.GetRetvals(&rets), "does not have value");
    886   TF_EXPECT_OK(frame.SetRetval(0, v));
    887   HasError(frame.SetRetval(0, v), "has already been set");
    888 
    889   TF_EXPECT_OK(frame.GetRetvals(&rets));
    890   EXPECT_EQ(rets.size(), 1);
    891   test::ExpectTensorEqual<float>(rets[0], v);
    892 }
    893 
    894 TEST(Canonicalize, Basic) {
    895   EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT},
    896                                           {"transpose_a", false},
    897                                           {"transpose_b", false}})),
    898             "MatMul[T=float,transpose_a=false,transpose_b=false]");
    899   EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT},
    900                                           {"transpose_b", false},
    901                                           {"transpose_a", false}})),
    902             "MatMul[T=float,transpose_a=false,transpose_b=false]");
    903   EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_DOUBLE},
    904                                           {"transpose_b", true},
    905                                           {"transpose_a", false}})),
    906             "MatMul[T=double,transpose_a=false,transpose_b=true]");
    907 }
    908 
    909 TEST(FunctionLibraryDefinitionTest, Find) {
    910   FunctionDefLibrary proto;
    911   *proto.add_function() = test::function::XTimesTwo();
    912   FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
    913 
    914   EXPECT_EQ(lib_def.Find("XTimes16"), nullptr);
    915 
    916   auto expect = R"P(
    917 XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
    918   two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
    919   scale = Cast[DstT=$T, SrcT=int64](two:output:0)
    920   y = Mul[T=$T](x, scale:y:0)
    921   return y = y:z:0
    922 }
    923 )P";
    924   auto found = lib_def.Find("XTimesTwo");
    925   ASSERT_NE(found, nullptr);
    926   EXPECT_EQ(expect, DebugString(*found));
    927 }
    928 
    929 TEST(FunctionLibraryDefinitionTest, LookUp) {
    930   FunctionDefLibrary proto;
    931   *proto.add_function() = test::function::XTimesTwo();
    932   FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
    933 
    934   const OpDef* op_def;
    935   EXPECT_TRUE(!lib_def.LookUpOpDef("XTimes16", &op_def).ok());
    936 
    937   TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &op_def));
    938   ASSERT_NE(op_def, nullptr);
    939   EXPECT_EQ(op_def->DebugString(),
    940             test::function::XTimesTwo().signature().DebugString());
    941 
    942   const OpRegistrationData* op_reg_data;
    943   TF_EXPECT_OK(lib_def.LookUp("XTimesTwo", &op_reg_data));
    944   ASSERT_NE(op_reg_data, nullptr);
    945   // Shape inference function is initialized to UnknownShape.
    946   ASSERT_NE(op_reg_data->shape_inference_fn, nullptr);
    947 }
    948 
    949 TEST(FunctionLibraryDefinitionTest, AddFunctionDef) {
    950   // Add one function to the proto lib before constructing 'lib_def'.
    951   FunctionDefLibrary proto;
    952   *proto.add_function() = test::function::XTimesTwo();
    953   FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
    954 
    955   // Add a new function def to the library.
    956   TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB()));
    957 
    958   // Test lookup of first function.
    959   const OpDef* first;
    960   TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &first));
    961   ASSERT_NE(first, nullptr);
    962   EXPECT_EQ(first->DebugString(),
    963             test::function::XTimesTwo().signature().DebugString());
    964 
    965   // Test lookup of second function.
    966   const OpDef* second;
    967   TF_EXPECT_OK(lib_def.LookUpOpDef("WXPlusB", &second));
    968   ASSERT_NE(second, nullptr);
    969   EXPECT_EQ(second->DebugString(),
    970             test::function::WXPlusB().signature().DebugString());
    971 
    972   // Can't add function with same name as existing op
    973   FunctionDef fdef = test::function::XTimesTwo();
    974   fdef.mutable_signature()->set_name("Add");
    975   Status s = lib_def.AddFunctionDef(fdef);
    976   EXPECT_FALSE(s.ok());
    977   EXPECT_EQ(s.error_message(),
    978             "Cannot add function 'Add' because an op with the same name "
    979             "already exists.");
    980 
    981   // Already-added functions don't produce error
    982   TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::XTimesTwo()));
    983   TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB()));
    984 }
    985 
    986 TEST(FunctionLibraryDefinitionTest, AddGradientDef) {
    987   // AddGradientDef() doesn't check that functions referenced exist (yet?)
    988   FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary());
    989 
    990   // Test adding a gradient (XTimesFour isn't a valid grad function for
    991   // XTimesTwo but that's ok for now)
    992   GradientDef grad;
    993   grad.set_function_name(test::function::XTimesTwo().signature().name());
    994   grad.set_gradient_func(test::function::XTimesFour().signature().name());
    995   TF_EXPECT_OK(lib_def.AddGradientDef(grad));
    996 
    997   // Already-added gradients don't produce error
    998   TF_EXPECT_OK(lib_def.AddGradientDef(grad));
    999 
   1000   // Test that adding a duplicate gradient fails
   1001   grad.set_gradient_func(test::function::XTimes16().signature().name());
   1002   Status s = lib_def.AddGradientDef(grad);
   1003   EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
   1004   EXPECT_EQ(s.error_message(),
   1005             "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because "
   1006             "it already has gradient function 'XTimesFour'");
   1007 }
   1008 
   1009 TEST(FunctionLibraryDefinitionTest, AddLibrary) {
   1010   // Create lib def with single function
   1011   FunctionDefLibrary proto;
   1012   *proto.add_function() = test::function::XTimesTwo();
   1013   FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
   1014 
   1015   // Add gradient
   1016   GradientDef grad;
   1017   grad.set_function_name(test::function::XTimesTwo().signature().name());
   1018   grad.set_gradient_func(test::function::XTimesFour().signature().name());
   1019   TF_EXPECT_OK(lib_def.AddGradientDef(grad));
   1020 
   1021   // Error if you try to add conflicting function
   1022   proto.Clear();
   1023   FunctionDef fdef = test::function::XTimesFour();
   1024   fdef.mutable_signature()->set_name(
   1025       test::function::XTimesTwo().signature().name());
   1026   *proto.add_function() = fdef;
   1027   FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto);
   1028   Status s = lib_def.AddLibrary(lib_def2);
   1029   EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
   1030   EXPECT_EQ(s.error_message(),
   1031             "Cannot add function 'XTimesTwo' because a different function with "
   1032             "the same name already exists.");
   1033 
   1034   // Error if you try to add conflicting gradient
   1035   proto.Clear();
   1036   grad.set_gradient_func(test::function::XTimes16().signature().name());
   1037   *proto.add_gradient() = grad;
   1038   FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto);
   1039   s = lib_def.AddLibrary(lib_def3);
   1040   EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
   1041   EXPECT_EQ(s.error_message(),
   1042             "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because "
   1043             "it already has gradient function 'XTimesFour'");
   1044 
   1045   // No conflicting functions or gradients OK
   1046   proto.Clear();
   1047   *proto.add_function() = test::function::XTimesFour();
   1048   grad.set_function_name(test::function::XTimes16().signature().name());
   1049   *proto.add_gradient() = grad;
   1050   FunctionLibraryDefinition lib_def4(OpRegistry::Global(), proto);
   1051   TF_EXPECT_OK(lib_def.AddLibrary(lib_def4));
   1052 
   1053   // OK to add the same functions and gradients twice
   1054   TF_EXPECT_OK(lib_def.AddLibrary(lib_def));
   1055 }
   1056 
   1057 GradientDef MakeGradDef(const string& f, const string& g) {
   1058   GradientDef grad;
   1059   grad.set_function_name(f);
   1060   grad.set_gradient_func(g);
   1061   return grad;
   1062 }
   1063 
   1064 TEST(FunctionLibraryDefinitionTest, AddLibrary_Atomic) {
   1065   // Create lib def containing two functions with equal names
   1066   FunctionDefLibrary proto;
   1067   const string x2_name = test::function::XTimesTwo().signature().name();
   1068   const string x4_name = test::function::XTimesFour().signature().name();
   1069   *proto.add_function() = test::function::XTimesTwo();
   1070   FunctionDef fdef = test::function::XTimesFour();
   1071   fdef.mutable_signature()->set_name(x2_name);
   1072   *proto.add_function() = fdef;
   1073   FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary());
   1074 
   1075   // Try adding the two functions to lib_def
   1076   Status s = lib_def.AddLibrary(proto);
   1077   EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
   1078   EXPECT_EQ(
   1079       "Cannot add function 'XTimesTwo' because a different function with "
   1080       "the same name already exists.",
   1081       s.error_message());
   1082 
   1083   // Verify that none of the functions are added
   1084   EXPECT_TRUE(lib_def.Find(x2_name) == nullptr);
   1085 
   1086   // Fix the name in proto but add two gradient names for it
   1087   proto.mutable_function(1)->mutable_signature()->set_name(x4_name);
   1088   *proto.add_gradient() = MakeGradDef(x2_name, x4_name);
   1089   *proto.add_gradient() = MakeGradDef(x2_name, "SecondGradName");
   1090 
   1091   // Try adding the library and check that nothing was added
   1092   s = lib_def.AddLibrary(proto);
   1093   EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
   1094   EXPECT_EQ(s.error_message(),
   1095             "Cannot assign gradient function 'SecondGradName' to 'XTimesTwo' "
   1096             "because it already has gradient function 'XTimesFour'");
   1097   EXPECT_TRUE(lib_def.Find(x2_name) == nullptr);
   1098   EXPECT_EQ(0, lib_def.ToProto().function_size());
   1099   EXPECT_EQ(0, lib_def.ToProto().gradient_size());
   1100 }
   1101 
   1102 TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_FuncConflict) {
   1103   const string x2_name = test::function::XTimesTwo().signature().name();
   1104   const string x4_name = test::function::XTimesFour().signature().name();
   1105   const string wx_name = test::function::WXPlusB().signature().name();
   1106 
   1107   // Create FunctionLibraryDefinition with
   1108   // (func = XTimesTwo, grad = XTimesFour)
   1109   FunctionDefLibrary proto;
   1110   *proto.add_function() = test::function::XTimesTwo();
   1111   *proto.add_gradient() = MakeGradDef(x2_name, x4_name);
   1112   FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
   1113   EXPECT_EQ(1, lib_def.ToProto().function_size());
   1114   EXPECT_EQ(1, lib_def.ToProto().gradient_size());
   1115 
   1116   // Create FunctionLibraryDefinition with (func = WXPlusB, grad = XTimesTwo)
   1117   // and function (name = XTimesTwo, body = XTimeFour)
   1118   FunctionDefLibrary proto2;
   1119   *proto2.add_function() = test::function::WXPlusB();
   1120   *proto2.add_gradient() = MakeGradDef(wx_name, x2_name);
   1121   *proto2.add_function() = test::function::XTimesFour();
   1122   proto2.mutable_function(1)->mutable_signature()->set_name(x2_name);
   1123   FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2);
   1124 
   1125   // Verify that adding lib_def2 will fail because of function conflict
   1126   // and WXPlusB is not added.
   1127   Status s = lib_def.AddLibrary(lib_def2);
   1128   EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
   1129   EXPECT_EQ(
   1130       "Cannot add function 'XTimesTwo' because a different function "
   1131       "with the same name already exists.",
   1132       s.error_message());
   1133   EXPECT_TRUE(lib_def.Find(wx_name) == nullptr);
   1134   EXPECT_EQ(1, lib_def.ToProto().function_size());
   1135   EXPECT_EQ(1, lib_def.ToProto().gradient_size());
   1136 }
   1137 
   1138 TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_GradConflict) {
   1139   const string x2_name = test::function::XTimesTwo().signature().name();
   1140   const string x4_name = test::function::XTimesFour().signature().name();
   1141   const string wx_name = test::function::WXPlusB().signature().name();
   1142 
   1143   // Create FunctionLibraryDefinition with
   1144   // (func = XTimesTwo, grad = XTimesFour)
   1145   FunctionDefLibrary proto;
   1146   *proto.add_function() = test::function::XTimesTwo();
   1147   *proto.add_gradient() = MakeGradDef(x2_name, x4_name);
   1148   FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
   1149   EXPECT_EQ(1, lib_def.ToProto().function_size());
   1150   EXPECT_EQ(1, lib_def.ToProto().gradient_size());
   1151 
   1152   // Create FunctionLibraryDefinition with (func = WXPlusB, grad = XTimesTwo)
   1153   // and (func = XTimesTwo, grad = WXPlusB)
   1154   FunctionDefLibrary proto2;
   1155   *proto2.add_function() = test::function::WXPlusB();
   1156   *proto2.add_gradient() = MakeGradDef(wx_name, x2_name);
   1157   *proto2.add_function() = test::function::XTimesTwo();
   1158   *proto2.add_gradient() = MakeGradDef(x2_name, wx_name);
   1159   FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2);
   1160 
   1161   // Verify that adding lib_def2 will fail because of gradient conflict
   1162   // and WXPlusB is not added.
   1163   Status s = lib_def.AddLibrary(lib_def2);
   1164   EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
   1165   EXPECT_EQ(
   1166       "Cannot assign gradient function 'WXPlusB' to 'XTimesTwo'"
   1167       " because it already has gradient function 'XTimesFour'",
   1168       s.error_message());
   1169   EXPECT_TRUE(lib_def.Find(wx_name) == nullptr);
   1170   EXPECT_EQ(1, lib_def.ToProto().function_size());
   1171   EXPECT_EQ(1, lib_def.ToProto().gradient_size());
   1172 }
   1173 
   1174 TEST(FunctionLibraryDefinitionTest, ToProto) {
   1175   FunctionDefLibrary proto1;
   1176   *proto1.add_function() = test::function::XTimesTwo();
   1177   *proto1.add_function() = test::function::WXPlusB();
   1178   FunctionLibraryDefinition lib_def1(OpRegistry::Global(), proto1);
   1179 
   1180   // Call 'ToProto' and make sure both protos have the same function lib size.
   1181   FunctionDefLibrary proto2 = lib_def1.ToProto();
   1182   EXPECT_EQ(proto1.function_size(), proto2.function_size());
   1183 
   1184   // Initialize 'lib_def2' with proto returned by 'ToProto' call.
   1185   FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2);
   1186 
   1187   // Test that the first function exists in both libraries.
   1188   const OpDef *f1, *f2, *f3, *f4;
   1189   TF_EXPECT_OK(lib_def1.LookUpOpDef("XTimesTwo", &f1));
   1190   TF_EXPECT_OK(lib_def2.LookUpOpDef("XTimesTwo", &f2));
   1191   EXPECT_EQ(f1->DebugString(), f2->DebugString());
   1192 
   1193   // Test that the second function exists in both libraries.
   1194   TF_EXPECT_OK(lib_def1.LookUpOpDef("WXPlusB", &f3));
   1195   TF_EXPECT_OK(lib_def2.LookUpOpDef("WXPlusB", &f4));
   1196   EXPECT_EQ(f3->DebugString(), f4->DebugString());
   1197 }
   1198 
   1199 TEST(FunctionLibraryDefinitionTest, GetAttr_FuncNoAttr) {
   1200   FunctionDefLibrary proto;
   1201   *proto.add_function() = test::function::XTimesTwo();
   1202   FunctionLibraryDefinition lib(OpRegistry::Global(), proto);
   1203 
   1204   NodeDef ndef;
   1205   bool annotation;
   1206 
   1207   // Not a function.
   1208   ndef.set_op("Matmul");
   1209   EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());
   1210 
   1211   // A function. No attr defined.
   1212   ndef.set_op("XTimesTwo");
   1213   EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());
   1214 
   1215   // ndef defines the attr. But we don't care.
   1216   AddNodeAttr("annotation", true, &ndef);
   1217   EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());
   1218 }
   1219 
   1220 template <typename T>
   1221 void SetAttrValue(FunctionDef* fdef, const string& attr, const T& value) {
   1222   AttrValue attr_value;
   1223   SetAttrValue(value, &attr_value);
   1224   fdef->mutable_attr()->insert({attr, attr_value});
   1225 }
   1226 
   1227 TEST(FunctionLibraryDefinitionTest, GetAttr_FuncWithAttr) {
   1228   FunctionDefLibrary proto;
   1229   auto fdef = proto.add_function();
   1230   *fdef = test::function::XTimesTwo();
   1231   SetAttrValue(fdef, "annotation", true);
   1232   SetAttrValue(fdef, "options", "some string data");
   1233   FunctionLibraryDefinition lib(OpRegistry::Global(), proto);
   1234 
   1235   NodeDef ndef;
   1236   bool annotation;
   1237 
   1238   // A function. No attr defined in ndef.
   1239   ndef.set_op("XTimesTwo");
   1240   TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation));
   1241   EXPECT_EQ(annotation, true);
   1242 
   1243   string str;
   1244   TF_EXPECT_OK(lib.GetAttr(ndef, "options", &str));
   1245   EXPECT_EQ(str, "some string data");
   1246 }
   1247 
   1248 TEST(FunctionLibraryDefinitionTest, GetAttr_Gradient) {
   1249   FunctionDefLibrary proto;
   1250   auto fdef = proto.add_function();
   1251   *fdef = test::function::XTimesTwo();
   1252   SetAttrValue(fdef, "annotation", true);
   1253   *fdef = test::function::WXPlusB();
   1254   SetAttrValue(fdef, "annotation", false);
   1255   auto func_grad = proto.add_gradient();
   1256   func_grad->set_function_name("XTimesTwo");
   1257   func_grad->set_gradient_func("WXPlusB");
   1258   FunctionLibraryDefinition lib(OpRegistry::Global(), proto);
   1259 
   1260   NodeDef ndef;
   1261   ndef.set_op(FunctionLibraryDefinition::kGradientOp);
   1262 
   1263   bool annotation;
   1264   EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());
   1265 
   1266   NameAttrList nal;
   1267   nal.set_name("XTimesTwo");
   1268   AddNodeAttr(FunctionLibraryDefinition::kFuncAttr, nal, &ndef);
   1269   TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation));
   1270   EXPECT_EQ(annotation, false);  // XTimesTwo's gradient is WXPlusB.
   1271 
   1272   nal.set_name("WXPlusB");
   1273   ndef.clear_attr();
   1274   AddNodeAttr(FunctionLibraryDefinition::kFuncAttr, nal, &ndef);
   1275   TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation));
   1276   EXPECT_EQ(annotation, false);  // WXPlusB has no custom gradient.
   1277 }
   1278 
   1279 // TODO(skyewm): this could be more thorough
   1280 TEST(FunctionDefsEqualTest, TestFunctionDefsEqual) {
   1281   // Equal functions
   1282   const FunctionDef fdef1 = test::function::XTimesTwo();
   1283   FunctionDef fdef2 = test::function::XTimesTwo();
   1284   uint64 hash1 = FunctionDefHash(fdef1);
   1285   EXPECT_TRUE(FunctionDefsEqual(fdef1, fdef2));
   1286   EXPECT_EQ(hash1, FunctionDefHash(fdef2));
   1287 
   1288   // Different functions
   1289   fdef2 = test::function::XTimesFour();
   1290   EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
   1291   EXPECT_NE(hash1, FunctionDefHash(fdef2));
   1292 
   1293   // Different signatures
   1294   fdef2 = test::function::XTimesTwo();
   1295   fdef2.mutable_signature()->mutable_input_arg(0)->set_name("foo");
   1296   EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
   1297   EXPECT_NE(hash1, FunctionDefHash(fdef2));
   1298 
   1299   // Descriptions must be equal
   1300   fdef2 = test::function::XTimesTwo();
   1301   fdef2.mutable_signature()->mutable_input_arg(0)->set_description("foo");
   1302   EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
   1303   EXPECT_NE(hash1, FunctionDefHash(fdef2));
   1304 
   1305   // Different NodeDefs
   1306   fdef2 = test::function::XTimesTwo();
   1307   NodeDef* ndef = fdef2.add_node_def();
   1308   *ndef = fdef2.node_def(0);
   1309   ndef->set_name("new_name");
   1310   EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
   1311   EXPECT_NE(hash1, FunctionDefHash(fdef2));
   1312 
   1313   // Different return values
   1314   fdef2 = test::function::XTimesTwo();
   1315   (*fdef2.mutable_ret())["y"] = "y:z:1";  // originally is "y:z:0"
   1316   EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
   1317   EXPECT_NE(hash1, FunctionDefHash(fdef2));
   1318 
   1319   // Different attributes
   1320   fdef2 = test::function::XTimesTwo();
   1321   SetAttrValue(&fdef2, "ExtraAttr", true);
   1322   EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
   1323   EXPECT_NE(hash1, FunctionDefHash(fdef2));
   1324 
   1325   // Multiple equivalent attributes; the two functions should be equal.
   1326   fdef2 = test::function::XTimesTwo();
   1327   FunctionDef fdef3 = test::function::XTimesTwo();
   1328   SetAttrValue(&fdef2, "Foo", true);
   1329   SetAttrValue(&fdef3, "Foo", true);
   1330   SetAttrValue(&fdef2, "Bar", 123);
   1331   SetAttrValue(&fdef3, "Bar", 123);
   1332   SetAttrValue(&fdef2, "Baz", "abc");
   1333   SetAttrValue(&fdef3, "Baz", "abc");
   1334   EXPECT_TRUE(FunctionDefsEqual(fdef2, fdef3));
   1335   EXPECT_EQ(FunctionDefHash(fdef2), FunctionDefHash(fdef3));
   1336 }
   1337 
   1338 }  // end namespace
   1339 }  // end namespace tensorflow
   1340