Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/graph/validate.h"
     17 
     18 #include <string>
     19 
     20 #include "tensorflow/core/framework/graph.pb.h"
     21 #include "tensorflow/core/framework/graph_def_util.h"
     22 #include "tensorflow/core/framework/op_def_builder.h"
     23 #include "tensorflow/core/graph/graph.h"
     24 #include "tensorflow/core/graph/graph_def_builder.h"
     25 #include "tensorflow/core/graph/subgraph.h"
     26 #include "tensorflow/core/kernels/ops_util.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/lib/core/status_test_util.h"
     29 #include "tensorflow/core/platform/test.h"
     30 
     31 namespace tensorflow {
     32 namespace {
     33 
     34 REGISTER_OP("FloatInput").Output("o: float");
     35 REGISTER_OP("Int32Input").Output("o: int32");
     36 
     37 TEST(ValidateGraphDefTest, TestValidGraph) {
     38   const string graph_def_str =
     39       "node { name: 'A' op: 'FloatInput' }"
     40       "node { name: 'B' op: 'FloatInput' }"
     41       "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
     42       " input: ['A', 'B'] }";
     43   GraphDef graph_def;
     44   auto parser = protobuf::TextFormat::Parser();
     45   CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
     46   TF_ASSERT_OK(graph::ValidateGraphDef(graph_def, *OpRegistry::Global()));
     47 }
     48 
     49 TEST(ValidateGraphDefTest, GraphWithUnspecifiedDefaultAttr) {
     50   const string graph_def_str =
     51       "node { name: 'A' op: 'FloatInput' }"
     52       "node { name: 'B' op: 'Int32Input' }"
     53       "node { "
     54       "       name: 'C' op: 'Sum' "
     55       "       attr { key: 'T' value { type: DT_FLOAT } }"
     56       "       input: ['A', 'B'] "
     57       "}";
     58   GraphDef graph_def;
     59   auto parser = protobuf::TextFormat::Parser();
     60   CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
     61   Status s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global());
     62   EXPECT_FALSE(s.ok());
     63   EXPECT_TRUE(StringPiece(s.ToString()).contains("NodeDef missing attr"));
     64 
     65   // Add the defaults.
     66   TF_ASSERT_OK(AddDefaultAttrsToGraphDef(&graph_def, *OpRegistry::Global(), 0));
     67 
     68   // Validation should succeed.
     69   TF_ASSERT_OK(graph::ValidateGraphDef(graph_def, *OpRegistry::Global()));
     70 }
     71 
     72 TEST(ValidateGraphDefTest, GraphWithUnspecifiedRequiredAttr) {
     73   // "DstT" attribute is missing.
     74   const string graph_def_str =
     75       "node { name: 'A' op: 'FloatInput' }"
     76       "node { "
     77       "       name: 'B' op: 'Cast' "
     78       "       attr { key: 'SrcT' value { type: DT_FLOAT } }"
     79       "       input: ['A'] "
     80       "}";
     81   GraphDef graph_def;
     82   auto parser = protobuf::TextFormat::Parser();
     83   CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
     84   Status s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global());
     85   EXPECT_FALSE(s.ok());
     86   EXPECT_TRUE(StringPiece(s.ToString()).contains("NodeDef missing attr"));
     87 
     88   // Add the defaults.
     89   TF_ASSERT_OK(AddDefaultAttrsToGraphDef(&graph_def, *OpRegistry::Global(), 0));
     90 
     91   // Validation should still fail.
     92   s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global());
     93   EXPECT_FALSE(s.ok());
     94   EXPECT_TRUE(StringPiece(s.ToString()).contains("NodeDef missing attr"));
     95 }
     96 
     97 TEST(ValidateGraphDefAgainstOpListTest, GraphWithOpOnlyInOpList) {
     98   OpRegistrationData op_reg_data;
     99   TF_ASSERT_OK(OpDefBuilder("UniqueSnowflake").Finalize(&op_reg_data));
    100   OpList op_list;
    101   *op_list.add_op() = op_reg_data.op_def;
    102   const string graph_def_str = "node { name: 'A' op: 'UniqueSnowflake' }";
    103   GraphDef graph_def;
    104   auto parser = protobuf::TextFormat::Parser();
    105   CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
    106   TF_ASSERT_OK(graph::ValidateGraphDefAgainstOpList(graph_def, op_list));
    107 }
    108 
    109 TEST(ValidateGraphDefAgainstOpListTest, GraphWithGlobalOpNotInOpList) {
    110   OpRegistrationData op_reg_data;
    111   TF_ASSERT_OK(OpDefBuilder("NotAnywhere").Finalize(&op_reg_data));
    112   OpList op_list;
    113   *op_list.add_op() = op_reg_data.op_def;
    114   const string graph_def_str = "node { name: 'A' op: 'FloatInput' }";
    115   GraphDef graph_def;
    116   auto parser = protobuf::TextFormat::Parser();
    117   CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
    118   ASSERT_FALSE(graph::ValidateGraphDefAgainstOpList(graph_def, op_list).ok());
    119 }
    120 
    121 REGISTER_OP("HasDocs").Doc("This is in the summary.");
    122 
    123 TEST(GetOpListForValidationTest, ShouldStripDocs) {
    124   bool found_float = false;
    125   bool found_int32 = false;
    126   bool found_has_docs = false;
    127   OpList op_list;
    128   graph::GetOpListForValidation(&op_list);
    129   for (const OpDef& op_def : op_list.op()) {
    130     if (op_def.name() == "FloatInput") {
    131       EXPECT_FALSE(found_float);
    132       found_float = true;
    133     }
    134     if (op_def.name() == "Int32Input") {
    135       EXPECT_FALSE(found_int32);
    136       found_int32 = true;
    137     }
    138     if (op_def.name() == "HasDocs") {
    139       EXPECT_FALSE(found_has_docs);
    140       found_has_docs = true;
    141       EXPECT_TRUE(op_def.summary().empty());
    142     }
    143   }
    144   EXPECT_TRUE(found_float);
    145   EXPECT_TRUE(found_int32);
    146   EXPECT_TRUE(found_has_docs);
    147 }
    148 
    149 }  // namespace
    150 }  // namespace tensorflow
    151