Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <ctype.h>
     17 #include <vector>
     18 
     19 #include "tensorflow/core/platform/test.h"
     20 #include "tensorflow/core/util/command_line_flags.h"
     21 
     22 namespace tensorflow {
     23 namespace {
     24 // The returned array is only valid for the lifetime of the input vector.
     25 // We're using const casting because we need to pass in an argv-style array of
     26 // char* pointers for the API, even though we know they won't be altered.
     27 std::vector<char *> CharPointerVectorFromStrings(
     28     const std::vector<string> &strings) {
     29   std::vector<char *> result;
     30   result.reserve(strings.size());
     31   for (const string &string : strings) {
     32     result.push_back(const_cast<char *>(string.c_str()));
     33   }
     34   return result;
     35 }
     36 }  // namespace
     37 
     38 TEST(CommandLineFlagsTest, BasicUsage) {
     39   int some_int32_set_directly = 10;
     40   int some_int32_set_via_hook = 20;
     41   int64 some_int64_set_directly = 21474836470;  // max int32 is 2147483647
     42   int64 some_int64_set_via_hook = 21474836479;  // max int32 is 2147483647
     43   bool some_switch_set_directly = false;
     44   bool some_switch_set_via_hook = true;
     45   string some_name_set_directly = "something_a";
     46   string some_name_set_via_hook = "something_b";
     47   float some_float_set_directly = -23.23f;
     48   float some_float_set_via_hook = -25.23f;
     49   std::vector<string> argv_strings = {"program_name",
     50                                       "--some_int32_set_directly=20",
     51                                       "--some_int32_set_via_hook=50",
     52                                       "--some_int64_set_directly=214748364700",
     53                                       "--some_int64_set_via_hook=214748364710",
     54                                       "--some_switch_set_directly",
     55                                       "--some_switch_set_via_hook=false",
     56                                       "--some_name_set_directly=somethingelse",
     57                                       "--some_name_set_via_hook=anythingelse",
     58                                       "--some_float_set_directly=42.0",
     59                                       "--some_float_set_via_hook=43.0"};
     60   int argc = argv_strings.size();
     61   std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
     62   bool parsed_ok = Flags::Parse(
     63       &argc, argv_array.data(),
     64       {
     65           Flag("some_int32_set_directly", &some_int32_set_directly,
     66                "some int32 set directly"),
     67           Flag("some_int32_set_via_hook",
     68                [&](int32 value) {
     69                  some_int32_set_via_hook = value;
     70                  return true;
     71                },
     72                some_int32_set_via_hook, "some int32 set via hook"),
     73           Flag("some_int64_set_directly", &some_int64_set_directly,
     74                "some int64 set directly"),
     75           Flag("some_int64_set_via_hook",
     76                [&](int64 value) {
     77                  some_int64_set_via_hook = value;
     78                  return true;
     79                },
     80                some_int64_set_via_hook, "some int64 set via hook"),
     81           Flag("some_switch_set_directly", &some_switch_set_directly,
     82                "some switch set directly"),
     83           Flag("some_switch_set_via_hook",
     84                [&](bool value) {
     85                  some_switch_set_via_hook = value;
     86                  return true;
     87                },
     88                some_switch_set_via_hook, "some switch set via hook"),
     89           Flag("some_name_set_directly", &some_name_set_directly,
     90                "some name set directly"),
     91           Flag("some_name_set_via_hook",
     92                [&](string value) {
     93                  some_name_set_via_hook = std::move(value);
     94                  return true;
     95                },
     96                some_name_set_via_hook, "some name set via hook"),
     97           Flag("some_float_set_directly", &some_float_set_directly,
     98                "some float set directly"),
     99           Flag("some_float_set_via_hook",
    100                [&](float value) {
    101                  some_float_set_via_hook = value;
    102                  return true;
    103                },
    104                some_float_set_via_hook, "some float set via hook"),
    105       });
    106 
    107   EXPECT_EQ(true, parsed_ok);
    108   EXPECT_EQ(20, some_int32_set_directly);
    109   EXPECT_EQ(50, some_int32_set_via_hook);
    110   EXPECT_EQ(214748364700, some_int64_set_directly);
    111   EXPECT_EQ(214748364710, some_int64_set_via_hook);
    112   EXPECT_EQ(true, some_switch_set_directly);
    113   EXPECT_EQ(false, some_switch_set_via_hook);
    114   EXPECT_EQ("somethingelse", some_name_set_directly);
    115   EXPECT_EQ("anythingelse", some_name_set_via_hook);
    116   EXPECT_NEAR(42.0f, some_float_set_directly, 1e-5f);
    117   EXPECT_NEAR(43.0f, some_float_set_via_hook, 1e-5f);
    118   EXPECT_EQ(argc, 1);
    119 }
    120 
    121 TEST(CommandLineFlagsTest, BadIntValue) {
    122   int some_int = 10;
    123   int argc = 2;
    124   std::vector<string> argv_strings = {"program_name", "--some_int=notanumber"};
    125   std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
    126   bool parsed_ok = Flags::Parse(&argc, argv_array.data(),
    127                                 {Flag("some_int", &some_int, "some int")});
    128 
    129   EXPECT_EQ(false, parsed_ok);
    130   EXPECT_EQ(10, some_int);
    131   EXPECT_EQ(argc, 1);
    132 }
    133 
    134 TEST(CommandLineFlagsTest, BadBoolValue) {
    135   bool some_switch = false;
    136   int argc = 2;
    137   std::vector<string> argv_strings = {"program_name", "--some_switch=notabool"};
    138   std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
    139   bool parsed_ok =
    140       Flags::Parse(&argc, argv_array.data(),
    141                    {Flag("some_switch", &some_switch, "some switch")});
    142 
    143   EXPECT_EQ(false, parsed_ok);
    144   EXPECT_EQ(false, some_switch);
    145   EXPECT_EQ(argc, 1);
    146 }
    147 
    148 TEST(CommandLineFlagsTest, BadFloatValue) {
    149   float some_float = -23.23f;
    150   int argc = 2;
    151   std::vector<string> argv_strings = {"program_name",
    152                                       "--some_float=notanumber"};
    153   std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
    154   bool parsed_ok =
    155       Flags::Parse(&argc, argv_array.data(),
    156                    {Flag("some_float", &some_float, "some float")});
    157 
    158   EXPECT_EQ(false, parsed_ok);
    159   EXPECT_NEAR(-23.23f, some_float, 1e-5f);
    160   EXPECT_EQ(argc, 1);
    161 }
    162 
    163 TEST(CommandLineFlagsTest, FailedInt32Hook) {
    164   int argc = 2;
    165   std::vector<string> argv_strings = {"program_name", "--some_int32=200"};
    166   std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
    167   bool parsed_ok =
    168       Flags::Parse(&argc, argv_array.data(),
    169                    {Flag("some_int32", [](int32 value) { return false; }, 30,
    170                          "some int32")});
    171 
    172   EXPECT_EQ(false, parsed_ok);
    173   EXPECT_EQ(argc, 1);
    174 }
    175 
    176 TEST(CommandLineFlagsTest, FailedInt64Hook) {
    177   int argc = 2;
    178   std::vector<string> argv_strings = {"program_name", "--some_int64=200"};
    179   std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
    180   bool parsed_ok =
    181       Flags::Parse(&argc, argv_array.data(),
    182                    {Flag("some_int64", [](int64 value) { return false; }, 30,
    183                          "some int64")});
    184 
    185   EXPECT_EQ(false, parsed_ok);
    186   EXPECT_EQ(argc, 1);
    187 }
    188 
    189 TEST(CommandLineFlagsTest, FailedFloatHook) {
    190   int argc = 2;
    191   std::vector<string> argv_strings = {"program_name", "--some_float=200.0"};
    192   std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
    193   bool parsed_ok =
    194       Flags::Parse(&argc, argv_array.data(),
    195                    {Flag("some_float", [](float value) { return false; }, 30.0f,
    196                          "some float")});
    197 
    198   EXPECT_EQ(false, parsed_ok);
    199   EXPECT_EQ(argc, 1);
    200 }
    201 
    202 TEST(CommandLineFlagsTest, FailedBoolHook) {
    203   int argc = 2;
    204   std::vector<string> argv_strings = {"program_name", "--some_switch=true"};
    205   std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
    206   bool parsed_ok =
    207       Flags::Parse(&argc, argv_array.data(),
    208                    {Flag("some_switch", [](bool value) { return false; }, false,
    209                          "some switch")});
    210 
    211   EXPECT_EQ(false, parsed_ok);
    212   EXPECT_EQ(argc, 1);
    213 }
    214 
    215 TEST(CommandLineFlagsTest, FailedStringHook) {
    216   int argc = 2;
    217   std::vector<string> argv_strings = {"program_name", "--some_name=true"};
    218   std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
    219   bool parsed_ok = Flags::Parse(
    220       &argc, argv_array.data(),
    221       {Flag("some_name", [](string value) { return false; }, "", "some name")});
    222 
    223   EXPECT_EQ(false, parsed_ok);
    224   EXPECT_EQ(argc, 1);
    225 }
    226 
    227 TEST(CommandLineFlagsTest, RepeatedStringHook) {
    228   int argc = 3;
    229   std::vector<string> argv_strings = {"program_name", "--some_name=this",
    230                                       "--some_name=that"};
    231   std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
    232   int call_count = 0;
    233   bool parsed_ok = Flags::Parse(&argc, argv_array.data(),
    234                                 {Flag("some_name",
    235                                       [&call_count](string value) {
    236                                         call_count++;
    237                                         return true;
    238                                       },
    239                                       "", "some name")});
    240 
    241   EXPECT_EQ(true, parsed_ok);
    242   EXPECT_EQ(argc, 1);
    243   EXPECT_EQ(call_count, 2);
    244 }
    245 
    246 // Return whether str==pat, but allowing any whitespace in pat
    247 // to match zero or more whitespace characters in str.
    248 static bool MatchWithAnyWhitespace(const string &str, const string &pat) {
    249   bool matching = true;
    250   int pat_i = 0;
    251   for (int str_i = 0; str_i != str.size() && matching; str_i++) {
    252     if (isspace(str[str_i])) {
    253       matching = (pat_i != pat.size() && isspace(pat[pat_i]));
    254     } else {
    255       while (pat_i != pat.size() && isspace(pat[pat_i])) {
    256         pat_i++;
    257       }
    258       matching = (pat_i != pat.size() && str[str_i] == pat[pat_i++]);
    259     }
    260   }
    261   while (pat_i != pat.size() && isspace(pat[pat_i])) {
    262     pat_i++;
    263   }
    264   return (matching && pat_i == pat.size());
    265 }
    266 
    267 TEST(CommandLineFlagsTest, UsageString) {
    268   int some_int = 10;
    269   int64 some_int64 = 21474836470;  // max int32 is 2147483647
    270   bool some_switch = false;
    271   string some_name = "something";
    272   // Don't test float in this case, because precision is hard to predict and
    273   // match against, and we don't want a flakey test.
    274   const string tool_name = "some_tool_name";
    275   string usage = Flags::Usage(tool_name + "<flags>",
    276                               {Flag("some_int", &some_int, "some int"),
    277                                Flag("some_int64", &some_int64, "some int64"),
    278                                Flag("some_switch", &some_switch, "some switch"),
    279                                Flag("some_name", &some_name, "some name")});
    280   // Match the usage message, being sloppy about whitespace.
    281   const char *expected_usage =
    282       " usage: some_tool_name <flags>\n"
    283       "Flags:\n"
    284       "--some_int=10 int32 some int\n"
    285       "--some_int64=21474836470 int64 some int64\n"
    286       "--some_switch=false bool some switch\n"
    287       "--some_name=\"something\" string some name\n";
    288   ASSERT_EQ(MatchWithAnyWhitespace(usage, expected_usage), true);
    289 
    290   // Again but with no flags.
    291   usage = Flags::Usage(tool_name, {});
    292   ASSERT_EQ(MatchWithAnyWhitespace(usage, " usage: some_tool_name\n"), true);
    293 }
    294 }  // namespace tensorflow
    295