Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/node_def_builder.h"
     17 
     18 #include <memory>
     19 #include <vector>
     20 #include "tensorflow/core/framework/fake_input.h"
     21 #include "tensorflow/core/framework/node_def_util.h"
     22 #include "tensorflow/core/framework/op_def_builder.h"
     23 #include "tensorflow/core/framework/op_def_util.h"
     24 #include "tensorflow/core/lib/core/status_test_util.h"
     25 #include "tensorflow/core/platform/protobuf.h"
     26 #include "tensorflow/core/platform/test.h"
     27 
     28 namespace tensorflow {
     29 namespace {
     30 
     31 class NodeDefBuilderTest : public ::testing::Test {
     32  protected:
     33   // Specify an OpDef via an OpDefBuilder.
     34   void Op(const OpDefBuilder& op_def_builder) {
     35     OpRegistrationData op_reg_data;
     36     TF_EXPECT_OK(op_def_builder.Finalize(&op_reg_data));
     37     op_def_ = op_reg_data.op_def;
     38   }
     39 
     40   // Resets builder_ with a new NodeDefBuilder using the Op from the last call
     41   // to Op() above.
     42   NodeDefBuilder& Builder() {
     43     EXPECT_FALSE(op_def_.name().empty()) << "Must call Op() before Builder()";
     44     builder_.reset(new NodeDefBuilder("n", &op_def_));
     45     return *builder_;
     46   }
     47 
     48   // Calls Finalize() and verifies it returns success and the result matches
     49   // expectations.
     50   void ExpectSuccess(const NodeDefBuilder& builder,
     51                      DataTypeSlice expected_in_types,
     52                      DataTypeSlice expected_out_types, StringPiece proto) {
     53     NodeDef node_def;
     54     Status status = builder.Finalize(&node_def);
     55     TF_EXPECT_OK(status);
     56     if (!status.ok()) return;
     57     NodeDef expected;
     58     protobuf::TextFormat::ParseFromString(strings::StrCat("name: 'n' ", proto),
     59                                           &expected);
     60     EXPECT_EQ(node_def.DebugString(), expected.DebugString());
     61 
     62     DataTypeVector in_types, out_types;
     63     status =
     64         InOutTypesForNode(node_def, builder.op_def(), &in_types, &out_types);
     65     TF_EXPECT_OK(status);
     66     if (!status.ok()) return;
     67     EXPECT_EQ(DataTypeSliceString(expected_in_types),
     68               DataTypeVectorString(in_types));
     69     EXPECT_EQ(DataTypeSliceString(expected_out_types),
     70               DataTypeVectorString(out_types));
     71 
     72     status = ValidateNodeDef(node_def, op_def_);
     73     TF_EXPECT_OK(status);
     74   }
     75 
     76   // Calls Finalize() and verifies it returns an error.
     77   // Each message must appear as a substring of the error.
     78   void ExpectFailures(const NodeDefBuilder& builder,
     79                       const std::vector<string>& messages) {
     80     NodeDef node_def;
     81     Status status = builder.Finalize(&node_def);
     82     EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def);
     83     if (status.ok()) return;
     84     for (const string& message : messages) {
     85       EXPECT_TRUE(StringPiece(status.error_message()).contains(message))
     86           << status << ", " << message;
     87     }
     88   }
     89 
     90   // Calls Finalize() and verifies it returns an error.
     91   // Message must appear as a substring of the error.
     92   void ExpectFailure(const NodeDefBuilder& builder, const string& message) {
     93     ExpectFailures(builder, {message});
     94   }
     95 
     96   // Like ExpectFailure(), except that the error can come from
     97   // ValidateNodeDef().
     98   void ExpectInvalid(const NodeDefBuilder& builder, const string& message) {
     99     NodeDef node_def;
    100     Status status = builder.Finalize(&node_def);
    101     if (status.ok()) {
    102       status = ValidateNodeDef(node_def, op_def_);
    103     }
    104     EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def);
    105     if (status.ok()) return;
    106     EXPECT_TRUE(StringPiece(status.error_message()).contains(message))
    107         << "Actual error: " << status.error_message()
    108         << "\nDoes not contain: " << message;
    109   }
    110 
    111   OpDef op_def_;
    112   std::unique_ptr<NodeDefBuilder> builder_;
    113 };
    114 
    115 TEST_F(NodeDefBuilderTest, Simple) {
    116   Op(OpDefBuilder("Simple").Input("a: int32").Output("out: float"));
    117 
    118   ExpectSuccess(Builder().Input("x", 0, DT_INT32), {DT_INT32}, {DT_FLOAT},
    119                 R"proto( op: "Simple" input: "x" )proto");
    120 
    121   // Port != 0
    122   ExpectSuccess(Builder().Input("y", 2, DT_INT32), {DT_INT32}, {DT_FLOAT},
    123                 R"proto( op: "Simple" input: "y:2" )proto");
    124 
    125   // FakeInput
    126   ExpectSuccess(Builder().Input(FakeInput()), {DT_INT32}, {DT_FLOAT}, R"proto(
    127       op: "Simple" input: "a" )proto");
    128 
    129   ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_FLOAT},
    130                 R"proto( op: "Simple" input: "a" )proto");
    131 
    132   // Ref input
    133   ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32},
    134                 {DT_FLOAT}, R"proto( op: "Simple" input: "a" )proto");
    135 
    136   // ControlInput
    137   ExpectSuccess(
    138       Builder().ControlInput("x").Input(FakeInput()).ControlInput("y"),
    139       {DT_INT32}, {DT_FLOAT}, R"proto(
    140       op: "Simple" input: ["a", "^x", "^y"] )proto");
    141 
    142   // Device
    143   ExpectSuccess(Builder().Input(FakeInput()).Device("ddd"), {DT_INT32},
    144                 {DT_FLOAT}, R"proto(
    145       op: "Simple" input: "a" device: "ddd" )proto");
    146 
    147   // Extra input
    148   ExpectFailure(Builder().Input("x", 0, DT_INT32).Input("y", 0, DT_INT32),
    149                 "More Input() calls than the 1 input_args while building "
    150                 "NodeDef 'n' using Op<name=Simple; signature=a:int32 -> "
    151                 "out:float>");
    152 
    153   // Missing input
    154   ExpectFailure(Builder(), "0 inputs specified of 1 inputs in Op while");
    155 
    156   {  // Finalize() twice.
    157     NodeDefBuilder& builder = Builder();
    158     // First call to Finalize()
    159     TF_EXPECT_OK(builder.Input(FakeInput()).Finalize(nullptr));
    160     // ExpectSuccess() also calls Finalize().
    161     ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto(
    162         op: "Simple" input: "a" )proto");
    163   }
    164 
    165   {  // Input() after Finalize()
    166     NodeDefBuilder& builder = Builder();
    167     // Calling Finalize() before enough inputs -> error.
    168     ExpectFailure(builder, "0 inputs specified of 1 inputs in Op while");
    169     builder.Input(FakeInput());
    170     // Calling Finalize() with enough inputs -> success
    171     ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto(
    172         op: "Simple" input: "a" )proto");
    173     // Calling Finalize() with too many inputs -> error.
    174     builder.Input(FakeInput(DT_INT32));
    175     ExpectFailure(builder, "More Input() calls than the 1 input_args while");
    176   }
    177 
    178   // Wrong input type
    179   ExpectFailure(Builder().Input("x", 0, DT_FLOAT),
    180                 "Input 'a' passed float expected int32 ");
    181 
    182   ExpectFailure(Builder().Input("x", 0, DT_FLOAT_REF),
    183                 "Input 'a' passed float_ref expected int32 ");
    184 
    185   // List input
    186   ExpectFailure(Builder().Input(FakeInput(3, DT_FLOAT)),
    187                 "List provided to input 'a' when single Tensor expected while");
    188 
    189   ExpectFailure(Builder().Input(FakeInput(3)),
    190                 "List provided to input 'a' when single Tensor expected while");
    191 
    192   // Bad ControlInput
    193   ExpectInvalid(Builder().Input(FakeInput()).ControlInput("z:2"),
    194                 "Control input '^z:2' must not have ':' in NodeDef:");
    195 
    196   // Bad input name
    197   ExpectFailure(Builder().Input("", 0, DT_INT32),
    198                 "Empty input node name while");
    199 
    200   ExpectFailure(Builder().Input("^x", 0, DT_INT32),
    201                 "Non-control input starting with ^: ^x while");
    202 }
    203 
    204 TEST_F(NodeDefBuilderTest, OpDoesNotExist) {
    205   NodeDefBuilder builder("n", "Op Does Not Exist");
    206   builder.Input(FakeInput())
    207       .Input(FakeInput(12))
    208       .ControlInput("y")
    209       .Attr("foo", 12)
    210       .Device("device");
    211   ExpectFailures(builder, {"Op type not registered 'Op Does Not Exist'",
    212                            "while building NodeDef 'n'"});
    213 }
    214 
    215 TEST_F(NodeDefBuilderTest, Polymorphic) {
    216   Op(OpDefBuilder("Polymorphic")
    217          .Input("v: T")
    218          .Output("out: T")
    219          .Attr("T: type"));
    220 
    221   ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_INT32},
    222                 R"proto(
    223       op: "Polymorphic" input: "a"
    224       attr { key: "T" value { type: DT_INT32 } } )proto");
    225 
    226   ExpectSuccess(Builder().Input(FakeInput(DT_FLOAT)), {DT_FLOAT}, {DT_FLOAT},
    227                 R"proto(
    228       op: "Polymorphic" input: "a"
    229       attr { key: "T" value { type: DT_FLOAT } } )proto");
    230 
    231   // Redundant Attr()
    232   ExpectSuccess(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_BOOL),
    233                 {DT_BOOL}, {DT_BOOL}, R"proto(
    234       op: "Polymorphic" input: "a"
    235       attr { key: "T" value { type: DT_BOOL } } )proto");
    236 
    237   // Conficting Attr()
    238   ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_STRING),
    239                 "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while");
    240 
    241   ExpectFailure(Builder().Attr("T", DT_STRING).Input(FakeInput(DT_BOOL)),
    242                 "Inconsistent values for attr 'T' DT_STRING vs. DT_BOOL while");
    243 
    244   ExpectFailure(Builder().Attr("T", 12).Input(FakeInput(DT_BOOL)),
    245                 "Inconsistent values for attr 'T' 12 vs. DT_BOOL while");
    246 }
    247 
    248 TEST_F(NodeDefBuilderTest, PolymorphicOut) {
    249   Op(OpDefBuilder("PolymorphicOut").Output("out: T").Attr("T: type"));
    250 
    251   ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32}, R"proto(
    252       op: "PolymorphicOut"
    253       attr { key: "T" value { type: DT_INT32 } } )proto");
    254 
    255   ExpectSuccess(Builder().Attr("T", DT_FLOAT), {}, {DT_FLOAT}, R"proto(
    256       op: "PolymorphicOut"
    257       attr { key: "T" value { type: DT_FLOAT } } )proto");
    258 
    259   // Redundant attr
    260   ExpectSuccess(Builder().Attr("T", DT_FLOAT).Attr("T", DT_FLOAT), {},
    261                 {DT_FLOAT}, R"proto(
    262       op: "PolymorphicOut"
    263       attr { key: "T" value { type: DT_FLOAT } } )proto");
    264 
    265   // Conflicting attr
    266   ExpectFailure(Builder().Attr("T", DT_BOOL).Attr("T", DT_FLOAT),
    267                 "Inconsistent values for attr 'T' DT_BOOL vs. DT_FLOAT while");
    268 
    269   // Missing attr
    270   ExpectInvalid(Builder(), "NodeDef missing attr 'T' from");
    271 
    272   // Attr has the wrong type
    273   ExpectInvalid(
    274       Builder().Attr("T", {DT_INT32, DT_BOOL}),
    275       "AttrValue had value with type 'list(type)' when 'type' expected");
    276 
    277   ExpectInvalid(Builder().Attr("T", 12),
    278                 "AttrValue had value with type 'int' when 'type' expected");
    279 }
    280 
    281 TEST_F(NodeDefBuilderTest, PolymorphicDefaultOut) {
    282   Op(OpDefBuilder("PolymorphicDefaultOut")
    283          .Output("out: T")
    284          .Attr("T: type = DT_STRING"));
    285 
    286   ExpectSuccess(Builder(), {}, {DT_STRING}, R"proto(
    287       op: "PolymorphicDefaultOut"
    288       attr { key: "T" value { type: DT_STRING } } )proto");
    289 
    290   ExpectSuccess(Builder().Attr("T", DT_BOOL), {}, {DT_BOOL}, R"proto(
    291       op: "PolymorphicDefaultOut"
    292       attr { key: "T" value { type: DT_BOOL } } )proto");
    293 }
    294 
    295 TEST_F(NodeDefBuilderTest, Binary) {
    296   Op(OpDefBuilder("Binary").Input("a: T").Input("b: T").Output("out: T").Attr(
    297       "T: type"));
    298 
    299   ExpectSuccess(Builder().Input(FakeInput(DT_INT32)).Input(FakeInput(DT_INT32)),
    300                 {DT_INT32, DT_INT32}, {DT_INT32}, R"proto(
    301       op: "Binary" input: "a" input: "b"
    302       attr { key: "T" value { type: DT_INT32 } } )proto");
    303 
    304   ExpectSuccess(Builder().Input(FakeInput(DT_STRING)).Input(FakeInput()),
    305                 {DT_STRING, DT_STRING}, {DT_STRING}, R"proto(
    306       op: "Binary" input: "a" input: "b"
    307       attr { key: "T" value { type: DT_STRING } } )proto");
    308 
    309   // Type mismatch
    310   ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Input(FakeInput(DT_STRING)),
    311                 "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while");
    312 }
    313 
    314 TEST_F(NodeDefBuilderTest, Restrict) {
    315   Op(OpDefBuilder("Restrict")
    316          .Input("a: T")
    317          .Output("out: T")
    318          .Attr("T: {string, bool}"));
    319   ExpectSuccess(Builder().Input(FakeInput(DT_STRING)), {DT_STRING}, {DT_STRING},
    320                 R"proto(
    321       op: "Restrict" input: "a"
    322       attr { key: "T" value { type: DT_STRING } } )proto");
    323 
    324   ExpectInvalid(Builder().Input(FakeInput(DT_INT32)),
    325                 "Value for attr 'T' of int32 is not in the list of allowed "
    326                 "values: string, bool");
    327 }
    328 
    329 TEST_F(NodeDefBuilderTest, TypeList) {
    330   Op(OpDefBuilder("TypeList").Input("a: T").Attr("T: list(type)"));
    331 
    332   ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_INT32})),
    333                 {DT_STRING, DT_INT32}, {}, R"proto(
    334       op: "TypeList" input: ["a", "a:1"]
    335       attr { key: "T" value { list { type: [DT_STRING, DT_INT32] } } }
    336       )proto");
    337 
    338   ExpectSuccess(Builder().Input(FakeInput(3, DT_BOOL)),
    339                 {DT_BOOL, DT_BOOL, DT_BOOL}, {}, R"proto(
    340       op: "TypeList" input: ["a", "a:1", "a:2"]
    341       attr { key: "T" value { list { type: [DT_BOOL, DT_BOOL, DT_BOOL] } } }
    342       )proto");
    343 
    344   ExpectInvalid(Builder().Input(FakeInput(0)),
    345                 "Length for attr 'T' of 0 must be at least minimum 1");
    346 
    347   ExpectInvalid(Builder().Input(FakeInput({})),
    348                 "Length for attr 'T' of 0 must be at least minimum 1");
    349 
    350   ExpectInvalid(Builder().Input(FakeInput(DT_BOOL)),
    351                 "Single tensor passed to 'a', expected list while");
    352 
    353   ExpectFailures(Builder().Input(FakeInput()),
    354                  {"2 errors while building NodeDef",
    355                   "Could not infer list of types for input 'a': "
    356                   "No attr named 'T' in NodeDef:",
    357                   "0 inputs specified of 1 inputs in Op"});
    358 }
    359 
    360 TEST_F(NodeDefBuilderTest, TypeListNoMin) {
    361   Op(OpDefBuilder("TypeListNoMin").Input("a: T").Attr("T: list(type) >= 0"));
    362 
    363   ExpectSuccess(Builder().Input(FakeInput(0)), {}, {}, R"proto(
    364       op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto");
    365 
    366   ExpectSuccess(Builder().Input(FakeInput(DataTypeVector())), {}, {}, R"proto(
    367       op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto");
    368 
    369   ExpectSuccess(Builder().Input(FakeInput({})), {}, {}, R"proto(
    370       op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto");
    371 
    372   ExpectSuccess(Builder().Input(FakeInput({DT_BOOL})), {DT_BOOL}, {}, R"proto(
    373       op: "TypeListNoMin" input: "a"
    374       attr { key: "T" value { list { type: DT_BOOL } } } )proto");
    375 }
    376 
    377 TEST_F(NodeDefBuilderTest, TypeListTwice) {
    378   Op(OpDefBuilder("TypeListTwice")
    379          .Input("a: T")
    380          .Input("b: T")
    381          .Attr("T: list(type) >= 0"));
    382 
    383   ExpectSuccess(Builder()
    384                     .Input(FakeInput({DT_INT32, DT_BOOL}))
    385                     .Input(FakeInput({DT_INT32, DT_BOOL})),
    386                 {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto(
    387       op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"]
    388       attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto");
    389 
    390   ExpectSuccess(
    391       Builder().Input(FakeInput({DT_INT32, DT_BOOL})).Input(FakeInput()),
    392       {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto(
    393       op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"]
    394       attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto");
    395 
    396   ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(0)), {}, {},
    397                 R"proto(
    398       op: "TypeListTwice" attr { key: "T" value { list { } } } )proto");
    399 
    400   ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {},
    401                 R"proto(
    402       op: "TypeListTwice" attr { key: "T" value { list { } } } )proto");
    403 
    404   ExpectFailure(Builder()
    405                     .Input(FakeInput({DT_INT32, DT_BOOL}))
    406                     .Input(FakeInput({DT_INT32, DT_STRING})),
    407                 "Inconsistent values for attr 'T' [DT_INT32, DT_BOOL] vs. "
    408                 "[DT_INT32, DT_STRING] while");
    409 }
    410 
    411 TEST_F(NodeDefBuilderTest, OutTypeList) {
    412   Op(OpDefBuilder("OutTypeList").Output("out: T").Attr("T: list(type) >= 0"));
    413 
    414   ExpectSuccess(Builder().Attr("T", {DT_FLOAT}), {}, {DT_FLOAT}, R"proto(
    415       op: "OutTypeList"
    416       attr { key: "T" value { list { type: DT_FLOAT } } } )proto");
    417 
    418   ExpectSuccess(Builder().Attr("T", {DT_STRING, DT_BOOL}), {},
    419                 {DT_STRING, DT_BOOL}, R"proto(
    420       op: "OutTypeList"
    421       attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto");
    422 
    423   ExpectSuccess(Builder().Attr("T", DataTypeVector()), {}, {}, R"proto(
    424       op: "OutTypeList"
    425       attr { key: "T" value { list { } } } )proto");
    426 
    427   ExpectInvalid(
    428       Builder().Attr("T", DT_FLOAT),
    429       "AttrValue had value with type 'type' when 'list(type)' expected");
    430 }
    431 
    432 TEST_F(NodeDefBuilderTest, TypeListRestrict) {
    433   Op(OpDefBuilder("TypeListRestrict")
    434          .Input("a: T")
    435          .Attr("T: list({string, bool}) >= 0"));
    436 
    437   ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_BOOL})),
    438                 {DT_STRING, DT_BOOL}, {}, R"proto(
    439       op: "TypeListRestrict" input: ["a", "a:1"]
    440       attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto");
    441 
    442   ExpectInvalid(Builder().Input(FakeInput({DT_STRING, DT_INT32})),
    443                 "Value for attr 'T' of int32 is not in the list of allowed "
    444                 "values: string, bool");
    445 }
    446 
    447 TEST_F(NodeDefBuilderTest, OutTypeListRestrict) {
    448   Op(OpDefBuilder("OutTypeListRestrict")
    449          .Output("out: t")
    450          .Attr("t: list({string, bool}) >= 0"));
    451 
    452   ExpectSuccess(Builder().Attr("t", {DT_BOOL, DT_STRING}), {},
    453                 {DT_BOOL, DT_STRING}, R"proto(
    454       op: "OutTypeListRestrict"
    455       attr { key: "t" value { list { type: [DT_BOOL, DT_STRING] } } } )proto");
    456 
    457   ExpectInvalid(Builder().Attr("t", {DT_STRING, DT_INT32}),
    458                 "Value for attr 't' of int32 is not in the list of allowed "
    459                 "values: string, bool");
    460 }
    461 
    462 TEST_F(NodeDefBuilderTest, Attr) {
    463   Op(OpDefBuilder("Attr").Attr("a: int"));
    464 
    465   ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto(
    466       op: "Attr" attr { key: "a" value { i: 12 } } )proto");
    467 
    468   // Attr has wrong type
    469   ExpectInvalid(Builder().Attr("a", "bad"),
    470                 "AttrValue had value with type 'string' when 'int' expected");
    471 
    472   ExpectInvalid(
    473       Builder().Attr("a", {12}),
    474       "AttrValue had value with type 'list(int)' when 'int' expected");
    475 
    476   // Missing attr
    477   ExpectInvalid(Builder(), "NodeDef missing attr 'a' from Op<");
    478 
    479   // Wrong attr
    480   ExpectInvalid(Builder().Attr("b", 12),
    481                 "NodeDef mentions attr 'b' not in Op<");
    482 
    483   // Extra attr
    484   ExpectInvalid(Builder().Attr("a", 12).Attr("extra", 12),
    485                 "NodeDef mentions attr 'extra' not in Op<");
    486 }
    487 
    488 TEST_F(NodeDefBuilderTest, AttrFloat) {
    489   Op(OpDefBuilder("AttrFloat").Attr("a: float"));
    490 
    491   ExpectSuccess(Builder().Attr("a", 1.2f /* float */), {}, {}, R"proto(
    492       op: "AttrFloat" attr { key: "a" value { f: 1.2 } }
    493       )proto");
    494 
    495   ExpectSuccess(Builder().Attr("a", 1.2 /* double */), {}, {}, R"proto(
    496       op: "AttrFloat" attr { key: "a" value { f: 1.2 } }
    497       )proto");
    498 
    499   // Won't automatically cast int to float
    500   ExpectInvalid(Builder().Attr("a", 12),
    501                 "AttrValue had value with type 'int' when 'float' expected");
    502 }
    503 
    504 TEST_F(NodeDefBuilderTest, AttrBoolList) {
    505   Op(OpDefBuilder("AttrBoolList").Attr("a: list(bool)"));
    506 
    507   ExpectSuccess(Builder().Attr("a", {true, false, true}), {}, {}, R"proto(
    508       op: "AttrBoolList"
    509       attr { key: "a" value { list { b: [true, false, true] } } }
    510       )proto");
    511 
    512   ExpectSuccess(Builder().Attr("a", std::vector<bool>()), {}, {}, R"proto(
    513       op: "AttrBoolList" attr { key: "a" value { list { } } }
    514       )proto");
    515 
    516   // Won't cast int -> bool.
    517   ExpectInvalid(Builder().Attr("a", {0}),
    518                 "AttrValue had value with type 'list(int)' when 'list(bool)' "
    519                 "expected");
    520 }
    521 
    522 TEST_F(NodeDefBuilderTest, AttrMin) {
    523   Op(OpDefBuilder("AttrMin").Attr("a: int >= 5"));
    524 
    525   ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto(
    526       op: "AttrMin" attr { key: "a" value { i: 12 } } )proto");
    527 
    528   ExpectInvalid(Builder().Attr("a", 2),
    529                 "Value for attr 'a' of 2 must be at least minimum 5");
    530 }
    531 
    532 TEST_F(NodeDefBuilderTest, AttrListMin) {
    533   Op(OpDefBuilder("AttrListMin").Attr("a: list(int) >= 2"));
    534 
    535   ExpectSuccess(Builder().Attr("a", {1, 2}), {}, {}, R"proto(
    536       op: "AttrListMin"
    537       attr { key: "a" value { list { i: [1, 2] } } } )proto");
    538 
    539   ExpectInvalid(Builder().Attr("a", {17}),
    540                 "Length for attr 'a' of 1 must be at least minimum 2");
    541 }
    542 
    543 TEST_F(NodeDefBuilderTest, AttrEnum) {
    544   Op(OpDefBuilder("AttrEnum").Attr("a: {'apples', 'oranges'}"));
    545 
    546   ExpectSuccess(Builder().Attr("a", "oranges"), {}, {}, R"proto(
    547       op: "AttrEnum"
    548       attr { key: "a" value { s: "oranges" } } )proto");
    549 
    550   ExpectInvalid(
    551       Builder().Attr("a", "invalid"),
    552       "Value for attr 'a' of \"invalid\" is not in the list of allowed values: "
    553       "\"apples\", \"oranges\"");
    554 }
    555 
    556 TEST_F(NodeDefBuilderTest, AttrEnumList) {
    557   Op(OpDefBuilder("AttrEnumList").Attr("a: list({'apples', 'oranges'})"));
    558 
    559   ExpectSuccess(Builder().Attr("a", {"oranges", "apples"}), {}, {}, R"proto(
    560       op: "AttrEnumList"
    561       attr { key: "a" value { list { s: ["oranges", "apples"] } } } )proto");
    562 
    563   ExpectInvalid(
    564       Builder().Attr("a", {"apples", "invalid", "oranges"}),
    565       "Value for attr 'a' of \"invalid\" is not in the list of allowed values: "
    566       "\"apples\", \"oranges\"");
    567 }
    568 
    569 TEST_F(NodeDefBuilderTest, AttrShape) {
    570   Op(OpDefBuilder("AttrShape").Attr("a: shape"));
    571 
    572   ExpectSuccess(Builder().Attr("a", TensorShape({5})), {}, {}, R"proto(
    573       op: "AttrShape"
    574       attr { key: "a" value { shape { dim { size: 5 } } } } )proto");
    575 
    576   ExpectSuccess(Builder().Attr("a", TensorShape({4, 3, 2})), {}, {}, R"proto(
    577       op: "AttrShape"
    578       attr { key: "a" value { shape {
    579         dim { size: 4 } dim { size: 3 } dim { size: 2 } } } } )proto");
    580 
    581   ExpectSuccess(Builder().Attr("a", TensorShape({3, 2})), {}, {},
    582                 R"proto(
    583       op: "AttrShape"
    584       attr { key: "a" value { shape {
    585         dim { size: 3 } dim { size: 2 } } } } )proto");
    586 
    587   ExpectSuccess(Builder().Attr("a", TensorShape()), {}, {}, R"proto(
    588       op: "AttrShape"
    589       attr { key: "a" value { shape { } } } )proto");
    590 }
    591 
    592 TEST_F(NodeDefBuilderTest, AttrDefault) {
    593   Op(OpDefBuilder("AttrDefault").Attr("a: string = 'banana'"));
    594 
    595   ExpectSuccess(Builder(), {}, {}, R"proto(
    596       op: "AttrDefault"
    597       attr { key: "a" value { s: "banana" } } )proto");
    598 
    599   ExpectSuccess(Builder().Attr("a", "kiwi"), {}, {}, R"proto(
    600       op: "AttrDefault"
    601       attr { key: "a" value { s: "kiwi" } } )proto");
    602 }
    603 
    604 TEST_F(NodeDefBuilderTest, AttrManyDefault) {
    605   Op(OpDefBuilder("AttrManyDefault")
    606          .Attr("a: string = 'banana'")
    607          .Attr("b: string = 'kiwi'"));
    608 
    609   ExpectSuccess(Builder(), {}, {}, R"proto(
    610       op: "AttrManyDefault"
    611       attr { key: "a" value { s: "banana" } }
    612       attr { key: "b" value { s: "kiwi" } } )proto");
    613 
    614   Op(OpDefBuilder("AttrManyDefaultWithMandatory")
    615          .Attr("a: string = 'banana'")
    616          .Attr("b: string = 'kiwi'")
    617          .Attr("c: string"));
    618 
    619   ExpectSuccess(Builder().Attr("c", "strawberry"), {}, {}, R"proto(
    620       op: "AttrManyDefaultWithMandatory"
    621       attr { key: "c" value { s: "strawberry" } }
    622       attr { key: "a" value { s: "banana" } }
    623       attr { key: "b" value { s: "kiwi" } } )proto");
    624 
    625   Op(OpDefBuilder("AttrManyDefaultAndInferred")
    626          .Input("input: T")
    627          .Attr("T: {float, double}")
    628          .Attr("a: string")
    629          .Attr("b: list(string) >= 1")
    630          .Attr("c: bool = true")
    631          .Attr("d: float = 0.3")
    632          .Attr("e: string")
    633          .Attr("f: float = 0.25"));
    634 
    635   ExpectSuccess(Builder()
    636                     .Input(FakeInput(DT_FLOAT))
    637                     .Attr("a", "foo")
    638                     .Attr("e", "foo")
    639                     .Attr("b", std::vector<string>({"bar", "baz"}))
    640                     .Attr("f", 1.0f),
    641                 {DT_FLOAT}, {}, R"proto(
    642       op: "AttrManyDefaultAndInferred"
    643       input: "a"
    644       attr { key: "T" value { type: DT_FLOAT } }
    645       attr { key: "a" value { s: "foo" } }
    646       attr { key: "e" value { s: "foo" } }
    647       attr { key: "b" value { list { s: "bar" s: "baz" } } }
    648       attr { key: "f" value { f: 1.0 } }
    649       attr { key: "c" value { b: true } }
    650       attr { key: "d" value { f: 0.3 } } )proto");
    651 }
    652 
    653 TEST_F(NodeDefBuilderTest, AttrListDefault) {
    654   Op(OpDefBuilder("AttrListDefault").Attr("a: list(int) = [5, 15]"));
    655 
    656   ExpectSuccess(Builder(), {}, {}, R"proto(
    657       op: "AttrListDefault"
    658       attr { key: "a" value { list { i: [5, 15] } } } )proto");
    659 
    660   ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto(
    661       op: "AttrListDefault"
    662       attr { key: "a" value { list { i: 3 } } } )proto");
    663 
    664   ExpectSuccess(Builder().Attr("a", std::vector<int>()), {}, {}, R"proto(
    665       op: "AttrListDefault"
    666       attr { key: "a" value { list { } } } )proto");
    667 }
    668 
    669 TEST_F(NodeDefBuilderTest, AttrEmptyListDefault) {
    670   Op(OpDefBuilder("AttrEmptyListDefault").Attr("a: list(int) = []"));
    671 
    672   ExpectSuccess(Builder(), {}, {}, R"proto(
    673       op: "AttrEmptyListDefault"
    674       attr { key: "a" value { list { } } } )proto");
    675 
    676   ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto(
    677       op: "AttrEmptyListDefault"
    678       attr { key: "a" value { list { i: 3 } } } )proto");
    679 
    680   ExpectSuccess(Builder().Attr("a", std::vector<int>()), {}, {}, R"proto(
    681       op: "AttrEmptyListDefault"
    682       attr { key: "a" value { list { } } } )proto");
    683 }
    684 
    685 TEST_F(NodeDefBuilderTest, NIntsIn) {
    686   Op(OpDefBuilder("NIntsIn").Input("a: N*int32").Attr("N: int >= 2"));
    687 
    688   ExpectSuccess(Builder().Input(FakeInput(2)), {DT_INT32, DT_INT32}, {},
    689                 R"proto(
    690       op: "NIntsIn" input: ["a", "a:1"]
    691       attr { key: "N" value { i: 2 } } )proto");
    692 
    693   ExpectSuccess(Builder().Input(FakeInput(5, DT_INT32)),
    694                 {DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto(
    695       op: "NIntsIn"
    696       input: ["a", "a:1", "a:2", "a:3", "a:4"]
    697       attr { key: "N" value { i: 5 } } )proto");
    698 
    699   ExpectFailures(Builder().Input(FakeInput(2, DT_STRING)),
    700                  {"2 errors while building NodeDef",
    701                   "Input 'a' passed string expected int32"});
    702 
    703   ExpectInvalid(Builder().Input(FakeInput(1)),
    704                 "Value for attr 'N' of 1 must be at least minimum 2");
    705 
    706   ExpectFailures(
    707       Builder().Input(FakeInput(DT_INT32)),
    708       {"2 errors while building NodeDef",
    709        "Could not infer length of input 'a': No attr named 'N' in NodeDef:",
    710        "0 inputs specified of 1 inputs in Op"});
    711 
    712   ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}),
    713                 "Input 'a' passed string expected int32 while");
    714 
    715   ExpectFailures(
    716       Builder().Input(FakeInput()),
    717       {"2 errors while building NodeDef",
    718        "Could not infer length of input 'a': No attr named 'N' in NodeDef:",
    719        "0 inputs specified of 1 inputs in Op"});
    720 }
    721 
    722 TEST_F(NodeDefBuilderTest, NPolymorphicIn) {
    723   Op(OpDefBuilder("NPolymorphicIn")
    724          .Input("a: N*T")
    725          .Attr("T: type")
    726          .Attr("N: int >= 2"));
    727 
    728   ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)), {DT_INT32, DT_INT32},
    729                 {}, R"proto(
    730       op: "NPolymorphicIn" input: ["a", "a:1"]
    731       attr { key: "N" value { i: 2 } }
    732       attr { key: "T" value { type: DT_INT32 } } )proto");
    733 
    734   ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)),
    735                 {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto(
    736       op: "NPolymorphicIn"
    737       input: ["a", "a:1", "a:2"]
    738       attr { key: "N" value { i: 3 } }
    739       attr { key: "T" value { type: DT_STRING } } )proto");
    740 
    741   ExpectFailures(
    742       Builder().Input(FakeInput(2)),
    743       {"2 errors while building NodeDef",
    744        "Could not infer type for input 'a': No attr named 'T' in NodeDef:",
    745        "0 inputs specified of 1 inputs in Op"});
    746 
    747   ExpectFailure(Builder().Input(FakeInput({DT_INT32, DT_STRING})),
    748                 "Input 'a' passed string expected int32 while");
    749 
    750   ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}),
    751                 "Input 'a' passed string expected int32 while");
    752 
    753   ExpectInvalid(Builder().Input(FakeInput(1, DT_INT32)),
    754                 "Value for attr 'N' of 1 must be at least minimum 2");
    755 
    756   ExpectFailure(Builder().Input("in", 0, DT_INT32),
    757                 "Single tensor passed to 'a', expected list while");
    758 }
    759 
    760 TEST_F(NodeDefBuilderTest, NPolymorphicRestrictIn) {
    761   Op(OpDefBuilder("NPolymorphicRestrictIn")
    762          .Input("a: N*T")
    763          .Attr("T: {string, bool}")
    764          .Attr("N: int >= 2"));
    765 
    766   ExpectSuccess(Builder().Input(FakeInput(2, DT_BOOL)), {DT_BOOL, DT_BOOL}, {},
    767                 R"proto(
    768       op: "NPolymorphicRestrictIn" input: ["a", "a:1"]
    769       attr { key: "N" value { i: 2 } }
    770       attr { key: "T" value { type: DT_BOOL } } )proto");
    771 
    772   ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)),
    773                 {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto(
    774       op: "NPolymorphicRestrictIn"
    775       input: ["a", "a:1", "a:2"]
    776       attr { key: "N" value { i: 3 } }
    777       attr { key: "T" value { type: DT_STRING } } )proto");
    778 
    779   ExpectInvalid(Builder().Input(FakeInput(2, DT_INT32)),
    780                 "Value for attr 'T' of int32 is not in the list of allowed "
    781                 "values: string, bool");
    782 }
    783 
    784 TEST_F(NodeDefBuilderTest, NInTwice) {
    785   Op(OpDefBuilder("NInTwice")
    786          .Input("a: N*int32")
    787          .Input("b: N*string")
    788          .Attr("N: int >= 0"));
    789 
    790   ExpectSuccess(Builder().Input(FakeInput(2)).Input(FakeInput(2)),
    791                 {DT_INT32, DT_INT32, DT_STRING, DT_STRING}, {}, R"proto(
    792       op: "NInTwice"
    793       input: ["a", "a:1", "b", "b:1"]
    794       attr { key: "N" value { i: 2 } } )proto");
    795 
    796   ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {},
    797                 R"proto(
    798       op: "NInTwice" attr { key: "N" value { i: 0 } } )proto");
    799 
    800   ExpectFailure(Builder().Input(FakeInput(3)).Input(FakeInput(1)),
    801                 "Inconsistent values for attr 'N' 3 vs. 1 while");
    802 }
    803 
    804 TEST_F(NodeDefBuilderTest, NInPolymorphicTwice) {
    805   Op(OpDefBuilder("NInPolymorphicTwice")
    806          .Input("a: N*T")
    807          .Input("b: N*T")
    808          .Attr("T: type")
    809          .Attr("N: int >= 0"));
    810 
    811   ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput()),
    812                 {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto(
    813       op: "NInPolymorphicTwice"
    814       input: ["a", "a:1", "b", "b:1"]
    815       attr { key: "N" value { i: 2 } }
    816       attr { key: "T" value { type: DT_INT32 } } )proto");
    817 
    818   ExpectFailure(
    819       Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_INT32)),
    820       "Inconsistent values for attr 'N' 3 vs. 1 while");
    821 
    822   ExpectFailure(Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1)),
    823                 "Inconsistent values for attr 'N' 3 vs. 1 while");
    824 
    825   ExpectFailure(
    826       Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)),
    827       "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while");
    828 
    829   ExpectFailure(
    830       Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_STRING)),
    831       "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while");
    832 }
    833 
    834 TEST_F(NodeDefBuilderTest, NInTwoTypeVariables) {
    835   Op(OpDefBuilder("NInTwoTypeVariables")
    836          .Input("a: N*S")
    837          .Input("b: N*T")
    838          .Attr("S: type")
    839          .Attr("T: type")
    840          .Attr("N: int >= 0"));
    841 
    842   ExpectSuccess(
    843       Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_BOOL)),
    844       {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto(
    845       op: "NInTwoTypeVariables"
    846       input: ["a", "a:1", "b", "b:1"]
    847       attr { key: "N" value { i: 2 } }
    848       attr { key: "S" value { type: DT_INT32 } }
    849       attr { key: "T" value { type: DT_BOOL } } )proto");
    850 
    851   ExpectSuccess(
    852       Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_BOOL)),
    853       {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto(
    854       op: "NInTwoTypeVariables"
    855       input: ["a", "a:1", "b", "b:1"]
    856       attr { key: "N" value { i: 2 } }
    857       attr { key: "S" value { type: DT_INT32 } }
    858       attr { key: "T" value { type: DT_BOOL } } )proto");
    859 
    860   ExpectFailure(
    861       Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_STRING)),
    862       "Inconsistent values for attr 'N' 3 vs. 1 while");
    863 }
    864 
    865 TEST_F(NodeDefBuilderTest, InPolymorphicTwice) {
    866   Op(OpDefBuilder("InPolymorphicTwice")
    867          .Input("a: N*T")
    868          .Input("b: M*T")
    869          .Attr("T: type")
    870          .Attr("N: int >= 0")
    871          .Attr("M: int >= 0"));
    872 
    873   ExpectSuccess(
    874       Builder().Input(FakeInput(1, DT_INT32)).Input(FakeInput(3, DT_INT32)),
    875       {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto(
    876       op: "InPolymorphicTwice"
    877       input: ["a", "b", "b:1", "b:2"]
    878       attr { key: "N" value { i: 1 } }
    879       attr { key: "T" value { type: DT_INT32 } }
    880       attr { key: "M" value { i: 3 } } )proto");
    881 
    882   ExpectSuccess(Builder().Input(FakeInput(1, DT_BOOL)).Input(FakeInput(0)),
    883                 {DT_BOOL}, {}, R"proto(
    884       op: "InPolymorphicTwice" input: "a"
    885       attr { key: "N" value { i: 1 } }
    886       attr { key: "T" value { type: DT_BOOL } }
    887       attr { key: "M" value { i: 0 } } )proto");
    888 
    889   ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(1, DT_BOOL)),
    890                 {DT_BOOL}, {}, R"proto(
    891       op: "InPolymorphicTwice" input: "b"
    892       attr { key: "N" value { i: 0 } }
    893       attr { key: "M" value { i: 1 } }
    894       attr { key: "T" value { type: DT_BOOL } } )proto");
    895 
    896   ExpectFailure(
    897       Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)),
    898       "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while");
    899 }
    900 
    901 TEST_F(NodeDefBuilderTest, NIntsOut) {
    902   Op(OpDefBuilder("NIntsOut").Output("a: N*int32").Attr("N: int >= 2"));
    903 
    904   ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto(
    905       op: "NIntsOut"
    906       attr { key: "N" value { i: 2 } } )proto");
    907 
    908   ExpectSuccess(Builder().Attr("N", 3), {}, {DT_INT32, DT_INT32, DT_INT32},
    909                 R"proto(
    910       op: "NIntsOut"
    911       attr { key: "N" value { i: 3 } } )proto");
    912 
    913   ExpectInvalid(Builder().Attr("N", 1),
    914                 "Value for attr 'N' of 1 must be at least minimum 2");
    915 
    916   ExpectInvalid(
    917       Builder().Attr("N", {3}),
    918       "AttrValue had value with type 'list(int)' when 'int' expected");
    919 
    920   ExpectInvalid(Builder(), "NodeDef missing attr 'N' from");
    921 }
    922 
    923 TEST_F(NodeDefBuilderTest, NIntsOutDefault) {
    924   Op(OpDefBuilder("NIntsOutDefault")
    925          .Output("a: N*int32")
    926          .Attr("N: int >= 2 = 3"));
    927 
    928   ExpectSuccess(Builder(), {}, {DT_INT32, DT_INT32, DT_INT32}, R"proto(
    929       op: "NIntsOutDefault"
    930       attr { key: "N" value { i: 3 } } )proto");
    931 
    932   ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto(
    933       op: "NIntsOutDefault"
    934       attr { key: "N" value { i: 2 } } )proto");
    935 }
    936 
    937 TEST_F(NodeDefBuilderTest, NPolymorphicOut) {
    938   Op(OpDefBuilder("NPolymorphicOut")
    939          .Output("a: N*T")
    940          .Attr("T: type")
    941          .Attr("N: int >= 2"));
    942 
    943   ExpectSuccess(Builder().Attr("T", DT_INT32).Attr("N", 2), {},
    944                 {DT_INT32, DT_INT32}, R"proto(
    945       op: "NPolymorphicOut"
    946       attr { key: "T" value { type: DT_INT32 } }
    947       attr { key: "N" value { i: 2 } } )proto");
    948 
    949   ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_STRING), {},
    950                 {DT_STRING, DT_STRING, DT_STRING}, R"proto(
    951       op: "NPolymorphicOut"
    952       attr { key: "N" value { i: 3 } }
    953       attr { key: "T" value { type: DT_STRING } } )proto");
    954 
    955   ExpectInvalid(Builder().Attr("N", 1).Attr("T", DT_STRING),
    956                 "Value for attr 'N' of 1 must be at least minimum 2");
    957 
    958   ExpectInvalid(
    959       Builder().Attr("N", 3).Attr("T", {DT_STRING}),
    960       "AttrValue had value with type 'list(type)' when 'type' expected");
    961 }
    962 
    963 TEST_F(NodeDefBuilderTest, NPolymorphicOutDefault) {
    964   Op(OpDefBuilder("NPolymorphicOutDefault")
    965          .Output("a: N*T")
    966          .Attr("T: type = DT_BOOL")
    967          .Attr("N: int >= 2 = 2"));
    968 
    969   ExpectSuccess(Builder(), {}, {DT_BOOL, DT_BOOL}, R"proto(
    970       op: "NPolymorphicOutDefault"
    971       attr { key: "T" value { type: DT_BOOL } }
    972       attr { key: "N" value { i: 2 } } )proto");
    973 
    974   ExpectSuccess(Builder().Attr("N", 3), {}, {DT_BOOL, DT_BOOL, DT_BOOL},
    975                 R"proto(
    976       op: "NPolymorphicOutDefault"
    977       attr { key: "N" value { i: 3 } }
    978       attr { key: "T" value { type: DT_BOOL } } )proto");
    979 
    980   ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32, DT_INT32},
    981                 R"proto(
    982       op: "NPolymorphicOutDefault"
    983       attr { key: "T" value { type: DT_INT32 } }
    984       attr { key: "N" value { i: 2 } } )proto");
    985 
    986   ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_INT32), {},
    987                 {DT_INT32, DT_INT32, DT_INT32}, R"proto(
    988       op: "NPolymorphicOutDefault"
    989       attr { key: "N" value { i: 3 } }
    990       attr { key: "T" value { type: DT_INT32 } } )proto");
    991 }
    992 
    993 TEST_F(NodeDefBuilderTest, NPolymorphicRestrictOut) {
    994   Op(OpDefBuilder("NPolymorphicRestrictOut")
    995          .Output("a: N*T")
    996          .Attr("T: {string, bool}")
    997          .Attr("N: int >= 2"));
    998 
    999   ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_BOOL), {},
   1000                 {DT_BOOL, DT_BOOL, DT_BOOL}, R"proto(
   1001       op: "NPolymorphicRestrictOut"
   1002       attr { key: "N" value { i: 3 } }
   1003       attr { key: "T" value { type: DT_BOOL } } )proto");
   1004 
   1005   ExpectInvalid(Builder().Attr("N", 3).Attr("T", DT_INT32),
   1006                 "Value for attr 'T' of int32 is not in the list of allowed "
   1007                 "values: string, bool");
   1008 }
   1009 
   1010 TEST_F(NodeDefBuilderTest, RefIn) {
   1011   Op(OpDefBuilder("RefIn").Input("a: Ref(int32)"));
   1012 
   1013   ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32_REF}, {},
   1014                 R"proto(
   1015       op: "RefIn" input: "a" )proto");
   1016 
   1017   ExpectFailure(Builder().Input(FakeInput(DT_BOOL_REF)),
   1018                 "Input 'a' passed bool_ref expected int32_ref while");
   1019 
   1020   ExpectFailure(Builder().Input(FakeInput(DT_INT32)),
   1021                 "Input 'a' passed int32 expected int32_ref while");
   1022 }
   1023 
   1024 TEST_F(NodeDefBuilderTest, PolymorphicRefIn) {
   1025   Op(OpDefBuilder("PolymorphicRefIn").Input("a: Ref(T)").Attr("T: type"));
   1026 
   1027   ExpectSuccess(Builder().Input(FakeInput(DT_BOOL_REF)), {DT_BOOL_REF}, {},
   1028                 R"proto(
   1029       op: "PolymorphicRefIn" input: "a"
   1030       attr { key: "T" value { type: DT_BOOL } } )proto");
   1031 
   1032   ExpectFailure(Builder().Input(FakeInput(DT_BOOL)),
   1033                 "Input 'a' passed bool expected ref type while");
   1034 }
   1035 
   1036 TEST_F(NodeDefBuilderTest, RefOut) {
   1037   Op(OpDefBuilder("RefOut").Output("a: Ref(string)"));
   1038 
   1039   ExpectSuccess(Builder(), {}, {DT_STRING_REF}, R"proto(
   1040       op: "RefOut" )proto");
   1041 }
   1042 
   1043 TEST_F(NodeDefBuilderTest, PolymorphicRefOut) {
   1044   Op(OpDefBuilder("PolymorphicRefOut").Output("a: Ref(t)").Attr("t: type"));
   1045 
   1046   ExpectSuccess(Builder().Attr("t", DT_BOOL), {}, {DT_BOOL_REF}, R"proto(
   1047       op: "PolymorphicRefOut"
   1048       attr { key: "t" value { type: DT_BOOL } } )proto");
   1049 }
   1050 
   1051 TEST_F(NodeDefBuilderTest, SpecifyDevice) {
   1052   Op(OpDefBuilder("SpecifyDevice"));
   1053 
   1054   ExpectSuccess(Builder().Device("ADevice"), {}, {}, R"proto(
   1055       op: "SpecifyDevice" device: "ADevice" )proto");
   1056 }
   1057 
   1058 }  // namespace
   1059 }  // namespace tensorflow
   1060