1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Test for parse_flags_from_env.cc 17 18 #include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" 19 20 #include <stdio.h> 21 #include <stdlib.h> 22 #include <vector> 23 24 #include "tensorflow/compiler/xla/types.h" 25 #include "tensorflow/core/lib/strings/stringprintf.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/test.h" 28 #include "tensorflow/core/platform/types.h" 29 #include "tensorflow/core/util/command_line_flags.h" 30 31 namespace xla { 32 namespace legacy_flags { 33 34 // Test that XLA flags can be set from the environment. 35 // Failure messages are accompanied by the text in msg[]. 36 static void TestParseFlagsFromEnv(const char* msg) { 37 // Initialize module under test. 38 int* pargc; 39 std::vector<char*>* pargv; 40 ResetFlagsFromEnvForTesting(&pargc, &pargv); 41 42 // Ensure that environment variable can be parsed when 43 // no flags are expected. 44 std::vector<tensorflow::Flag> empty_flag_list; 45 bool parsed_ok = ParseFlagsFromEnv(empty_flag_list); 46 CHECK(parsed_ok) << msg; 47 const std::vector<char*>& argv_first = *pargv; 48 CHECK_NE(argv_first[0], nullptr) << msg; 49 int i = 0; 50 while (argv_first[i] != nullptr) { 51 i++; 52 } 53 CHECK_EQ(i, *pargc) << msg; 54 55 // Check that actual flags can be parsed. 56 bool simple = false; 57 string with_value; 58 string embedded_quotes; 59 string single_quoted; 60 string double_quoted; 61 std::vector<tensorflow::Flag> flag_list = { 62 tensorflow::Flag("simple", &simple, ""), 63 tensorflow::Flag("with_value", &with_value, ""), 64 tensorflow::Flag("embedded_quotes", &embedded_quotes, ""), 65 tensorflow::Flag("single_quoted", &single_quoted, ""), 66 tensorflow::Flag("double_quoted", &double_quoted, ""), 67 }; 68 parsed_ok = ParseFlagsFromEnv(flag_list); 69 CHECK_EQ(*pargc, 1) << msg; 70 const std::vector<char*>& argv_second = *pargv; 71 CHECK_NE(argv_second[0], nullptr) << msg; 72 CHECK_EQ(argv_second[1], nullptr) << msg; 73 CHECK(parsed_ok) << msg; 74 CHECK(simple) << msg; 75 CHECK_EQ(with_value, "a_value") << msg; 76 CHECK_EQ(embedded_quotes, "single'double\"") << msg; 77 CHECK_EQ(single_quoted, "single quoted \\\\ \n \"") << msg; 78 CHECK_EQ(double_quoted, "double quoted \\ \n '\"") << msg; 79 } 80 81 // The flags settings to test. 82 static const char kTestFlagString[] = 83 "--simple " 84 "--with_value=a_value " 85 "--embedded_quotes=single'double\" " 86 "--single_quoted='single quoted \\\\ \n \"' " 87 "--double_quoted=\"double quoted \\\\ \n '\\\"\" "; 88 89 // Test that the environent variable is parsed correctly. 90 TEST(ParseFlagsFromEnv, Basic) { 91 // Prepare environment. 92 setenv("TF_XLA_FLAGS", kTestFlagString, true /*overwrite*/); 93 TestParseFlagsFromEnv("(flags in environment variable)"); 94 } 95 96 // Test that a file named by the environent variable is parsed correctly. 97 TEST(ParseFlagsFromEnv, File) { 98 // environment variables where tmp dir may be specified. 99 static const char* kTempVars[] = {"TEST_TMPDIR", "TMP"}; 100 static const char kTempDir[] = "/tmp"; // default temp dir if all else fails. 101 const char* tmp_dir = nullptr; 102 for (int i = 0; i != TF_ARRAYSIZE(kTempVars) && tmp_dir == nullptr; i++) { 103 tmp_dir = getenv(kTempVars[i]); 104 } 105 if (tmp_dir == nullptr) { 106 tmp_dir = kTempDir; 107 } 108 string tmp_file = tensorflow::strings::Printf("%s/parse_flags_from_env.%d", 109 tmp_dir, getpid()); 110 FILE* fp = fopen(tmp_file.c_str(), "w"); 111 CHECK_NE(fp, nullptr) << "can't write to " << tmp_file; 112 for (int i = 0; kTestFlagString[i] != '\0'; i++) { 113 putc(kTestFlagString[i], fp); 114 } 115 fflush(fp); 116 CHECK_EQ(ferror(fp), 0) << "writes failed to " << tmp_file; 117 fclose(fp); 118 // Prepare environment. 119 setenv("TF_XLA_FLAGS", tmp_file.c_str(), true /*overwrite*/); 120 TestParseFlagsFromEnv("(flags in file)"); 121 unlink(tmp_file.c_str()); 122 } 123 124 // Name of the test binary. 125 static const char* binary_name; 126 127 // Test that when we use both the environment variable and actual 128 // commend line flags (when the latter is possible), the latter win. 129 TEST(ParseFlagsFromEnv, EnvAndFlag) { 130 static struct { 131 const char* env; 132 const char* arg; 133 const char* expected_value; 134 } test[] = { 135 {nullptr, nullptr, "1\n"}, 136 {nullptr, "--int_flag=2", "2\n"}, 137 {"--int_flag=3", nullptr, "3\n"}, 138 {"--int_flag=3", "--int_flag=2", "2\n"}, // flag beats environment 139 }; 140 for (int i = 0; i != TF_ARRAYSIZE(test); i++) { 141 if (test[i].env != nullptr) { 142 setenv("TF_XLA_FLAGS", test[i].env, true /*overwrite*/); 143 } 144 tensorflow::SubProcess child; 145 std::vector<string> argv; 146 argv.push_back(binary_name); 147 argv.push_back("--recursing"); 148 if (test[i].arg != nullptr) { 149 argv.push_back(test[i].arg); 150 } 151 child.SetProgram(binary_name, argv); 152 child.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE); 153 CHECK(child.Start()) << "test " << i; 154 string stdout_str; 155 int child_status = child.Communicate(nullptr, &stdout_str, nullptr); 156 CHECK_EQ(child_status, 0) << "test " << i; 157 CHECK_EQ(stdout_str, test[i].expected_value) << "test " << i; 158 } 159 } 160 161 } // namespace legacy_flags 162 } // namespace xla 163 164 int main(int argc, char* argv[]) { 165 // Save name of binary so that it may invoke itself. 166 xla::legacy_flags::binary_name = argv[0]; 167 bool recursing = false; 168 xla::int32 int_flag = 1; 169 const std::vector<tensorflow::Flag> flag_list = { 170 tensorflow::Flag("recursing", &recursing, 171 "Whether the binary is being invoked recusively."), 172 tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"), 173 }; 174 xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); 175 bool parse_ok = xla::legacy_flags::ParseFlagsFromEnv(flag_list); 176 if (!parse_ok) { 177 LOG(QFATAL) << "can't parse from environment\n" << usage; 178 } 179 parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); 180 if (!parse_ok) { 181 LOG(QFATAL) << usage; 182 } 183 if (recursing) { 184 printf("%d\n", int_flag); 185 exit(0); 186 } 187 testing::InitGoogleTest(&argc, argv); 188 return RUN_ALL_TESTS(); 189 } 190