1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/node_def_builder.h" 17 18 #include <memory> 19 #include <vector> 20 #include "tensorflow/core/framework/fake_input.h" 21 #include "tensorflow/core/framework/node_def_util.h" 22 #include "tensorflow/core/framework/op_def_builder.h" 23 #include "tensorflow/core/framework/op_def_util.h" 24 #include "tensorflow/core/lib/core/status_test_util.h" 25 #include "tensorflow/core/platform/protobuf.h" 26 #include "tensorflow/core/platform/test.h" 27 28 namespace tensorflow { 29 namespace { 30 31 class NodeDefBuilderTest : public ::testing::Test { 32 protected: 33 // Specify an OpDef via an OpDefBuilder. 34 void Op(const OpDefBuilder& op_def_builder) { 35 OpRegistrationData op_reg_data; 36 TF_EXPECT_OK(op_def_builder.Finalize(&op_reg_data)); 37 op_def_ = op_reg_data.op_def; 38 } 39 40 // Resets builder_ with a new NodeDefBuilder using the Op from the last call 41 // to Op() above. 42 NodeDefBuilder& Builder() { 43 EXPECT_FALSE(op_def_.name().empty()) << "Must call Op() before Builder()"; 44 builder_.reset(new NodeDefBuilder("n", &op_def_)); 45 return *builder_; 46 } 47 48 // Calls Finalize() and verifies it returns success and the result matches 49 // expectations. 50 void ExpectSuccess(const NodeDefBuilder& builder, 51 DataTypeSlice expected_in_types, 52 DataTypeSlice expected_out_types, StringPiece proto) { 53 NodeDef node_def; 54 Status status = builder.Finalize(&node_def); 55 TF_EXPECT_OK(status); 56 if (!status.ok()) return; 57 NodeDef expected; 58 protobuf::TextFormat::ParseFromString(strings::StrCat("name: 'n' ", proto), 59 &expected); 60 EXPECT_EQ(node_def.DebugString(), expected.DebugString()); 61 62 DataTypeVector in_types, out_types; 63 status = 64 InOutTypesForNode(node_def, builder.op_def(), &in_types, &out_types); 65 TF_EXPECT_OK(status); 66 if (!status.ok()) return; 67 EXPECT_EQ(DataTypeSliceString(expected_in_types), 68 DataTypeVectorString(in_types)); 69 EXPECT_EQ(DataTypeSliceString(expected_out_types), 70 DataTypeVectorString(out_types)); 71 72 status = ValidateNodeDef(node_def, op_def_); 73 TF_EXPECT_OK(status); 74 } 75 76 // Calls Finalize() and verifies it returns an error. 77 // Each message must appear as a substring of the error. 78 void ExpectFailures(const NodeDefBuilder& builder, 79 const std::vector<string>& messages) { 80 NodeDef node_def; 81 Status status = builder.Finalize(&node_def); 82 EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); 83 if (status.ok()) return; 84 for (const string& message : messages) { 85 EXPECT_TRUE(StringPiece(status.error_message()).contains(message)) 86 << status << ", " << message; 87 } 88 } 89 90 // Calls Finalize() and verifies it returns an error. 91 // Message must appear as a substring of the error. 92 void ExpectFailure(const NodeDefBuilder& builder, const string& message) { 93 ExpectFailures(builder, {message}); 94 } 95 96 // Like ExpectFailure(), except that the error can come from 97 // ValidateNodeDef(). 98 void ExpectInvalid(const NodeDefBuilder& builder, const string& message) { 99 NodeDef node_def; 100 Status status = builder.Finalize(&node_def); 101 if (status.ok()) { 102 status = ValidateNodeDef(node_def, op_def_); 103 } 104 EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); 105 if (status.ok()) return; 106 EXPECT_TRUE(StringPiece(status.error_message()).contains(message)) 107 << "Actual error: " << status.error_message() 108 << "\nDoes not contain: " << message; 109 } 110 111 OpDef op_def_; 112 std::unique_ptr<NodeDefBuilder> builder_; 113 }; 114 115 TEST_F(NodeDefBuilderTest, Simple) { 116 Op(OpDefBuilder("Simple").Input("a: int32").Output("out: float")); 117 118 ExpectSuccess(Builder().Input("x", 0, DT_INT32), {DT_INT32}, {DT_FLOAT}, 119 R"proto( op: "Simple" input: "x" )proto"); 120 121 // Port != 0 122 ExpectSuccess(Builder().Input("y", 2, DT_INT32), {DT_INT32}, {DT_FLOAT}, 123 R"proto( op: "Simple" input: "y:2" )proto"); 124 125 // FakeInput 126 ExpectSuccess(Builder().Input(FakeInput()), {DT_INT32}, {DT_FLOAT}, R"proto( 127 op: "Simple" input: "a" )proto"); 128 129 ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_FLOAT}, 130 R"proto( op: "Simple" input: "a" )proto"); 131 132 // Ref input 133 ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32}, 134 {DT_FLOAT}, R"proto( op: "Simple" input: "a" )proto"); 135 136 // ControlInput 137 ExpectSuccess( 138 Builder().ControlInput("x").Input(FakeInput()).ControlInput("y"), 139 {DT_INT32}, {DT_FLOAT}, R"proto( 140 op: "Simple" input: ["a", "^x", "^y"] )proto"); 141 142 // Device 143 ExpectSuccess(Builder().Input(FakeInput()).Device("ddd"), {DT_INT32}, 144 {DT_FLOAT}, R"proto( 145 op: "Simple" input: "a" device: "ddd" )proto"); 146 147 // Extra input 148 ExpectFailure(Builder().Input("x", 0, DT_INT32).Input("y", 0, DT_INT32), 149 "More Input() calls than the 1 input_args while building " 150 "NodeDef 'n' using Op<name=Simple; signature=a:int32 -> " 151 "out:float>"); 152 153 // Missing input 154 ExpectFailure(Builder(), "0 inputs specified of 1 inputs in Op while"); 155 156 { // Finalize() twice. 157 NodeDefBuilder& builder = Builder(); 158 // First call to Finalize() 159 TF_EXPECT_OK(builder.Input(FakeInput()).Finalize(nullptr)); 160 // ExpectSuccess() also calls Finalize(). 161 ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto( 162 op: "Simple" input: "a" )proto"); 163 } 164 165 { // Input() after Finalize() 166 NodeDefBuilder& builder = Builder(); 167 // Calling Finalize() before enough inputs -> error. 168 ExpectFailure(builder, "0 inputs specified of 1 inputs in Op while"); 169 builder.Input(FakeInput()); 170 // Calling Finalize() with enough inputs -> success 171 ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto( 172 op: "Simple" input: "a" )proto"); 173 // Calling Finalize() with too many inputs -> error. 174 builder.Input(FakeInput(DT_INT32)); 175 ExpectFailure(builder, "More Input() calls than the 1 input_args while"); 176 } 177 178 // Wrong input type 179 ExpectFailure(Builder().Input("x", 0, DT_FLOAT), 180 "Input 'a' passed float expected int32 "); 181 182 ExpectFailure(Builder().Input("x", 0, DT_FLOAT_REF), 183 "Input 'a' passed float_ref expected int32 "); 184 185 // List input 186 ExpectFailure(Builder().Input(FakeInput(3, DT_FLOAT)), 187 "List provided to input 'a' when single Tensor expected while"); 188 189 ExpectFailure(Builder().Input(FakeInput(3)), 190 "List provided to input 'a' when single Tensor expected while"); 191 192 // Bad ControlInput 193 ExpectInvalid(Builder().Input(FakeInput()).ControlInput("z:2"), 194 "Control input '^z:2' must not have ':' in NodeDef:"); 195 196 // Bad input name 197 ExpectFailure(Builder().Input("", 0, DT_INT32), 198 "Empty input node name while"); 199 200 ExpectFailure(Builder().Input("^x", 0, DT_INT32), 201 "Non-control input starting with ^: ^x while"); 202 } 203 204 TEST_F(NodeDefBuilderTest, OpDoesNotExist) { 205 NodeDefBuilder builder("n", "Op Does Not Exist"); 206 builder.Input(FakeInput()) 207 .Input(FakeInput(12)) 208 .ControlInput("y") 209 .Attr("foo", 12) 210 .Device("device"); 211 ExpectFailures(builder, {"Op type not registered 'Op Does Not Exist'", 212 "while building NodeDef 'n'"}); 213 } 214 215 TEST_F(NodeDefBuilderTest, Polymorphic) { 216 Op(OpDefBuilder("Polymorphic") 217 .Input("v: T") 218 .Output("out: T") 219 .Attr("T: type")); 220 221 ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_INT32}, 222 R"proto( 223 op: "Polymorphic" input: "a" 224 attr { key: "T" value { type: DT_INT32 } } )proto"); 225 226 ExpectSuccess(Builder().Input(FakeInput(DT_FLOAT)), {DT_FLOAT}, {DT_FLOAT}, 227 R"proto( 228 op: "Polymorphic" input: "a" 229 attr { key: "T" value { type: DT_FLOAT } } )proto"); 230 231 // Redundant Attr() 232 ExpectSuccess(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_BOOL), 233 {DT_BOOL}, {DT_BOOL}, R"proto( 234 op: "Polymorphic" input: "a" 235 attr { key: "T" value { type: DT_BOOL } } )proto"); 236 237 // Conficting Attr() 238 ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_STRING), 239 "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while"); 240 241 ExpectFailure(Builder().Attr("T", DT_STRING).Input(FakeInput(DT_BOOL)), 242 "Inconsistent values for attr 'T' DT_STRING vs. DT_BOOL while"); 243 244 ExpectFailure(Builder().Attr("T", 12).Input(FakeInput(DT_BOOL)), 245 "Inconsistent values for attr 'T' 12 vs. DT_BOOL while"); 246 } 247 248 TEST_F(NodeDefBuilderTest, PolymorphicOut) { 249 Op(OpDefBuilder("PolymorphicOut").Output("out: T").Attr("T: type")); 250 251 ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32}, R"proto( 252 op: "PolymorphicOut" 253 attr { key: "T" value { type: DT_INT32 } } )proto"); 254 255 ExpectSuccess(Builder().Attr("T", DT_FLOAT), {}, {DT_FLOAT}, R"proto( 256 op: "PolymorphicOut" 257 attr { key: "T" value { type: DT_FLOAT } } )proto"); 258 259 // Redundant attr 260 ExpectSuccess(Builder().Attr("T", DT_FLOAT).Attr("T", DT_FLOAT), {}, 261 {DT_FLOAT}, R"proto( 262 op: "PolymorphicOut" 263 attr { key: "T" value { type: DT_FLOAT } } )proto"); 264 265 // Conflicting attr 266 ExpectFailure(Builder().Attr("T", DT_BOOL).Attr("T", DT_FLOAT), 267 "Inconsistent values for attr 'T' DT_BOOL vs. DT_FLOAT while"); 268 269 // Missing attr 270 ExpectInvalid(Builder(), "NodeDef missing attr 'T' from"); 271 272 // Attr has the wrong type 273 ExpectInvalid( 274 Builder().Attr("T", {DT_INT32, DT_BOOL}), 275 "AttrValue had value with type 'list(type)' when 'type' expected"); 276 277 ExpectInvalid(Builder().Attr("T", 12), 278 "AttrValue had value with type 'int' when 'type' expected"); 279 } 280 281 TEST_F(NodeDefBuilderTest, PolymorphicDefaultOut) { 282 Op(OpDefBuilder("PolymorphicDefaultOut") 283 .Output("out: T") 284 .Attr("T: type = DT_STRING")); 285 286 ExpectSuccess(Builder(), {}, {DT_STRING}, R"proto( 287 op: "PolymorphicDefaultOut" 288 attr { key: "T" value { type: DT_STRING } } )proto"); 289 290 ExpectSuccess(Builder().Attr("T", DT_BOOL), {}, {DT_BOOL}, R"proto( 291 op: "PolymorphicDefaultOut" 292 attr { key: "T" value { type: DT_BOOL } } )proto"); 293 } 294 295 TEST_F(NodeDefBuilderTest, Binary) { 296 Op(OpDefBuilder("Binary").Input("a: T").Input("b: T").Output("out: T").Attr( 297 "T: type")); 298 299 ExpectSuccess(Builder().Input(FakeInput(DT_INT32)).Input(FakeInput(DT_INT32)), 300 {DT_INT32, DT_INT32}, {DT_INT32}, R"proto( 301 op: "Binary" input: "a" input: "b" 302 attr { key: "T" value { type: DT_INT32 } } )proto"); 303 304 ExpectSuccess(Builder().Input(FakeInput(DT_STRING)).Input(FakeInput()), 305 {DT_STRING, DT_STRING}, {DT_STRING}, R"proto( 306 op: "Binary" input: "a" input: "b" 307 attr { key: "T" value { type: DT_STRING } } )proto"); 308 309 // Type mismatch 310 ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Input(FakeInput(DT_STRING)), 311 "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while"); 312 } 313 314 TEST_F(NodeDefBuilderTest, Restrict) { 315 Op(OpDefBuilder("Restrict") 316 .Input("a: T") 317 .Output("out: T") 318 .Attr("T: {string, bool}")); 319 ExpectSuccess(Builder().Input(FakeInput(DT_STRING)), {DT_STRING}, {DT_STRING}, 320 R"proto( 321 op: "Restrict" input: "a" 322 attr { key: "T" value { type: DT_STRING } } )proto"); 323 324 ExpectInvalid(Builder().Input(FakeInput(DT_INT32)), 325 "Value for attr 'T' of int32 is not in the list of allowed " 326 "values: string, bool"); 327 } 328 329 TEST_F(NodeDefBuilderTest, TypeList) { 330 Op(OpDefBuilder("TypeList").Input("a: T").Attr("T: list(type)")); 331 332 ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_INT32})), 333 {DT_STRING, DT_INT32}, {}, R"proto( 334 op: "TypeList" input: ["a", "a:1"] 335 attr { key: "T" value { list { type: [DT_STRING, DT_INT32] } } } 336 )proto"); 337 338 ExpectSuccess(Builder().Input(FakeInput(3, DT_BOOL)), 339 {DT_BOOL, DT_BOOL, DT_BOOL}, {}, R"proto( 340 op: "TypeList" input: ["a", "a:1", "a:2"] 341 attr { key: "T" value { list { type: [DT_BOOL, DT_BOOL, DT_BOOL] } } } 342 )proto"); 343 344 ExpectInvalid(Builder().Input(FakeInput(0)), 345 "Length for attr 'T' of 0 must be at least minimum 1"); 346 347 ExpectInvalid(Builder().Input(FakeInput({})), 348 "Length for attr 'T' of 0 must be at least minimum 1"); 349 350 ExpectInvalid(Builder().Input(FakeInput(DT_BOOL)), 351 "Single tensor passed to 'a', expected list while"); 352 353 ExpectFailures(Builder().Input(FakeInput()), 354 {"2 errors while building NodeDef", 355 "Could not infer list of types for input 'a': " 356 "No attr named 'T' in NodeDef:", 357 "0 inputs specified of 1 inputs in Op"}); 358 } 359 360 TEST_F(NodeDefBuilderTest, TypeListNoMin) { 361 Op(OpDefBuilder("TypeListNoMin").Input("a: T").Attr("T: list(type) >= 0")); 362 363 ExpectSuccess(Builder().Input(FakeInput(0)), {}, {}, R"proto( 364 op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); 365 366 ExpectSuccess(Builder().Input(FakeInput(DataTypeVector())), {}, {}, R"proto( 367 op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); 368 369 ExpectSuccess(Builder().Input(FakeInput({})), {}, {}, R"proto( 370 op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); 371 372 ExpectSuccess(Builder().Input(FakeInput({DT_BOOL})), {DT_BOOL}, {}, R"proto( 373 op: "TypeListNoMin" input: "a" 374 attr { key: "T" value { list { type: DT_BOOL } } } )proto"); 375 } 376 377 TEST_F(NodeDefBuilderTest, TypeListTwice) { 378 Op(OpDefBuilder("TypeListTwice") 379 .Input("a: T") 380 .Input("b: T") 381 .Attr("T: list(type) >= 0")); 382 383 ExpectSuccess(Builder() 384 .Input(FakeInput({DT_INT32, DT_BOOL})) 385 .Input(FakeInput({DT_INT32, DT_BOOL})), 386 {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto( 387 op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"] 388 attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto"); 389 390 ExpectSuccess( 391 Builder().Input(FakeInput({DT_INT32, DT_BOOL})).Input(FakeInput()), 392 {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto( 393 op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"] 394 attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto"); 395 396 ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(0)), {}, {}, 397 R"proto( 398 op: "TypeListTwice" attr { key: "T" value { list { } } } )proto"); 399 400 ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {}, 401 R"proto( 402 op: "TypeListTwice" attr { key: "T" value { list { } } } )proto"); 403 404 ExpectFailure(Builder() 405 .Input(FakeInput({DT_INT32, DT_BOOL})) 406 .Input(FakeInput({DT_INT32, DT_STRING})), 407 "Inconsistent values for attr 'T' [DT_INT32, DT_BOOL] vs. " 408 "[DT_INT32, DT_STRING] while"); 409 } 410 411 TEST_F(NodeDefBuilderTest, OutTypeList) { 412 Op(OpDefBuilder("OutTypeList").Output("out: T").Attr("T: list(type) >= 0")); 413 414 ExpectSuccess(Builder().Attr("T", {DT_FLOAT}), {}, {DT_FLOAT}, R"proto( 415 op: "OutTypeList" 416 attr { key: "T" value { list { type: DT_FLOAT } } } )proto"); 417 418 ExpectSuccess(Builder().Attr("T", {DT_STRING, DT_BOOL}), {}, 419 {DT_STRING, DT_BOOL}, R"proto( 420 op: "OutTypeList" 421 attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto"); 422 423 ExpectSuccess(Builder().Attr("T", DataTypeVector()), {}, {}, R"proto( 424 op: "OutTypeList" 425 attr { key: "T" value { list { } } } )proto"); 426 427 ExpectInvalid( 428 Builder().Attr("T", DT_FLOAT), 429 "AttrValue had value with type 'type' when 'list(type)' expected"); 430 } 431 432 TEST_F(NodeDefBuilderTest, TypeListRestrict) { 433 Op(OpDefBuilder("TypeListRestrict") 434 .Input("a: T") 435 .Attr("T: list({string, bool}) >= 0")); 436 437 ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_BOOL})), 438 {DT_STRING, DT_BOOL}, {}, R"proto( 439 op: "TypeListRestrict" input: ["a", "a:1"] 440 attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto"); 441 442 ExpectInvalid(Builder().Input(FakeInput({DT_STRING, DT_INT32})), 443 "Value for attr 'T' of int32 is not in the list of allowed " 444 "values: string, bool"); 445 } 446 447 TEST_F(NodeDefBuilderTest, OutTypeListRestrict) { 448 Op(OpDefBuilder("OutTypeListRestrict") 449 .Output("out: t") 450 .Attr("t: list({string, bool}) >= 0")); 451 452 ExpectSuccess(Builder().Attr("t", {DT_BOOL, DT_STRING}), {}, 453 {DT_BOOL, DT_STRING}, R"proto( 454 op: "OutTypeListRestrict" 455 attr { key: "t" value { list { type: [DT_BOOL, DT_STRING] } } } )proto"); 456 457 ExpectInvalid(Builder().Attr("t", {DT_STRING, DT_INT32}), 458 "Value for attr 't' of int32 is not in the list of allowed " 459 "values: string, bool"); 460 } 461 462 TEST_F(NodeDefBuilderTest, Attr) { 463 Op(OpDefBuilder("Attr").Attr("a: int")); 464 465 ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto( 466 op: "Attr" attr { key: "a" value { i: 12 } } )proto"); 467 468 // Attr has wrong type 469 ExpectInvalid(Builder().Attr("a", "bad"), 470 "AttrValue had value with type 'string' when 'int' expected"); 471 472 ExpectInvalid( 473 Builder().Attr("a", {12}), 474 "AttrValue had value with type 'list(int)' when 'int' expected"); 475 476 // Missing attr 477 ExpectInvalid(Builder(), "NodeDef missing attr 'a' from Op<"); 478 479 // Wrong attr 480 ExpectInvalid(Builder().Attr("b", 12), 481 "NodeDef mentions attr 'b' not in Op<"); 482 483 // Extra attr 484 ExpectInvalid(Builder().Attr("a", 12).Attr("extra", 12), 485 "NodeDef mentions attr 'extra' not in Op<"); 486 } 487 488 TEST_F(NodeDefBuilderTest, AttrFloat) { 489 Op(OpDefBuilder("AttrFloat").Attr("a: float")); 490 491 ExpectSuccess(Builder().Attr("a", 1.2f /* float */), {}, {}, R"proto( 492 op: "AttrFloat" attr { key: "a" value { f: 1.2 } } 493 )proto"); 494 495 ExpectSuccess(Builder().Attr("a", 1.2 /* double */), {}, {}, R"proto( 496 op: "AttrFloat" attr { key: "a" value { f: 1.2 } } 497 )proto"); 498 499 // Won't automatically cast int to float 500 ExpectInvalid(Builder().Attr("a", 12), 501 "AttrValue had value with type 'int' when 'float' expected"); 502 } 503 504 TEST_F(NodeDefBuilderTest, AttrBoolList) { 505 Op(OpDefBuilder("AttrBoolList").Attr("a: list(bool)")); 506 507 ExpectSuccess(Builder().Attr("a", {true, false, true}), {}, {}, R"proto( 508 op: "AttrBoolList" 509 attr { key: "a" value { list { b: [true, false, true] } } } 510 )proto"); 511 512 ExpectSuccess(Builder().Attr("a", std::vector<bool>()), {}, {}, R"proto( 513 op: "AttrBoolList" attr { key: "a" value { list { } } } 514 )proto"); 515 516 // Won't cast int -> bool. 517 ExpectInvalid(Builder().Attr("a", {0}), 518 "AttrValue had value with type 'list(int)' when 'list(bool)' " 519 "expected"); 520 } 521 522 TEST_F(NodeDefBuilderTest, AttrMin) { 523 Op(OpDefBuilder("AttrMin").Attr("a: int >= 5")); 524 525 ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto( 526 op: "AttrMin" attr { key: "a" value { i: 12 } } )proto"); 527 528 ExpectInvalid(Builder().Attr("a", 2), 529 "Value for attr 'a' of 2 must be at least minimum 5"); 530 } 531 532 TEST_F(NodeDefBuilderTest, AttrListMin) { 533 Op(OpDefBuilder("AttrListMin").Attr("a: list(int) >= 2")); 534 535 ExpectSuccess(Builder().Attr("a", {1, 2}), {}, {}, R"proto( 536 op: "AttrListMin" 537 attr { key: "a" value { list { i: [1, 2] } } } )proto"); 538 539 ExpectInvalid(Builder().Attr("a", {17}), 540 "Length for attr 'a' of 1 must be at least minimum 2"); 541 } 542 543 TEST_F(NodeDefBuilderTest, AttrEnum) { 544 Op(OpDefBuilder("AttrEnum").Attr("a: {'apples', 'oranges'}")); 545 546 ExpectSuccess(Builder().Attr("a", "oranges"), {}, {}, R"proto( 547 op: "AttrEnum" 548 attr { key: "a" value { s: "oranges" } } )proto"); 549 550 ExpectInvalid( 551 Builder().Attr("a", "invalid"), 552 "Value for attr 'a' of \"invalid\" is not in the list of allowed values: " 553 "\"apples\", \"oranges\""); 554 } 555 556 TEST_F(NodeDefBuilderTest, AttrEnumList) { 557 Op(OpDefBuilder("AttrEnumList").Attr("a: list({'apples', 'oranges'})")); 558 559 ExpectSuccess(Builder().Attr("a", {"oranges", "apples"}), {}, {}, R"proto( 560 op: "AttrEnumList" 561 attr { key: "a" value { list { s: ["oranges", "apples"] } } } )proto"); 562 563 ExpectInvalid( 564 Builder().Attr("a", {"apples", "invalid", "oranges"}), 565 "Value for attr 'a' of \"invalid\" is not in the list of allowed values: " 566 "\"apples\", \"oranges\""); 567 } 568 569 TEST_F(NodeDefBuilderTest, AttrShape) { 570 Op(OpDefBuilder("AttrShape").Attr("a: shape")); 571 572 ExpectSuccess(Builder().Attr("a", TensorShape({5})), {}, {}, R"proto( 573 op: "AttrShape" 574 attr { key: "a" value { shape { dim { size: 5 } } } } )proto"); 575 576 ExpectSuccess(Builder().Attr("a", TensorShape({4, 3, 2})), {}, {}, R"proto( 577 op: "AttrShape" 578 attr { key: "a" value { shape { 579 dim { size: 4 } dim { size: 3 } dim { size: 2 } } } } )proto"); 580 581 ExpectSuccess(Builder().Attr("a", TensorShape({3, 2})), {}, {}, 582 R"proto( 583 op: "AttrShape" 584 attr { key: "a" value { shape { 585 dim { size: 3 } dim { size: 2 } } } } )proto"); 586 587 ExpectSuccess(Builder().Attr("a", TensorShape()), {}, {}, R"proto( 588 op: "AttrShape" 589 attr { key: "a" value { shape { } } } )proto"); 590 } 591 592 TEST_F(NodeDefBuilderTest, AttrDefault) { 593 Op(OpDefBuilder("AttrDefault").Attr("a: string = 'banana'")); 594 595 ExpectSuccess(Builder(), {}, {}, R"proto( 596 op: "AttrDefault" 597 attr { key: "a" value { s: "banana" } } )proto"); 598 599 ExpectSuccess(Builder().Attr("a", "kiwi"), {}, {}, R"proto( 600 op: "AttrDefault" 601 attr { key: "a" value { s: "kiwi" } } )proto"); 602 } 603 604 TEST_F(NodeDefBuilderTest, AttrManyDefault) { 605 Op(OpDefBuilder("AttrManyDefault") 606 .Attr("a: string = 'banana'") 607 .Attr("b: string = 'kiwi'")); 608 609 ExpectSuccess(Builder(), {}, {}, R"proto( 610 op: "AttrManyDefault" 611 attr { key: "a" value { s: "banana" } } 612 attr { key: "b" value { s: "kiwi" } } )proto"); 613 614 Op(OpDefBuilder("AttrManyDefaultWithMandatory") 615 .Attr("a: string = 'banana'") 616 .Attr("b: string = 'kiwi'") 617 .Attr("c: string")); 618 619 ExpectSuccess(Builder().Attr("c", "strawberry"), {}, {}, R"proto( 620 op: "AttrManyDefaultWithMandatory" 621 attr { key: "c" value { s: "strawberry" } } 622 attr { key: "a" value { s: "banana" } } 623 attr { key: "b" value { s: "kiwi" } } )proto"); 624 625 Op(OpDefBuilder("AttrManyDefaultAndInferred") 626 .Input("input: T") 627 .Attr("T: {float, double}") 628 .Attr("a: string") 629 .Attr("b: list(string) >= 1") 630 .Attr("c: bool = true") 631 .Attr("d: float = 0.3") 632 .Attr("e: string") 633 .Attr("f: float = 0.25")); 634 635 ExpectSuccess(Builder() 636 .Input(FakeInput(DT_FLOAT)) 637 .Attr("a", "foo") 638 .Attr("e", "foo") 639 .Attr("b", std::vector<string>({"bar", "baz"})) 640 .Attr("f", 1.0f), 641 {DT_FLOAT}, {}, R"proto( 642 op: "AttrManyDefaultAndInferred" 643 input: "a" 644 attr { key: "T" value { type: DT_FLOAT } } 645 attr { key: "a" value { s: "foo" } } 646 attr { key: "e" value { s: "foo" } } 647 attr { key: "b" value { list { s: "bar" s: "baz" } } } 648 attr { key: "f" value { f: 1.0 } } 649 attr { key: "c" value { b: true } } 650 attr { key: "d" value { f: 0.3 } } )proto"); 651 } 652 653 TEST_F(NodeDefBuilderTest, AttrListDefault) { 654 Op(OpDefBuilder("AttrListDefault").Attr("a: list(int) = [5, 15]")); 655 656 ExpectSuccess(Builder(), {}, {}, R"proto( 657 op: "AttrListDefault" 658 attr { key: "a" value { list { i: [5, 15] } } } )proto"); 659 660 ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto( 661 op: "AttrListDefault" 662 attr { key: "a" value { list { i: 3 } } } )proto"); 663 664 ExpectSuccess(Builder().Attr("a", std::vector<int>()), {}, {}, R"proto( 665 op: "AttrListDefault" 666 attr { key: "a" value { list { } } } )proto"); 667 } 668 669 TEST_F(NodeDefBuilderTest, AttrEmptyListDefault) { 670 Op(OpDefBuilder("AttrEmptyListDefault").Attr("a: list(int) = []")); 671 672 ExpectSuccess(Builder(), {}, {}, R"proto( 673 op: "AttrEmptyListDefault" 674 attr { key: "a" value { list { } } } )proto"); 675 676 ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto( 677 op: "AttrEmptyListDefault" 678 attr { key: "a" value { list { i: 3 } } } )proto"); 679 680 ExpectSuccess(Builder().Attr("a", std::vector<int>()), {}, {}, R"proto( 681 op: "AttrEmptyListDefault" 682 attr { key: "a" value { list { } } } )proto"); 683 } 684 685 TEST_F(NodeDefBuilderTest, NIntsIn) { 686 Op(OpDefBuilder("NIntsIn").Input("a: N*int32").Attr("N: int >= 2")); 687 688 ExpectSuccess(Builder().Input(FakeInput(2)), {DT_INT32, DT_INT32}, {}, 689 R"proto( 690 op: "NIntsIn" input: ["a", "a:1"] 691 attr { key: "N" value { i: 2 } } )proto"); 692 693 ExpectSuccess(Builder().Input(FakeInput(5, DT_INT32)), 694 {DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( 695 op: "NIntsIn" 696 input: ["a", "a:1", "a:2", "a:3", "a:4"] 697 attr { key: "N" value { i: 5 } } )proto"); 698 699 ExpectFailures(Builder().Input(FakeInput(2, DT_STRING)), 700 {"2 errors while building NodeDef", 701 "Input 'a' passed string expected int32"}); 702 703 ExpectInvalid(Builder().Input(FakeInput(1)), 704 "Value for attr 'N' of 1 must be at least minimum 2"); 705 706 ExpectFailures( 707 Builder().Input(FakeInput(DT_INT32)), 708 {"2 errors while building NodeDef", 709 "Could not infer length of input 'a': No attr named 'N' in NodeDef:", 710 "0 inputs specified of 1 inputs in Op"}); 711 712 ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}), 713 "Input 'a' passed string expected int32 while"); 714 715 ExpectFailures( 716 Builder().Input(FakeInput()), 717 {"2 errors while building NodeDef", 718 "Could not infer length of input 'a': No attr named 'N' in NodeDef:", 719 "0 inputs specified of 1 inputs in Op"}); 720 } 721 722 TEST_F(NodeDefBuilderTest, NPolymorphicIn) { 723 Op(OpDefBuilder("NPolymorphicIn") 724 .Input("a: N*T") 725 .Attr("T: type") 726 .Attr("N: int >= 2")); 727 728 ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)), {DT_INT32, DT_INT32}, 729 {}, R"proto( 730 op: "NPolymorphicIn" input: ["a", "a:1"] 731 attr { key: "N" value { i: 2 } } 732 attr { key: "T" value { type: DT_INT32 } } )proto"); 733 734 ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)), 735 {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto( 736 op: "NPolymorphicIn" 737 input: ["a", "a:1", "a:2"] 738 attr { key: "N" value { i: 3 } } 739 attr { key: "T" value { type: DT_STRING } } )proto"); 740 741 ExpectFailures( 742 Builder().Input(FakeInput(2)), 743 {"2 errors while building NodeDef", 744 "Could not infer type for input 'a': No attr named 'T' in NodeDef:", 745 "0 inputs specified of 1 inputs in Op"}); 746 747 ExpectFailure(Builder().Input(FakeInput({DT_INT32, DT_STRING})), 748 "Input 'a' passed string expected int32 while"); 749 750 ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}), 751 "Input 'a' passed string expected int32 while"); 752 753 ExpectInvalid(Builder().Input(FakeInput(1, DT_INT32)), 754 "Value for attr 'N' of 1 must be at least minimum 2"); 755 756 ExpectFailure(Builder().Input("in", 0, DT_INT32), 757 "Single tensor passed to 'a', expected list while"); 758 } 759 760 TEST_F(NodeDefBuilderTest, NPolymorphicRestrictIn) { 761 Op(OpDefBuilder("NPolymorphicRestrictIn") 762 .Input("a: N*T") 763 .Attr("T: {string, bool}") 764 .Attr("N: int >= 2")); 765 766 ExpectSuccess(Builder().Input(FakeInput(2, DT_BOOL)), {DT_BOOL, DT_BOOL}, {}, 767 R"proto( 768 op: "NPolymorphicRestrictIn" input: ["a", "a:1"] 769 attr { key: "N" value { i: 2 } } 770 attr { key: "T" value { type: DT_BOOL } } )proto"); 771 772 ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)), 773 {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto( 774 op: "NPolymorphicRestrictIn" 775 input: ["a", "a:1", "a:2"] 776 attr { key: "N" value { i: 3 } } 777 attr { key: "T" value { type: DT_STRING } } )proto"); 778 779 ExpectInvalid(Builder().Input(FakeInput(2, DT_INT32)), 780 "Value for attr 'T' of int32 is not in the list of allowed " 781 "values: string, bool"); 782 } 783 784 TEST_F(NodeDefBuilderTest, NInTwice) { 785 Op(OpDefBuilder("NInTwice") 786 .Input("a: N*int32") 787 .Input("b: N*string") 788 .Attr("N: int >= 0")); 789 790 ExpectSuccess(Builder().Input(FakeInput(2)).Input(FakeInput(2)), 791 {DT_INT32, DT_INT32, DT_STRING, DT_STRING}, {}, R"proto( 792 op: "NInTwice" 793 input: ["a", "a:1", "b", "b:1"] 794 attr { key: "N" value { i: 2 } } )proto"); 795 796 ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {}, 797 R"proto( 798 op: "NInTwice" attr { key: "N" value { i: 0 } } )proto"); 799 800 ExpectFailure(Builder().Input(FakeInput(3)).Input(FakeInput(1)), 801 "Inconsistent values for attr 'N' 3 vs. 1 while"); 802 } 803 804 TEST_F(NodeDefBuilderTest, NInPolymorphicTwice) { 805 Op(OpDefBuilder("NInPolymorphicTwice") 806 .Input("a: N*T") 807 .Input("b: N*T") 808 .Attr("T: type") 809 .Attr("N: int >= 0")); 810 811 ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput()), 812 {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( 813 op: "NInPolymorphicTwice" 814 input: ["a", "a:1", "b", "b:1"] 815 attr { key: "N" value { i: 2 } } 816 attr { key: "T" value { type: DT_INT32 } } )proto"); 817 818 ExpectFailure( 819 Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_INT32)), 820 "Inconsistent values for attr 'N' 3 vs. 1 while"); 821 822 ExpectFailure(Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1)), 823 "Inconsistent values for attr 'N' 3 vs. 1 while"); 824 825 ExpectFailure( 826 Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)), 827 "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); 828 829 ExpectFailure( 830 Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_STRING)), 831 "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); 832 } 833 834 TEST_F(NodeDefBuilderTest, NInTwoTypeVariables) { 835 Op(OpDefBuilder("NInTwoTypeVariables") 836 .Input("a: N*S") 837 .Input("b: N*T") 838 .Attr("S: type") 839 .Attr("T: type") 840 .Attr("N: int >= 0")); 841 842 ExpectSuccess( 843 Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_BOOL)), 844 {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto( 845 op: "NInTwoTypeVariables" 846 input: ["a", "a:1", "b", "b:1"] 847 attr { key: "N" value { i: 2 } } 848 attr { key: "S" value { type: DT_INT32 } } 849 attr { key: "T" value { type: DT_BOOL } } )proto"); 850 851 ExpectSuccess( 852 Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_BOOL)), 853 {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto( 854 op: "NInTwoTypeVariables" 855 input: ["a", "a:1", "b", "b:1"] 856 attr { key: "N" value { i: 2 } } 857 attr { key: "S" value { type: DT_INT32 } } 858 attr { key: "T" value { type: DT_BOOL } } )proto"); 859 860 ExpectFailure( 861 Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_STRING)), 862 "Inconsistent values for attr 'N' 3 vs. 1 while"); 863 } 864 865 TEST_F(NodeDefBuilderTest, InPolymorphicTwice) { 866 Op(OpDefBuilder("InPolymorphicTwice") 867 .Input("a: N*T") 868 .Input("b: M*T") 869 .Attr("T: type") 870 .Attr("N: int >= 0") 871 .Attr("M: int >= 0")); 872 873 ExpectSuccess( 874 Builder().Input(FakeInput(1, DT_INT32)).Input(FakeInput(3, DT_INT32)), 875 {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( 876 op: "InPolymorphicTwice" 877 input: ["a", "b", "b:1", "b:2"] 878 attr { key: "N" value { i: 1 } } 879 attr { key: "T" value { type: DT_INT32 } } 880 attr { key: "M" value { i: 3 } } )proto"); 881 882 ExpectSuccess(Builder().Input(FakeInput(1, DT_BOOL)).Input(FakeInput(0)), 883 {DT_BOOL}, {}, R"proto( 884 op: "InPolymorphicTwice" input: "a" 885 attr { key: "N" value { i: 1 } } 886 attr { key: "T" value { type: DT_BOOL } } 887 attr { key: "M" value { i: 0 } } )proto"); 888 889 ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(1, DT_BOOL)), 890 {DT_BOOL}, {}, R"proto( 891 op: "InPolymorphicTwice" input: "b" 892 attr { key: "N" value { i: 0 } } 893 attr { key: "M" value { i: 1 } } 894 attr { key: "T" value { type: DT_BOOL } } )proto"); 895 896 ExpectFailure( 897 Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)), 898 "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); 899 } 900 901 TEST_F(NodeDefBuilderTest, NIntsOut) { 902 Op(OpDefBuilder("NIntsOut").Output("a: N*int32").Attr("N: int >= 2")); 903 904 ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto( 905 op: "NIntsOut" 906 attr { key: "N" value { i: 2 } } )proto"); 907 908 ExpectSuccess(Builder().Attr("N", 3), {}, {DT_INT32, DT_INT32, DT_INT32}, 909 R"proto( 910 op: "NIntsOut" 911 attr { key: "N" value { i: 3 } } )proto"); 912 913 ExpectInvalid(Builder().Attr("N", 1), 914 "Value for attr 'N' of 1 must be at least minimum 2"); 915 916 ExpectInvalid( 917 Builder().Attr("N", {3}), 918 "AttrValue had value with type 'list(int)' when 'int' expected"); 919 920 ExpectInvalid(Builder(), "NodeDef missing attr 'N' from"); 921 } 922 923 TEST_F(NodeDefBuilderTest, NIntsOutDefault) { 924 Op(OpDefBuilder("NIntsOutDefault") 925 .Output("a: N*int32") 926 .Attr("N: int >= 2 = 3")); 927 928 ExpectSuccess(Builder(), {}, {DT_INT32, DT_INT32, DT_INT32}, R"proto( 929 op: "NIntsOutDefault" 930 attr { key: "N" value { i: 3 } } )proto"); 931 932 ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto( 933 op: "NIntsOutDefault" 934 attr { key: "N" value { i: 2 } } )proto"); 935 } 936 937 TEST_F(NodeDefBuilderTest, NPolymorphicOut) { 938 Op(OpDefBuilder("NPolymorphicOut") 939 .Output("a: N*T") 940 .Attr("T: type") 941 .Attr("N: int >= 2")); 942 943 ExpectSuccess(Builder().Attr("T", DT_INT32).Attr("N", 2), {}, 944 {DT_INT32, DT_INT32}, R"proto( 945 op: "NPolymorphicOut" 946 attr { key: "T" value { type: DT_INT32 } } 947 attr { key: "N" value { i: 2 } } )proto"); 948 949 ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_STRING), {}, 950 {DT_STRING, DT_STRING, DT_STRING}, R"proto( 951 op: "NPolymorphicOut" 952 attr { key: "N" value { i: 3 } } 953 attr { key: "T" value { type: DT_STRING } } )proto"); 954 955 ExpectInvalid(Builder().Attr("N", 1).Attr("T", DT_STRING), 956 "Value for attr 'N' of 1 must be at least minimum 2"); 957 958 ExpectInvalid( 959 Builder().Attr("N", 3).Attr("T", {DT_STRING}), 960 "AttrValue had value with type 'list(type)' when 'type' expected"); 961 } 962 963 TEST_F(NodeDefBuilderTest, NPolymorphicOutDefault) { 964 Op(OpDefBuilder("NPolymorphicOutDefault") 965 .Output("a: N*T") 966 .Attr("T: type = DT_BOOL") 967 .Attr("N: int >= 2 = 2")); 968 969 ExpectSuccess(Builder(), {}, {DT_BOOL, DT_BOOL}, R"proto( 970 op: "NPolymorphicOutDefault" 971 attr { key: "T" value { type: DT_BOOL } } 972 attr { key: "N" value { i: 2 } } )proto"); 973 974 ExpectSuccess(Builder().Attr("N", 3), {}, {DT_BOOL, DT_BOOL, DT_BOOL}, 975 R"proto( 976 op: "NPolymorphicOutDefault" 977 attr { key: "N" value { i: 3 } } 978 attr { key: "T" value { type: DT_BOOL } } )proto"); 979 980 ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32, DT_INT32}, 981 R"proto( 982 op: "NPolymorphicOutDefault" 983 attr { key: "T" value { type: DT_INT32 } } 984 attr { key: "N" value { i: 2 } } )proto"); 985 986 ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_INT32), {}, 987 {DT_INT32, DT_INT32, DT_INT32}, R"proto( 988 op: "NPolymorphicOutDefault" 989 attr { key: "N" value { i: 3 } } 990 attr { key: "T" value { type: DT_INT32 } } )proto"); 991 } 992 993 TEST_F(NodeDefBuilderTest, NPolymorphicRestrictOut) { 994 Op(OpDefBuilder("NPolymorphicRestrictOut") 995 .Output("a: N*T") 996 .Attr("T: {string, bool}") 997 .Attr("N: int >= 2")); 998 999 ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_BOOL), {}, 1000 {DT_BOOL, DT_BOOL, DT_BOOL}, R"proto( 1001 op: "NPolymorphicRestrictOut" 1002 attr { key: "N" value { i: 3 } } 1003 attr { key: "T" value { type: DT_BOOL } } )proto"); 1004 1005 ExpectInvalid(Builder().Attr("N", 3).Attr("T", DT_INT32), 1006 "Value for attr 'T' of int32 is not in the list of allowed " 1007 "values: string, bool"); 1008 } 1009 1010 TEST_F(NodeDefBuilderTest, RefIn) { 1011 Op(OpDefBuilder("RefIn").Input("a: Ref(int32)")); 1012 1013 ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32_REF}, {}, 1014 R"proto( 1015 op: "RefIn" input: "a" )proto"); 1016 1017 ExpectFailure(Builder().Input(FakeInput(DT_BOOL_REF)), 1018 "Input 'a' passed bool_ref expected int32_ref while"); 1019 1020 ExpectFailure(Builder().Input(FakeInput(DT_INT32)), 1021 "Input 'a' passed int32 expected int32_ref while"); 1022 } 1023 1024 TEST_F(NodeDefBuilderTest, PolymorphicRefIn) { 1025 Op(OpDefBuilder("PolymorphicRefIn").Input("a: Ref(T)").Attr("T: type")); 1026 1027 ExpectSuccess(Builder().Input(FakeInput(DT_BOOL_REF)), {DT_BOOL_REF}, {}, 1028 R"proto( 1029 op: "PolymorphicRefIn" input: "a" 1030 attr { key: "T" value { type: DT_BOOL } } )proto"); 1031 1032 ExpectFailure(Builder().Input(FakeInput(DT_BOOL)), 1033 "Input 'a' passed bool expected ref type while"); 1034 } 1035 1036 TEST_F(NodeDefBuilderTest, RefOut) { 1037 Op(OpDefBuilder("RefOut").Output("a: Ref(string)")); 1038 1039 ExpectSuccess(Builder(), {}, {DT_STRING_REF}, R"proto( 1040 op: "RefOut" )proto"); 1041 } 1042 1043 TEST_F(NodeDefBuilderTest, PolymorphicRefOut) { 1044 Op(OpDefBuilder("PolymorphicRefOut").Output("a: Ref(t)").Attr("t: type")); 1045 1046 ExpectSuccess(Builder().Attr("t", DT_BOOL), {}, {DT_BOOL_REF}, R"proto( 1047 op: "PolymorphicRefOut" 1048 attr { key: "t" value { type: DT_BOOL } } )proto"); 1049 } 1050 1051 TEST_F(NodeDefBuilderTest, SpecifyDevice) { 1052 Op(OpDefBuilder("SpecifyDevice")); 1053 1054 ExpectSuccess(Builder().Device("ADevice"), {}, {}, R"proto( 1055 op: "SpecifyDevice" device: "ADevice" )proto"); 1056 } 1057 1058 } // namespace 1059 } // namespace tensorflow 1060