1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <ctype.h> 17 #include <vector> 18 19 #include "tensorflow/core/platform/test.h" 20 #include "tensorflow/core/util/command_line_flags.h" 21 22 namespace tensorflow { 23 namespace { 24 // The returned array is only valid for the lifetime of the input vector. 25 // We're using const casting because we need to pass in an argv-style array of 26 // char* pointers for the API, even though we know they won't be altered. 27 std::vector<char *> CharPointerVectorFromStrings( 28 const std::vector<string> &strings) { 29 std::vector<char *> result; 30 result.reserve(strings.size()); 31 for (const string &string : strings) { 32 result.push_back(const_cast<char *>(string.c_str())); 33 } 34 return result; 35 } 36 } // namespace 37 38 TEST(CommandLineFlagsTest, BasicUsage) { 39 int some_int32_set_directly = 10; 40 int some_int32_set_via_hook = 20; 41 int64 some_int64_set_directly = 21474836470; // max int32 is 2147483647 42 int64 some_int64_set_via_hook = 21474836479; // max int32 is 2147483647 43 bool some_switch_set_directly = false; 44 bool some_switch_set_via_hook = true; 45 string some_name_set_directly = "something_a"; 46 string some_name_set_via_hook = "something_b"; 47 float some_float_set_directly = -23.23f; 48 float some_float_set_via_hook = -25.23f; 49 std::vector<string> argv_strings = {"program_name", 50 "--some_int32_set_directly=20", 51 "--some_int32_set_via_hook=50", 52 "--some_int64_set_directly=214748364700", 53 "--some_int64_set_via_hook=214748364710", 54 "--some_switch_set_directly", 55 "--some_switch_set_via_hook=false", 56 "--some_name_set_directly=somethingelse", 57 "--some_name_set_via_hook=anythingelse", 58 "--some_float_set_directly=42.0", 59 "--some_float_set_via_hook=43.0"}; 60 int argc = argv_strings.size(); 61 std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); 62 bool parsed_ok = Flags::Parse( 63 &argc, argv_array.data(), 64 { 65 Flag("some_int32_set_directly", &some_int32_set_directly, 66 "some int32 set directly"), 67 Flag("some_int32_set_via_hook", 68 [&](int32 value) { 69 some_int32_set_via_hook = value; 70 return true; 71 }, 72 some_int32_set_via_hook, "some int32 set via hook"), 73 Flag("some_int64_set_directly", &some_int64_set_directly, 74 "some int64 set directly"), 75 Flag("some_int64_set_via_hook", 76 [&](int64 value) { 77 some_int64_set_via_hook = value; 78 return true; 79 }, 80 some_int64_set_via_hook, "some int64 set via hook"), 81 Flag("some_switch_set_directly", &some_switch_set_directly, 82 "some switch set directly"), 83 Flag("some_switch_set_via_hook", 84 [&](bool value) { 85 some_switch_set_via_hook = value; 86 return true; 87 }, 88 some_switch_set_via_hook, "some switch set via hook"), 89 Flag("some_name_set_directly", &some_name_set_directly, 90 "some name set directly"), 91 Flag("some_name_set_via_hook", 92 [&](string value) { 93 some_name_set_via_hook = std::move(value); 94 return true; 95 }, 96 some_name_set_via_hook, "some name set via hook"), 97 Flag("some_float_set_directly", &some_float_set_directly, 98 "some float set directly"), 99 Flag("some_float_set_via_hook", 100 [&](float value) { 101 some_float_set_via_hook = value; 102 return true; 103 }, 104 some_float_set_via_hook, "some float set via hook"), 105 }); 106 107 EXPECT_EQ(true, parsed_ok); 108 EXPECT_EQ(20, some_int32_set_directly); 109 EXPECT_EQ(50, some_int32_set_via_hook); 110 EXPECT_EQ(214748364700, some_int64_set_directly); 111 EXPECT_EQ(214748364710, some_int64_set_via_hook); 112 EXPECT_EQ(true, some_switch_set_directly); 113 EXPECT_EQ(false, some_switch_set_via_hook); 114 EXPECT_EQ("somethingelse", some_name_set_directly); 115 EXPECT_EQ("anythingelse", some_name_set_via_hook); 116 EXPECT_NEAR(42.0f, some_float_set_directly, 1e-5f); 117 EXPECT_NEAR(43.0f, some_float_set_via_hook, 1e-5f); 118 EXPECT_EQ(argc, 1); 119 } 120 121 TEST(CommandLineFlagsTest, BadIntValue) { 122 int some_int = 10; 123 int argc = 2; 124 std::vector<string> argv_strings = {"program_name", "--some_int=notanumber"}; 125 std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); 126 bool parsed_ok = Flags::Parse(&argc, argv_array.data(), 127 {Flag("some_int", &some_int, "some int")}); 128 129 EXPECT_EQ(false, parsed_ok); 130 EXPECT_EQ(10, some_int); 131 EXPECT_EQ(argc, 1); 132 } 133 134 TEST(CommandLineFlagsTest, BadBoolValue) { 135 bool some_switch = false; 136 int argc = 2; 137 std::vector<string> argv_strings = {"program_name", "--some_switch=notabool"}; 138 std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); 139 bool parsed_ok = 140 Flags::Parse(&argc, argv_array.data(), 141 {Flag("some_switch", &some_switch, "some switch")}); 142 143 EXPECT_EQ(false, parsed_ok); 144 EXPECT_EQ(false, some_switch); 145 EXPECT_EQ(argc, 1); 146 } 147 148 TEST(CommandLineFlagsTest, BadFloatValue) { 149 float some_float = -23.23f; 150 int argc = 2; 151 std::vector<string> argv_strings = {"program_name", 152 "--some_float=notanumber"}; 153 std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); 154 bool parsed_ok = 155 Flags::Parse(&argc, argv_array.data(), 156 {Flag("some_float", &some_float, "some float")}); 157 158 EXPECT_EQ(false, parsed_ok); 159 EXPECT_NEAR(-23.23f, some_float, 1e-5f); 160 EXPECT_EQ(argc, 1); 161 } 162 163 TEST(CommandLineFlagsTest, FailedInt32Hook) { 164 int argc = 2; 165 std::vector<string> argv_strings = {"program_name", "--some_int32=200"}; 166 std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); 167 bool parsed_ok = 168 Flags::Parse(&argc, argv_array.data(), 169 {Flag("some_int32", [](int32 value) { return false; }, 30, 170 "some int32")}); 171 172 EXPECT_EQ(false, parsed_ok); 173 EXPECT_EQ(argc, 1); 174 } 175 176 TEST(CommandLineFlagsTest, FailedInt64Hook) { 177 int argc = 2; 178 std::vector<string> argv_strings = {"program_name", "--some_int64=200"}; 179 std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); 180 bool parsed_ok = 181 Flags::Parse(&argc, argv_array.data(), 182 {Flag("some_int64", [](int64 value) { return false; }, 30, 183 "some int64")}); 184 185 EXPECT_EQ(false, parsed_ok); 186 EXPECT_EQ(argc, 1); 187 } 188 189 TEST(CommandLineFlagsTest, FailedFloatHook) { 190 int argc = 2; 191 std::vector<string> argv_strings = {"program_name", "--some_float=200.0"}; 192 std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); 193 bool parsed_ok = 194 Flags::Parse(&argc, argv_array.data(), 195 {Flag("some_float", [](float value) { return false; }, 30.0f, 196 "some float")}); 197 198 EXPECT_EQ(false, parsed_ok); 199 EXPECT_EQ(argc, 1); 200 } 201 202 TEST(CommandLineFlagsTest, FailedBoolHook) { 203 int argc = 2; 204 std::vector<string> argv_strings = {"program_name", "--some_switch=true"}; 205 std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); 206 bool parsed_ok = 207 Flags::Parse(&argc, argv_array.data(), 208 {Flag("some_switch", [](bool value) { return false; }, false, 209 "some switch")}); 210 211 EXPECT_EQ(false, parsed_ok); 212 EXPECT_EQ(argc, 1); 213 } 214 215 TEST(CommandLineFlagsTest, FailedStringHook) { 216 int argc = 2; 217 std::vector<string> argv_strings = {"program_name", "--some_name=true"}; 218 std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); 219 bool parsed_ok = Flags::Parse( 220 &argc, argv_array.data(), 221 {Flag("some_name", [](string value) { return false; }, "", "some name")}); 222 223 EXPECT_EQ(false, parsed_ok); 224 EXPECT_EQ(argc, 1); 225 } 226 227 TEST(CommandLineFlagsTest, RepeatedStringHook) { 228 int argc = 3; 229 std::vector<string> argv_strings = {"program_name", "--some_name=this", 230 "--some_name=that"}; 231 std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); 232 int call_count = 0; 233 bool parsed_ok = Flags::Parse(&argc, argv_array.data(), 234 {Flag("some_name", 235 [&call_count](string value) { 236 call_count++; 237 return true; 238 }, 239 "", "some name")}); 240 241 EXPECT_EQ(true, parsed_ok); 242 EXPECT_EQ(argc, 1); 243 EXPECT_EQ(call_count, 2); 244 } 245 246 // Return whether str==pat, but allowing any whitespace in pat 247 // to match zero or more whitespace characters in str. 248 static bool MatchWithAnyWhitespace(const string &str, const string &pat) { 249 bool matching = true; 250 int pat_i = 0; 251 for (int str_i = 0; str_i != str.size() && matching; str_i++) { 252 if (isspace(str[str_i])) { 253 matching = (pat_i != pat.size() && isspace(pat[pat_i])); 254 } else { 255 while (pat_i != pat.size() && isspace(pat[pat_i])) { 256 pat_i++; 257 } 258 matching = (pat_i != pat.size() && str[str_i] == pat[pat_i++]); 259 } 260 } 261 while (pat_i != pat.size() && isspace(pat[pat_i])) { 262 pat_i++; 263 } 264 return (matching && pat_i == pat.size()); 265 } 266 267 TEST(CommandLineFlagsTest, UsageString) { 268 int some_int = 10; 269 int64 some_int64 = 21474836470; // max int32 is 2147483647 270 bool some_switch = false; 271 string some_name = "something"; 272 // Don't test float in this case, because precision is hard to predict and 273 // match against, and we don't want a flakey test. 274 const string tool_name = "some_tool_name"; 275 string usage = Flags::Usage(tool_name + "<flags>", 276 {Flag("some_int", &some_int, "some int"), 277 Flag("some_int64", &some_int64, "some int64"), 278 Flag("some_switch", &some_switch, "some switch"), 279 Flag("some_name", &some_name, "some name")}); 280 // Match the usage message, being sloppy about whitespace. 281 const char *expected_usage = 282 " usage: some_tool_name <flags>\n" 283 "Flags:\n" 284 "--some_int=10 int32 some int\n" 285 "--some_int64=21474836470 int64 some int64\n" 286 "--some_switch=false bool some switch\n" 287 "--some_name=\"something\" string some name\n"; 288 ASSERT_EQ(MatchWithAnyWhitespace(usage, expected_usage), true); 289 290 // Again but with no flags. 291 usage = Flags::Usage(tool_name, {}); 292 ASSERT_EQ(MatchWithAnyWhitespace(usage, " usage: some_tool_name\n"), true); 293 } 294 } // namespace tensorflow 295