1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/function.h" 17 #include <vector> 18 #include "tensorflow/core/framework/function.pb.h" 19 #include "tensorflow/core/framework/function_testlib.h" 20 #include "tensorflow/core/framework/op.h" 21 #include "tensorflow/core/framework/tensor_testutil.h" 22 #include "tensorflow/core/kernels/ops_util.h" 23 #include "tensorflow/core/lib/core/status_test_util.h" 24 #include "tensorflow/core/lib/gtl/array_slice.h" 25 #include "tensorflow/core/lib/strings/str_util.h" 26 #include "tensorflow/core/lib/strings/strcat.h" 27 #include "tensorflow/core/platform/test.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace tensorflow { 31 namespace { 32 33 // A helper class to make AttrSlice from initializer lists 34 class Attrs { 35 public: 36 Attrs(const std::initializer_list< // NOLINT(runtime/explicit) 37 std::pair<string, FunctionDefHelper::AttrValueWrapper>> 38 attrs) { 39 for (const auto& aval : attrs) { 40 map_.insert({aval.first, aval.second.proto}); 41 } 42 } 43 44 operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) 45 46 private: 47 AttrValueMap map_; 48 }; 49 50 typedef FunctionDefHelper FDH; 51 52 Status GetOpSig(const string& op, const OpDef** sig) { 53 return OpRegistry::Global()->LookUpOpDef(op, sig); 54 } 55 56 REGISTER_OP("One") 57 .Output("y: T") 58 .Attr("T: {float, double, int32, int64}") 59 .Doc(R"doc( 60 Returns a tensor with a single element (1) of type T. 61 62 y: A scalar in type T. 63 64 )doc"); 65 66 TEST(TFunc, SquarePlusOne) { 67 auto fdef = FDH::Create( 68 // Name 69 "SquarePlusOne", 70 // Inputs 71 {"x: T"}, 72 // Outputs 73 {"y: T"}, 74 // Attrs 75 {"T: {float, double, int32, int64}"}, 76 // Nodes 77 {// a = Square<T>(x) 78 {{"a"}, "Square", {"x"}, {{"T", "$T"}}}, 79 // o = One<T>() 80 // NOTE: We can also have a Cast<Tin, Tout>(x) instead. 81 {{"o"}, "One", {}, {{"T", "$T"}}}, 82 // y = Add<T>(a, o) 83 {{"y"}, "Add", {"a:y", "o:y"}, {{"T", "$T"}}}}, 84 // Returns 85 {{"y", "y:z:0"}}); 86 87 const char* e = R"P( 88 SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) { 89 a = Square[T=$T](x) 90 o = One[T=$T]() 91 y = Add[T=$T](a:y, o:y) 92 return y = y:z:0 93 } 94 )P"; 95 EXPECT_EQ(DebugString(fdef), e); 96 97 // Instantiate one with T=float 98 InstantiationResult result; 99 TF_ASSERT_OK( 100 InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result)); 101 const char* e2 = R"P( 102 (x:float) -> (y:float) { 103 a = Square[T=float](x) 104 o = One[T=float]() 105 y = Add[T=float](a, o) 106 } 107 )P"; 108 EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); 109 EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); 110 EXPECT_EQ(DebugString(result.nodes), e2); 111 } 112 113 TEST(TFunc, ControlDep) { 114 auto fdef = FDH::Create( 115 // Name 116 "ControlDep", 117 // Inputs 118 {"x: int32"}, 119 // Outputs 120 {"y: int32"}, 121 // Attrs 122 {}, 123 // Nodes 124 {// a = Identity<int32>(x) 125 {{"a"}, "Identity", {"x"}, {{"T", DT_INT32}}}, 126 // o = NoOp(^a) 127 {{"o"}, "NoOp", {"^a"}, {}}, 128 // y = Identity<int32>(a, ^o) 129 {{"y"}, "Identity", {"a:output:0", "^o"}, {{"T", DT_INT32}}}}, 130 // Returns 131 {{"y", "y:output:0"}}); 132 133 const char* e = R"P( 134 ControlDep(x:int32) -> (y:int32) { 135 a = Identity[T=int32](x) 136 o = NoOp() @ a 137 y = Identity[T=int32](a:output:0) @ o 138 return y = y:output:0 139 } 140 )P"; 141 EXPECT_EQ(DebugString(fdef), e); 142 143 // Instantiate one with T=float 144 InstantiationResult result; 145 TF_ASSERT_OK( 146 InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result)); 147 const char* e2 = R"P( 148 (x:int32) -> (y:int32) { 149 a = Identity[T=int32](x) 150 o = NoOp() @ a 151 y = Identity[T=int32](a) @ o 152 } 153 )P"; 154 EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32})); 155 EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32})); 156 EXPECT_EQ(DebugString(result.nodes), e2); 157 } 158 159 REGISTER_OP("HasDefaultType") 160 .Output("out: T") 161 .Attr("T: {float, double, int32, int64} = DT_FLOAT"); 162 163 // This verifies that a function using an op before a type attr (with 164 // a default) is added, still works. This is important for backwards 165 // compatibility. 166 TEST(TFunc, MissingTypeAttr) { 167 auto fdef = FDH::Create( 168 // Name 169 "BackCompat", 170 // Args 171 {}, 172 // Return values 173 {"y: float"}, 174 // Attrs 175 {}, 176 // Nodes 177 {// y = HasDefaultType(x), T missing, defaults to float 178 {{"a"}, "HasDefaultType", {}, {}}}, 179 // Returns 180 {{"y", "a:out:0"}}); 181 182 const char* e = R"P( 183 BackCompat() -> (y:float) { 184 a = HasDefaultType() 185 return y = a:out:0 186 } 187 )P"; 188 EXPECT_EQ(DebugString(fdef), e); 189 190 InstantiationResult result; 191 TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); 192 // Should get T=float from Op's default. 193 const char* e2 = R"P( 194 () -> (a:float) { 195 a = HasDefaultType[T=float]() 196 } 197 )P"; 198 EXPECT_EQ(result.arg_types, DataTypeVector()); 199 EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); 200 EXPECT_EQ(DebugString(result.nodes), e2); 201 } 202 203 TEST(TFunc, NTimesT) { 204 auto fdef = FDH::Create( 205 // Name 206 "NTimesT", 207 // Inputs 208 {"x: float", "y: float"}, 209 // Outputs 210 {"z: float"}, 211 // Attrs 212 {}, 213 // Nodes 214 {// a = AddN<N=2>(x, y) 215 {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 2}}}}, 216 // Returns 217 {{"z", "a:sum:0"}}); 218 219 const char* e = R"P( 220 NTimesT(x:float, y:float) -> (z:float) { 221 a = AddN[N=2, T=float](x, y) 222 return z = a:sum:0 223 } 224 )P"; 225 EXPECT_EQ(DebugString(fdef), e); 226 227 InstantiationResult result; 228 TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); 229 const char* e2 = R"P( 230 (x:float, y:float) -> (a:float) { 231 a = AddN[N=2, T=float](x, y) 232 } 233 )P"; 234 EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT})); 235 EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); 236 EXPECT_EQ(DebugString(result.nodes), e2); 237 } 238 239 // NOTE: This is the simplest Map op. It takes a f:T->U. 240 REGISTER_OP("Map") 241 .Input("x: N * T") 242 .Output("y: N * U") 243 .Attr("T: type") 244 .Attr("U: type") 245 .Attr("N: int >= 1") 246 // .Attr("func: func_name_with_attr") 247 .Doc(R"doc( 248 Applies the 'func' on every input. I.e., 249 250 y[i] = func<...>(x[i]) 251 252 x: N tensors, each of type T; 253 y: N tensors, each of type U; 254 255 )doc"); 256 257 TEST(TFunc, AddSquared) { 258 auto fdef = FDH::Create( 259 // Name 260 "AddSquared", 261 // Args 262 {"x: N*T"}, 263 // Return values 264 {"y: T"}, 265 // Attrs 266 {"N:int", "T:{float, double, int32, int64}"}, 267 // Nodes 268 {// a = Map<func=Square<$T>,T=$T,U=$T,N=$N>(x) 269 {{"a"}, 270 "Map", 271 {"x"}, 272 {{"func", FDH::FunctionRef("Square", {{"T", "$T"}})}, 273 {"T", "$T"}, 274 {"U", "$T"}, 275 {"N", "$N"}}}, 276 // y = AddN<N=$N,T=$T>(a) 277 {{"y"}, "AddN", {"a:y"}, {{"N", "$N"}, {"T", "$T"}}}}, 278 {{"y", "y:sum"}}); 279 280 const char* e = R"P( 281 AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) { 282 a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x) 283 y = AddN[N=$N, T=$T](a:y) 284 return y = y:sum 285 } 286 )P"; 287 EXPECT_EQ(DebugString(fdef), e); 288 289 // Instantiate one with T=float 290 InstantiationResult result; 291 TF_ASSERT_OK(InstantiateFunction(fdef, Attrs({{"N", 3}, {"T", DT_FLOAT}}), 292 GetOpSig, &result)); 293 const char* e2 = R"P( 294 (x_0:float, x_1:float, x_2:float) -> (y:float) { 295 a = Map[N=3, T=float, U=float, func=Square[T=float]](x_0, x_1, x_2) 296 y = AddN[N=3, T=float](a, a:1, a:2) 297 } 298 )P"; 299 EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT})); 300 EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); 301 EXPECT_EQ(DebugString(result.nodes), e2); 302 } 303 304 TEST(TFunc, ControlDeps) { 305 auto fdef = FDH::Define( 306 // Name 307 "ControlDeps", 308 // Args 309 {"x: float"}, 310 // Return values 311 {}, 312 // Attrs 313 {}, 314 // Nodes 315 { 316 {{"a"}, "One", {}, {{"T", DT_FLOAT}}, {"x"}}, 317 {{"u"}, "NoOp", {}, {}, {"a"}}, 318 {{"b"}, "One", {}, {{"T", DT_FLOAT}}, {"u"}}, 319 {{"v"}, "NoOp", {}, {}, {"b"}}, 320 {{"c"}, "One", {}, {{"T", DT_FLOAT}}, {"a", "v"}}, 321 }); 322 const char* e = R"P( 323 ControlDeps(x:float) -> () { 324 a = One[T=float]() @ x 325 u = NoOp() @ a 326 b = One[T=float]() @ u 327 v = NoOp() @ b 328 c = One[T=float]() @ a, v 329 } 330 )P"; 331 EXPECT_EQ(DebugString(fdef), e); 332 333 InstantiationResult result; 334 TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); 335 const char* e2 = R"P( 336 (x:float) -> () { 337 a = One[T=float]() @ x 338 u = NoOp() @ a 339 b = One[T=float]() @ u 340 v = NoOp() @ b 341 c = One[T=float]() @ a, v 342 } 343 )P"; 344 EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); 345 EXPECT_EQ(result.ret_types, DataTypeVector({})); 346 EXPECT_EQ(DebugString(result.nodes), e2); 347 } 348 349 TEST(TFunc, XTimesTwo) { 350 auto expect = R"P( 351 XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) { 352 two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() 353 scale = Cast[DstT=$T, SrcT=int64](two:output:0) 354 y = Mul[T=$T](x, scale:y:0) 355 return y = y:z:0 356 } 357 )P"; 358 EXPECT_EQ(expect, DebugString(test::function::XTimesTwo())); 359 } 360 361 TEST(TFunc, WXPlusB) { 362 auto expect = R"P( 363 WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) { 364 mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x) 365 y = Add[T=$T](mm:product:0, b) 366 return y = y:z:0 367 } 368 )P"; 369 EXPECT_EQ(expect, DebugString(test::function::WXPlusB())); 370 } 371 372 TEST(TFunc, Body_TypeList) { 373 const Tensor kZero = test::AsScalar<int32>(0); 374 auto fdef = FDH::Create( 375 // Name 376 "Test", 377 // Args 378 {"i:float"}, 379 // Return values 380 {"o:float"}, 381 // Attrs 382 {}, 383 // Nodes 384 {{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}}, 385 {{"s"}, 386 "Split", 387 {"zero:output:0", "i"}, 388 {{"num_split", 4}, {"T", DT_FLOAT}}}, 389 {{"l"}, "Mul", {"s:output:0", "s:output:1"}, {{"T", DT_FLOAT}}}, 390 {{"r"}, "Mul", {"s:output:2", "s:output:3"}, {{"T", DT_FLOAT}}}, 391 {{"x"}, 392 "_ListToArray", 393 {"l:z", "r:z"}, 394 {{"N", 2}, 395 {"T", DT_FLOAT}, 396 {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}}, 397 {{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}}, 398 {{"o", "o:sum:0"}}); 399 400 const char* e = R"P( 401 Test(i:float) -> (o:float) { 402 zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() 403 s = Split[T=float, num_split=4](zero:output:0, i) 404 l = Mul[T=float](s:output:0, s:output:1) 405 r = Mul[T=float](s:output:2, s:output:3) 406 x = _ListToArray[N=2, T=float, Tin={float, float}](l:z, r:z) 407 o = AddN[N=2, T=float](x:output) 408 return o = o:sum:0 409 } 410 )P"; 411 EXPECT_EQ(DebugString(fdef), e); 412 413 InstantiationResult result; 414 TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); 415 const char* e2 = R"P( 416 (i:float) -> (o:float) { 417 zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() 418 s = Split[T=float, num_split=4](zero, i) 419 l = Mul[T=float](s, s:1) 420 r = Mul[T=float](s:2, s:3) 421 x = _ListToArray[N=2, T=float, Tin={float, float}](l, r) 422 o = AddN[N=2, T=float](x, x:1) 423 } 424 )P"; 425 EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); 426 EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); 427 EXPECT_EQ(DebugString(result.nodes), e2); 428 } 429 430 REGISTER_OP("Cond") 431 .Input("input: Tin") 432 .Output("output: out_types") 433 .Attr("Tin: list(type)") 434 .Attr("out_types: list(type)") 435 .Attr("cond: func") 436 .Attr("then_branch: func") 437 .Attr("else_branch: func") 438 .Doc(R"doc( 439 output = Cond(input) ? then_branch(input) : else_branch(input) 440 441 cond: A function takes 'input' and returns a scalar. 442 then_branch: A function takes 'input' and returns 'output'. 443 else_branch: A function takes 'input' and returns 'output'. 444 )doc"); 445 446 TEST(TFunc, Body_Array_List_Converter) { 447 auto fdef = FDH::Define( 448 // Name 449 "MySelect", 450 // Args 451 {"x:float"}, 452 // Return values 453 {"z:float"}, 454 // Attrs 455 {}, 456 // Nodes 457 { 458 {{"y"}, 459 "Cond", 460 {"x"}, 461 {{"Tin", DataTypeSlice{DT_FLOAT}}, 462 {"out_types", DataTypeSlice{DT_FLOAT}}, 463 {"cond", FDH::FunctionRef("MyCond")}, 464 {"then_branch", FDH::FunctionRef("MyThen")}, 465 {"else_branch", FDH::FunctionRef("MyElse")}}}, 466 {{"z"}, 467 "Cond", 468 {"y", "y"}, 469 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, 470 {"out_types", DataTypeSlice{DT_FLOAT}}, 471 {"cond", FDH::FunctionRef("MyCond2")}, 472 {"then_branch", FDH::FunctionRef("MyThen2")}, 473 {"else_branch", FDH::FunctionRef("MyElse2")}}}, 474 }); 475 476 const char* e = R"P( 477 MySelect(x:float) -> (z:float) { 478 y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x) 479 z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y:output:0, y:output:0) 480 return z = z:output:0 481 } 482 )P"; 483 EXPECT_EQ(DebugString(fdef), e); 484 485 InstantiationResult result; 486 TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); 487 const char* e2 = R"P( 488 (x:float) -> (z:float) { 489 y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x) 490 z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y, y) 491 } 492 )P"; 493 EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); 494 EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); 495 EXPECT_EQ(DebugString(result.nodes), e2); 496 } 497 498 static void HasError(const Status& s, const string& substr) { 499 EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) 500 << ">>" << s << "<<, expected substring >>" << substr << "<<"; 501 } 502 503 TEST(InstantiateErrors, Not_Sufficient_Attrs) { 504 auto fdef = 505 FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); 506 InstantiationResult result; 507 HasError( 508 InstantiateFunction(fdef, Attrs({{"U", DT_FLOAT}}), GetOpSig, &result), 509 "Attr T is not found from "); 510 } 511 512 #if 0 // TODO(josh11b): Enable this test once having an extra attr is an error. 513 TEST(InstantiateErrors, Too_Many_Attrs) { 514 auto fdef = 515 FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); 516 InstantiationResult result; 517 HasError(InstantiateFunction(fdef, Attrs({{"T", DT_INT32}, {"U", DT_FLOAT}}), 518 GetOpSig, &result), 519 "Attr U is not found in "); 520 } 521 #endif 522 523 TEST(InstantiateErrors, AttrValue_Value_Placeholder) { 524 auto fdef = 525 FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); 526 InstantiationResult result; 527 HasError( 528 InstantiateFunction(fdef, Attrs({{"T", "$bad"}}), GetOpSig, &result), 529 "AttrValue had value with unexpected type 'placeholder'\n\tfor attr 'T'"); 530 } 531 532 TEST(InstantiateErrors, Unbounded_Attr) { 533 auto fdef = FDH::Define("test", {}, {}, {"T:{float, double, int32, int64}"}, 534 { 535 {{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}}, 536 }); 537 InstantiationResult result; 538 HasError( 539 InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result), 540 "Failed to bind all placeholders"); 541 } 542 543 TEST(InstantiateErrors, DupArgs) { 544 auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {}); 545 InstantiationResult result; 546 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 547 "Duplicated arg name"); 548 } 549 550 TEST(InstantiateErrors, Dup_Node_Names) { 551 auto fdef = FDH::Define("test", {"x:float"}, {}, {}, 552 { 553 {{"y"}, "One", {}, {{"T", DT_FLOAT}}}, 554 {{"y"}, "One", {}, {{"T", DT_FLOAT}}}, 555 }); 556 InstantiationResult result; 557 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 558 "Duplicated ret name"); 559 } 560 561 TEST(InstantiateErrors, Node_Arg_Notfound) { 562 auto fdef = FDH::Create("test", {"x:float"}, {}, {}, 563 { 564 {{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}}, 565 }, 566 {}); 567 InstantiationResult result; 568 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 569 "input z is not found"); 570 } 571 572 TEST(InstantiateErrors, Node_Arg_TypeMismatch) { 573 auto fdef = FDH::Define("test", {"x:float"}, {}, {}, 574 { 575 {{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}}, 576 }); 577 InstantiationResult result; 578 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 579 "input x[0] expected type int32 != float, the type of x[0]"); 580 } 581 582 TEST(InstantiateErrors, Node_Arg_ControlMissing) { 583 auto fdef = 584 FDH::Define("test", {"x:float"}, {}, {}, 585 { 586 {{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}}, 587 }); 588 InstantiationResult result; 589 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 590 "input[2] == '^z', is not found."); 591 } 592 593 TEST(InstantiateErrors, FuncRet_Missing) { 594 auto fdef = FDH::Create("test", {}, {"y: float"}, {}, 595 { 596 {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, 597 }, 598 {}); 599 InstantiationResult result; 600 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 601 "Return y missing"); 602 } 603 604 TEST(InstantiateErrors, FuncRet_NotFound) { 605 auto fdef = FDH::Create("test", {}, {"y: float"}, {}, 606 { 607 {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, 608 }, 609 {{"y", "z"}}); 610 InstantiationResult result; 611 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 612 "Return y -> z is not found"); 613 } 614 615 TEST(InstantiateErrors, FuncRet_NameMismatch) { 616 auto fdef = FDH::Create("test", {}, {"y: float"}, {}, 617 { 618 {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, 619 }, 620 {{"z", "x:y:0"}}); 621 InstantiationResult result; 622 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 623 "Return y missing"); 624 } 625 626 // TODO(josh11b): Make this an error. 627 // TEST(InstantiateErrors, FuncRet_Extra) { 628 // auto fdef = FDH::Create("test", {}, {"y: float"}, {}, 629 // { 630 // {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, 631 // }, 632 // {{"y", "x:y:0"}, {"z", "x:y:0"}}); 633 // InstantiationResult result; 634 // HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 635 // "ret is not found"); 636 // } 637 638 TEST(InstantiateErrors, FuncRet_TypeMismatch) { 639 auto fdef = FDH::Define("test", {}, {"y: float"}, {}, 640 { 641 {{"y"}, "One", {}, {{"T", DT_DOUBLE}}}, 642 }); 643 InstantiationResult result; 644 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 645 "Invalid ret types y : float vs. double\n\tIn function output y"); 646 } 647 648 TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) { 649 auto fdef = FDH::Create( 650 // Name 651 "MySelect", 652 // Args 653 {"x: float"}, 654 // Return values 655 {"y: float"}, 656 // Attrs 657 {}, 658 // Nodes 659 { 660 {{"y"}, 661 "Cond", 662 {"x", "x"}, 663 {{"tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, 664 {"cond", FDH::FunctionRef("MyCond2")}, 665 {"then_branch", FDH::FunctionRef("MyThen2")}, 666 {"else_branch", FDH::FunctionRef("MyElse2")}}}, 667 }, 668 {{"y", "y:output"}}); 669 InstantiationResult result; 670 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 671 "type attr not found: out_types"); 672 } 673 674 TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) { 675 auto fdef = FDH::Create( 676 // Name 677 "MySelect", 678 // Args 679 {"x: float"}, 680 // Return values 681 {"y: float"}, 682 // Attrs 683 {}, 684 // Nodes 685 { 686 {{"y"}, 687 "Cond", 688 {"x", "x"}, 689 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, 690 {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, 691 {"cond", FDH::FunctionRef("MyCond2")}, 692 {"then_branch", FDH::FunctionRef("MyThen2")}, 693 {"else_branch", FDH::FunctionRef("MyElse2")}}}, 694 }, 695 {{"y", "y:output"}}); 696 InstantiationResult result; 697 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 698 "Invalid ret types"); 699 } 700 701 TEST(InstantiateErrors, TypeList_Missing_Arg) { 702 auto fdef = FDH::Create( 703 // Name 704 "MySelect", 705 // Args 706 {"x: float"}, 707 // Return values 708 {"y: float"}, 709 // Attrs 710 {}, 711 // Nodes 712 { 713 {{"y"}, 714 "Cond", 715 {"x", "unknown"}, 716 {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, 717 {"out_types", DataTypeSlice{DT_FLOAT}}, 718 {"cond", FDH::FunctionRef("MyCond2")}, 719 {"then_branch", FDH::FunctionRef("MyThen2")}, 720 {"else_branch", FDH::FunctionRef("MyElse2")}}}, 721 }, 722 {{"y", "y:output"}}); 723 InstantiationResult result; 724 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 725 "input unknown is not found"); 726 } 727 728 TEST(InstantiateErrors, TooManyInputs) { 729 auto fdef = FDH::Create( 730 // Name 731 "TooManyInputs", 732 // Inputs 733 {"x: float", "y: float"}, 734 // Outputs 735 {"z: float"}, 736 // Attrs 737 {}, 738 // Nodes 739 {// a = AddN<N=2>(x, y, x) 740 {{"a"}, "AddN", {"x", "y", "x"}, {{"T", DT_FLOAT}, {"N", 2}}}}, 741 // Returns 742 {{"z", "a:sum:0"}}); 743 744 InstantiationResult result; 745 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 746 "Expected input[2] == 'x' to be a control input."); 747 } 748 749 TEST(InstantiateErrors, TooFewInputs) { 750 auto fdef = FDH::Create( 751 // Name 752 "TooFewInputs", 753 // Inputs 754 {"x: float", "y: float"}, 755 // Outputs 756 {"z: float"}, 757 // Attrs 758 {}, 759 // Nodes 760 {// a = AddN<N=3>(x, y) 761 {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 3}}}}, 762 // Returns 763 {{"z", "a:sum:0"}}); 764 765 InstantiationResult result; 766 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 767 "Attempt to access beyond input size: 2 >= 2"); 768 } 769 770 TEST(InstantiateErrors, TooManyInputsFromArray1) { 771 auto fdef = FDH::Create( 772 // Name 773 "TooManyInputsFromArray", 774 // Inputs 775 {"x: float", "y: float"}, 776 // Outputs 777 {"z: float"}, 778 // Attrs 779 {}, 780 // Nodes 781 {// a = _ListToArray(x,y) 782 {{"a"}, 783 "_ListToArray", 784 {"x", "y"}, 785 {{"N", 2}, 786 {"T", DT_FLOAT}, 787 {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}}, 788 // b = AddN<N=2>(a, y) 789 {{"b"}, "AddN", {"a:output", "y"}, {{"T", DT_FLOAT}, {"N", 2}}}}, 790 // Returns 791 {{"z", "a:sum:0"}}); 792 793 InstantiationResult result; 794 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 795 "Expected input[1] == 'y' to be a control input."); 796 } 797 798 TEST(InstantiateErrors, TooManyInputsFromArray2) { 799 auto fdef = FDH::Create( 800 // Name 801 "TooManyInputsFromArray", 802 // Inputs 803 {"x: float", "y: float"}, 804 // Outputs 805 {"z: float"}, 806 // Attrs 807 {}, 808 // Nodes 809 {// a = _ListToArray(x,y) 810 {{"a"}, 811 "_ListToArray", 812 {"x", "y"}, 813 {{"N", 2}, 814 {"T", DT_FLOAT}, 815 {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}}, 816 // b = AddN<N=2>(x, a) 817 {{"b"}, "AddN", {"x", "a:output"}, {{"T", DT_FLOAT}, {"N", 2}}}}, 818 // Returns 819 {{"z", "a:sum:0"}}); 820 821 InstantiationResult result; 822 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 823 "Input a:output too long for inputs"); 824 } 825 826 TEST(InstantiateErrors, TypeMismatch) { 827 auto fdef = FDH::Create( 828 // Name 829 "TypeMismatch", 830 // Inputs 831 {"x: float", "y: int32"}, 832 // Outputs 833 {"z: float"}, 834 // Attrs 835 {}, 836 // Nodes 837 {// a = AddN<N=2>(x, y) 838 {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 3}}}}, 839 // Returns 840 {{"z", "a:sum:0"}}); 841 842 InstantiationResult result; 843 HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), 844 "input inputs[1] expected type float != int32, the type of y[0]"); 845 } 846 847 TEST(FunctionCallFrame, Void_Void) { 848 FunctionCallFrame frame({}, {}); 849 TF_EXPECT_OK(frame.SetArgs({})); 850 auto a = test::AsTensor<float>({100}); 851 HasError(frame.SetArgs({a}), "Invalid argument"); 852 Tensor v; 853 HasError(frame.GetArg(0, &v), "Invalid argument"); 854 HasError(frame.SetRetval(0, v), "Invalid argument"); 855 std::vector<Tensor> rets; 856 TF_EXPECT_OK(frame.GetRetvals(&rets)); 857 EXPECT_EQ(rets.size(), 0); 858 } 859 860 TEST(FunctionCallFrame, Float_Float_Float) { 861 FunctionCallFrame frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT}); 862 HasError(frame.SetArgs({}), "Invalid argument: Expects 2 arguments"); 863 auto a = test::AsTensor<float>({100}); 864 auto b = test::AsTensor<float>({200}); 865 auto c = test::AsTensor<int64>({300}); 866 HasError(frame.SetArgs({a, c}), 867 "Invalid argument: Expects arg[1] to be float"); 868 TF_EXPECT_OK(frame.SetArgs({a, b})); 869 870 Tensor v; 871 HasError(frame.GetArg(-1, &v), "Invalid argument"); 872 HasError(frame.GetArg(2, &v), "Invalid argument"); 873 TF_EXPECT_OK(frame.GetArg(0, &v)); 874 test::ExpectTensorEqual<float>(a, v); 875 TF_EXPECT_OK(frame.GetArg(1, &v)); 876 test::ExpectTensorEqual<float>(b, v); 877 878 v = test::AsTensor<float>({-100}); 879 HasError(frame.SetRetval(-1, v), "Invalid argument"); 880 HasError(frame.SetRetval(1, v), "Invalid argument"); 881 HasError(frame.SetRetval(0, test::AsTensor<int64>({-100})), 882 "Invalid argument: Expects ret[0] to be float"); 883 884 std::vector<Tensor> rets; 885 HasError(frame.GetRetvals(&rets), "does not have value"); 886 TF_EXPECT_OK(frame.SetRetval(0, v)); 887 HasError(frame.SetRetval(0, v), "has already been set"); 888 889 TF_EXPECT_OK(frame.GetRetvals(&rets)); 890 EXPECT_EQ(rets.size(), 1); 891 test::ExpectTensorEqual<float>(rets[0], v); 892 } 893 894 TEST(Canonicalize, Basic) { 895 EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT}, 896 {"transpose_a", false}, 897 {"transpose_b", false}})), 898 "MatMul[T=float,transpose_a=false,transpose_b=false]"); 899 EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT}, 900 {"transpose_b", false}, 901 {"transpose_a", false}})), 902 "MatMul[T=float,transpose_a=false,transpose_b=false]"); 903 EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_DOUBLE}, 904 {"transpose_b", true}, 905 {"transpose_a", false}})), 906 "MatMul[T=double,transpose_a=false,transpose_b=true]"); 907 } 908 909 TEST(FunctionLibraryDefinitionTest, Find) { 910 FunctionDefLibrary proto; 911 *proto.add_function() = test::function::XTimesTwo(); 912 FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); 913 914 EXPECT_EQ(lib_def.Find("XTimes16"), nullptr); 915 916 auto expect = R"P( 917 XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) { 918 two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() 919 scale = Cast[DstT=$T, SrcT=int64](two:output:0) 920 y = Mul[T=$T](x, scale:y:0) 921 return y = y:z:0 922 } 923 )P"; 924 auto found = lib_def.Find("XTimesTwo"); 925 ASSERT_NE(found, nullptr); 926 EXPECT_EQ(expect, DebugString(*found)); 927 } 928 929 TEST(FunctionLibraryDefinitionTest, LookUp) { 930 FunctionDefLibrary proto; 931 *proto.add_function() = test::function::XTimesTwo(); 932 FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); 933 934 const OpDef* op_def; 935 EXPECT_TRUE(!lib_def.LookUpOpDef("XTimes16", &op_def).ok()); 936 937 TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &op_def)); 938 ASSERT_NE(op_def, nullptr); 939 EXPECT_EQ(op_def->DebugString(), 940 test::function::XTimesTwo().signature().DebugString()); 941 942 const OpRegistrationData* op_reg_data; 943 TF_EXPECT_OK(lib_def.LookUp("XTimesTwo", &op_reg_data)); 944 ASSERT_NE(op_reg_data, nullptr); 945 // Shape inference function is initialized to UnknownShape. 946 ASSERT_NE(op_reg_data->shape_inference_fn, nullptr); 947 } 948 949 TEST(FunctionLibraryDefinitionTest, AddFunctionDef) { 950 // Add one function to the proto lib before constructing 'lib_def'. 951 FunctionDefLibrary proto; 952 *proto.add_function() = test::function::XTimesTwo(); 953 FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); 954 955 // Add a new function def to the library. 956 TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB())); 957 958 // Test lookup of first function. 959 const OpDef* first; 960 TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &first)); 961 ASSERT_NE(first, nullptr); 962 EXPECT_EQ(first->DebugString(), 963 test::function::XTimesTwo().signature().DebugString()); 964 965 // Test lookup of second function. 966 const OpDef* second; 967 TF_EXPECT_OK(lib_def.LookUpOpDef("WXPlusB", &second)); 968 ASSERT_NE(second, nullptr); 969 EXPECT_EQ(second->DebugString(), 970 test::function::WXPlusB().signature().DebugString()); 971 972 // Can't add function with same name as existing op 973 FunctionDef fdef = test::function::XTimesTwo(); 974 fdef.mutable_signature()->set_name("Add"); 975 Status s = lib_def.AddFunctionDef(fdef); 976 EXPECT_FALSE(s.ok()); 977 EXPECT_EQ(s.error_message(), 978 "Cannot add function 'Add' because an op with the same name " 979 "already exists."); 980 981 // Already-added functions don't produce error 982 TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::XTimesTwo())); 983 TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB())); 984 } 985 986 TEST(FunctionLibraryDefinitionTest, AddGradientDef) { 987 // AddGradientDef() doesn't check that functions referenced exist (yet?) 988 FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary()); 989 990 // Test adding a gradient (XTimesFour isn't a valid grad function for 991 // XTimesTwo but that's ok for now) 992 GradientDef grad; 993 grad.set_function_name(test::function::XTimesTwo().signature().name()); 994 grad.set_gradient_func(test::function::XTimesFour().signature().name()); 995 TF_EXPECT_OK(lib_def.AddGradientDef(grad)); 996 997 // Already-added gradients don't produce error 998 TF_EXPECT_OK(lib_def.AddGradientDef(grad)); 999 1000 // Test that adding a duplicate gradient fails 1001 grad.set_gradient_func(test::function::XTimes16().signature().name()); 1002 Status s = lib_def.AddGradientDef(grad); 1003 EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); 1004 EXPECT_EQ(s.error_message(), 1005 "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because " 1006 "it already has gradient function 'XTimesFour'"); 1007 } 1008 1009 TEST(FunctionLibraryDefinitionTest, AddLibrary) { 1010 // Create lib def with single function 1011 FunctionDefLibrary proto; 1012 *proto.add_function() = test::function::XTimesTwo(); 1013 FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); 1014 1015 // Add gradient 1016 GradientDef grad; 1017 grad.set_function_name(test::function::XTimesTwo().signature().name()); 1018 grad.set_gradient_func(test::function::XTimesFour().signature().name()); 1019 TF_EXPECT_OK(lib_def.AddGradientDef(grad)); 1020 1021 // Error if you try to add conflicting function 1022 proto.Clear(); 1023 FunctionDef fdef = test::function::XTimesFour(); 1024 fdef.mutable_signature()->set_name( 1025 test::function::XTimesTwo().signature().name()); 1026 *proto.add_function() = fdef; 1027 FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto); 1028 Status s = lib_def.AddLibrary(lib_def2); 1029 EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); 1030 EXPECT_EQ(s.error_message(), 1031 "Cannot add function 'XTimesTwo' because a different function with " 1032 "the same name already exists."); 1033 1034 // Error if you try to add conflicting gradient 1035 proto.Clear(); 1036 grad.set_gradient_func(test::function::XTimes16().signature().name()); 1037 *proto.add_gradient() = grad; 1038 FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto); 1039 s = lib_def.AddLibrary(lib_def3); 1040 EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); 1041 EXPECT_EQ(s.error_message(), 1042 "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because " 1043 "it already has gradient function 'XTimesFour'"); 1044 1045 // No conflicting functions or gradients OK 1046 proto.Clear(); 1047 *proto.add_function() = test::function::XTimesFour(); 1048 grad.set_function_name(test::function::XTimes16().signature().name()); 1049 *proto.add_gradient() = grad; 1050 FunctionLibraryDefinition lib_def4(OpRegistry::Global(), proto); 1051 TF_EXPECT_OK(lib_def.AddLibrary(lib_def4)); 1052 1053 // OK to add the same functions and gradients twice 1054 TF_EXPECT_OK(lib_def.AddLibrary(lib_def)); 1055 } 1056 1057 GradientDef MakeGradDef(const string& f, const string& g) { 1058 GradientDef grad; 1059 grad.set_function_name(f); 1060 grad.set_gradient_func(g); 1061 return grad; 1062 } 1063 1064 TEST(FunctionLibraryDefinitionTest, AddLibrary_Atomic) { 1065 // Create lib def containing two functions with equal names 1066 FunctionDefLibrary proto; 1067 const string x2_name = test::function::XTimesTwo().signature().name(); 1068 const string x4_name = test::function::XTimesFour().signature().name(); 1069 *proto.add_function() = test::function::XTimesTwo(); 1070 FunctionDef fdef = test::function::XTimesFour(); 1071 fdef.mutable_signature()->set_name(x2_name); 1072 *proto.add_function() = fdef; 1073 FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary()); 1074 1075 // Try adding the two functions to lib_def 1076 Status s = lib_def.AddLibrary(proto); 1077 EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); 1078 EXPECT_EQ( 1079 "Cannot add function 'XTimesTwo' because a different function with " 1080 "the same name already exists.", 1081 s.error_message()); 1082 1083 // Verify that none of the functions are added 1084 EXPECT_TRUE(lib_def.Find(x2_name) == nullptr); 1085 1086 // Fix the name in proto but add two gradient names for it 1087 proto.mutable_function(1)->mutable_signature()->set_name(x4_name); 1088 *proto.add_gradient() = MakeGradDef(x2_name, x4_name); 1089 *proto.add_gradient() = MakeGradDef(x2_name, "SecondGradName"); 1090 1091 // Try adding the library and check that nothing was added 1092 s = lib_def.AddLibrary(proto); 1093 EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); 1094 EXPECT_EQ(s.error_message(), 1095 "Cannot assign gradient function 'SecondGradName' to 'XTimesTwo' " 1096 "because it already has gradient function 'XTimesFour'"); 1097 EXPECT_TRUE(lib_def.Find(x2_name) == nullptr); 1098 EXPECT_EQ(0, lib_def.ToProto().function_size()); 1099 EXPECT_EQ(0, lib_def.ToProto().gradient_size()); 1100 } 1101 1102 TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_FuncConflict) { 1103 const string x2_name = test::function::XTimesTwo().signature().name(); 1104 const string x4_name = test::function::XTimesFour().signature().name(); 1105 const string wx_name = test::function::WXPlusB().signature().name(); 1106 1107 // Create FunctionLibraryDefinition with 1108 // (func = XTimesTwo, grad = XTimesFour) 1109 FunctionDefLibrary proto; 1110 *proto.add_function() = test::function::XTimesTwo(); 1111 *proto.add_gradient() = MakeGradDef(x2_name, x4_name); 1112 FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); 1113 EXPECT_EQ(1, lib_def.ToProto().function_size()); 1114 EXPECT_EQ(1, lib_def.ToProto().gradient_size()); 1115 1116 // Create FunctionLibraryDefinition with (func = WXPlusB, grad = XTimesTwo) 1117 // and function (name = XTimesTwo, body = XTimeFour) 1118 FunctionDefLibrary proto2; 1119 *proto2.add_function() = test::function::WXPlusB(); 1120 *proto2.add_gradient() = MakeGradDef(wx_name, x2_name); 1121 *proto2.add_function() = test::function::XTimesFour(); 1122 proto2.mutable_function(1)->mutable_signature()->set_name(x2_name); 1123 FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2); 1124 1125 // Verify that adding lib_def2 will fail because of function conflict 1126 // and WXPlusB is not added. 1127 Status s = lib_def.AddLibrary(lib_def2); 1128 EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); 1129 EXPECT_EQ( 1130 "Cannot add function 'XTimesTwo' because a different function " 1131 "with the same name already exists.", 1132 s.error_message()); 1133 EXPECT_TRUE(lib_def.Find(wx_name) == nullptr); 1134 EXPECT_EQ(1, lib_def.ToProto().function_size()); 1135 EXPECT_EQ(1, lib_def.ToProto().gradient_size()); 1136 } 1137 1138 TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_GradConflict) { 1139 const string x2_name = test::function::XTimesTwo().signature().name(); 1140 const string x4_name = test::function::XTimesFour().signature().name(); 1141 const string wx_name = test::function::WXPlusB().signature().name(); 1142 1143 // Create FunctionLibraryDefinition with 1144 // (func = XTimesTwo, grad = XTimesFour) 1145 FunctionDefLibrary proto; 1146 *proto.add_function() = test::function::XTimesTwo(); 1147 *proto.add_gradient() = MakeGradDef(x2_name, x4_name); 1148 FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); 1149 EXPECT_EQ(1, lib_def.ToProto().function_size()); 1150 EXPECT_EQ(1, lib_def.ToProto().gradient_size()); 1151 1152 // Create FunctionLibraryDefinition with (func = WXPlusB, grad = XTimesTwo) 1153 // and (func = XTimesTwo, grad = WXPlusB) 1154 FunctionDefLibrary proto2; 1155 *proto2.add_function() = test::function::WXPlusB(); 1156 *proto2.add_gradient() = MakeGradDef(wx_name, x2_name); 1157 *proto2.add_function() = test::function::XTimesTwo(); 1158 *proto2.add_gradient() = MakeGradDef(x2_name, wx_name); 1159 FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2); 1160 1161 // Verify that adding lib_def2 will fail because of gradient conflict 1162 // and WXPlusB is not added. 1163 Status s = lib_def.AddLibrary(lib_def2); 1164 EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); 1165 EXPECT_EQ( 1166 "Cannot assign gradient function 'WXPlusB' to 'XTimesTwo'" 1167 " because it already has gradient function 'XTimesFour'", 1168 s.error_message()); 1169 EXPECT_TRUE(lib_def.Find(wx_name) == nullptr); 1170 EXPECT_EQ(1, lib_def.ToProto().function_size()); 1171 EXPECT_EQ(1, lib_def.ToProto().gradient_size()); 1172 } 1173 1174 TEST(FunctionLibraryDefinitionTest, ToProto) { 1175 FunctionDefLibrary proto1; 1176 *proto1.add_function() = test::function::XTimesTwo(); 1177 *proto1.add_function() = test::function::WXPlusB(); 1178 FunctionLibraryDefinition lib_def1(OpRegistry::Global(), proto1); 1179 1180 // Call 'ToProto' and make sure both protos have the same function lib size. 1181 FunctionDefLibrary proto2 = lib_def1.ToProto(); 1182 EXPECT_EQ(proto1.function_size(), proto2.function_size()); 1183 1184 // Initialize 'lib_def2' with proto returned by 'ToProto' call. 1185 FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2); 1186 1187 // Test that the first function exists in both libraries. 1188 const OpDef *f1, *f2, *f3, *f4; 1189 TF_EXPECT_OK(lib_def1.LookUpOpDef("XTimesTwo", &f1)); 1190 TF_EXPECT_OK(lib_def2.LookUpOpDef("XTimesTwo", &f2)); 1191 EXPECT_EQ(f1->DebugString(), f2->DebugString()); 1192 1193 // Test that the second function exists in both libraries. 1194 TF_EXPECT_OK(lib_def1.LookUpOpDef("WXPlusB", &f3)); 1195 TF_EXPECT_OK(lib_def2.LookUpOpDef("WXPlusB", &f4)); 1196 EXPECT_EQ(f3->DebugString(), f4->DebugString()); 1197 } 1198 1199 TEST(FunctionLibraryDefinitionTest, GetAttr_FuncNoAttr) { 1200 FunctionDefLibrary proto; 1201 *proto.add_function() = test::function::XTimesTwo(); 1202 FunctionLibraryDefinition lib(OpRegistry::Global(), proto); 1203 1204 NodeDef ndef; 1205 bool annotation; 1206 1207 // Not a function. 1208 ndef.set_op("Matmul"); 1209 EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok()); 1210 1211 // A function. No attr defined. 1212 ndef.set_op("XTimesTwo"); 1213 EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok()); 1214 1215 // ndef defines the attr. But we don't care. 1216 AddNodeAttr("annotation", true, &ndef); 1217 EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok()); 1218 } 1219 1220 template <typename T> 1221 void SetAttrValue(FunctionDef* fdef, const string& attr, const T& value) { 1222 AttrValue attr_value; 1223 SetAttrValue(value, &attr_value); 1224 fdef->mutable_attr()->insert({attr, attr_value}); 1225 } 1226 1227 TEST(FunctionLibraryDefinitionTest, GetAttr_FuncWithAttr) { 1228 FunctionDefLibrary proto; 1229 auto fdef = proto.add_function(); 1230 *fdef = test::function::XTimesTwo(); 1231 SetAttrValue(fdef, "annotation", true); 1232 SetAttrValue(fdef, "options", "some string data"); 1233 FunctionLibraryDefinition lib(OpRegistry::Global(), proto); 1234 1235 NodeDef ndef; 1236 bool annotation; 1237 1238 // A function. No attr defined in ndef. 1239 ndef.set_op("XTimesTwo"); 1240 TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation)); 1241 EXPECT_EQ(annotation, true); 1242 1243 string str; 1244 TF_EXPECT_OK(lib.GetAttr(ndef, "options", &str)); 1245 EXPECT_EQ(str, "some string data"); 1246 } 1247 1248 TEST(FunctionLibraryDefinitionTest, GetAttr_Gradient) { 1249 FunctionDefLibrary proto; 1250 auto fdef = proto.add_function(); 1251 *fdef = test::function::XTimesTwo(); 1252 SetAttrValue(fdef, "annotation", true); 1253 *fdef = test::function::WXPlusB(); 1254 SetAttrValue(fdef, "annotation", false); 1255 auto func_grad = proto.add_gradient(); 1256 func_grad->set_function_name("XTimesTwo"); 1257 func_grad->set_gradient_func("WXPlusB"); 1258 FunctionLibraryDefinition lib(OpRegistry::Global(), proto); 1259 1260 NodeDef ndef; 1261 ndef.set_op(FunctionLibraryDefinition::kGradientOp); 1262 1263 bool annotation; 1264 EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok()); 1265 1266 NameAttrList nal; 1267 nal.set_name("XTimesTwo"); 1268 AddNodeAttr(FunctionLibraryDefinition::kFuncAttr, nal, &ndef); 1269 TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation)); 1270 EXPECT_EQ(annotation, false); // XTimesTwo's gradient is WXPlusB. 1271 1272 nal.set_name("WXPlusB"); 1273 ndef.clear_attr(); 1274 AddNodeAttr(FunctionLibraryDefinition::kFuncAttr, nal, &ndef); 1275 TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation)); 1276 EXPECT_EQ(annotation, false); // WXPlusB has no custom gradient. 1277 } 1278 1279 // TODO(skyewm): this could be more thorough 1280 TEST(FunctionDefsEqualTest, TestFunctionDefsEqual) { 1281 // Equal functions 1282 const FunctionDef fdef1 = test::function::XTimesTwo(); 1283 FunctionDef fdef2 = test::function::XTimesTwo(); 1284 uint64 hash1 = FunctionDefHash(fdef1); 1285 EXPECT_TRUE(FunctionDefsEqual(fdef1, fdef2)); 1286 EXPECT_EQ(hash1, FunctionDefHash(fdef2)); 1287 1288 // Different functions 1289 fdef2 = test::function::XTimesFour(); 1290 EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); 1291 EXPECT_NE(hash1, FunctionDefHash(fdef2)); 1292 1293 // Different signatures 1294 fdef2 = test::function::XTimesTwo(); 1295 fdef2.mutable_signature()->mutable_input_arg(0)->set_name("foo"); 1296 EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); 1297 EXPECT_NE(hash1, FunctionDefHash(fdef2)); 1298 1299 // Descriptions must be equal 1300 fdef2 = test::function::XTimesTwo(); 1301 fdef2.mutable_signature()->mutable_input_arg(0)->set_description("foo"); 1302 EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); 1303 EXPECT_NE(hash1, FunctionDefHash(fdef2)); 1304 1305 // Different NodeDefs 1306 fdef2 = test::function::XTimesTwo(); 1307 NodeDef* ndef = fdef2.add_node_def(); 1308 *ndef = fdef2.node_def(0); 1309 ndef->set_name("new_name"); 1310 EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); 1311 EXPECT_NE(hash1, FunctionDefHash(fdef2)); 1312 1313 // Different return values 1314 fdef2 = test::function::XTimesTwo(); 1315 (*fdef2.mutable_ret())["y"] = "y:z:1"; // originally is "y:z:0" 1316 EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); 1317 EXPECT_NE(hash1, FunctionDefHash(fdef2)); 1318 1319 // Different attributes 1320 fdef2 = test::function::XTimesTwo(); 1321 SetAttrValue(&fdef2, "ExtraAttr", true); 1322 EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); 1323 EXPECT_NE(hash1, FunctionDefHash(fdef2)); 1324 1325 // Multiple equivalent attributes; the two functions should be equal. 1326 fdef2 = test::function::XTimesTwo(); 1327 FunctionDef fdef3 = test::function::XTimesTwo(); 1328 SetAttrValue(&fdef2, "Foo", true); 1329 SetAttrValue(&fdef3, "Foo", true); 1330 SetAttrValue(&fdef2, "Bar", 123); 1331 SetAttrValue(&fdef3, "Bar", 123); 1332 SetAttrValue(&fdef2, "Baz", "abc"); 1333 SetAttrValue(&fdef3, "Baz", "abc"); 1334 EXPECT_TRUE(FunctionDefsEqual(fdef2, fdef3)); 1335 EXPECT_EQ(FunctionDefHash(fdef2), FunctionDefHash(fdef3)); 1336 } 1337 1338 } // end namespace 1339 } // end namespace tensorflow 1340