Home | History | Annotate | Download | only in legacy_flags
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Test for parse_flags_from_env.cc
     17 
     18 #include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
     19 
     20 #include <stdio.h>
     21 #include <stdlib.h>
     22 #include <vector>
     23 
     24 #include "tensorflow/compiler/xla/types.h"
     25 #include "tensorflow/core/lib/strings/stringprintf.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 #include "tensorflow/core/platform/test.h"
     28 #include "tensorflow/core/platform/types.h"
     29 #include "tensorflow/core/util/command_line_flags.h"
     30 
     31 namespace xla {
     32 namespace legacy_flags {
     33 
     34 // Test that XLA flags can be set from the environment.
     35 // Failure messages are accompanied by the text in msg[].
     36 static void TestParseFlagsFromEnv(const char* msg) {
     37   // Initialize module under test.
     38   int* pargc;
     39   std::vector<char*>* pargv;
     40   ResetFlagsFromEnvForTesting(&pargc, &pargv);
     41 
     42   // Ensure that environment variable can be parsed when
     43   // no flags are expected.
     44   std::vector<tensorflow::Flag> empty_flag_list;
     45   bool parsed_ok = ParseFlagsFromEnv(empty_flag_list);
     46   CHECK(parsed_ok) << msg;
     47   const std::vector<char*>& argv_first = *pargv;
     48   CHECK_NE(argv_first[0], nullptr) << msg;
     49   int i = 0;
     50   while (argv_first[i] != nullptr) {
     51     i++;
     52   }
     53   CHECK_EQ(i, *pargc) << msg;
     54 
     55   // Check that actual flags can be parsed.
     56   bool simple = false;
     57   string with_value;
     58   string embedded_quotes;
     59   string single_quoted;
     60   string double_quoted;
     61   std::vector<tensorflow::Flag> flag_list = {
     62       tensorflow::Flag("simple", &simple, ""),
     63       tensorflow::Flag("with_value", &with_value, ""),
     64       tensorflow::Flag("embedded_quotes", &embedded_quotes, ""),
     65       tensorflow::Flag("single_quoted", &single_quoted, ""),
     66       tensorflow::Flag("double_quoted", &double_quoted, ""),
     67   };
     68   parsed_ok = ParseFlagsFromEnv(flag_list);
     69   CHECK_EQ(*pargc, 1) << msg;
     70   const std::vector<char*>& argv_second = *pargv;
     71   CHECK_NE(argv_second[0], nullptr) << msg;
     72   CHECK_EQ(argv_second[1], nullptr) << msg;
     73   CHECK(parsed_ok) << msg;
     74   CHECK(simple) << msg;
     75   CHECK_EQ(with_value, "a_value") << msg;
     76   CHECK_EQ(embedded_quotes, "single'double\"") << msg;
     77   CHECK_EQ(single_quoted, "single quoted \\\\ \n \"") << msg;
     78   CHECK_EQ(double_quoted, "double quoted \\ \n '\"") << msg;
     79 }
     80 
     81 // The flags settings to test.
     82 static const char kTestFlagString[] =
     83     "--simple "
     84     "--with_value=a_value "
     85     "--embedded_quotes=single'double\" "
     86     "--single_quoted='single quoted \\\\ \n \"' "
     87     "--double_quoted=\"double quoted \\\\ \n '\\\"\" ";
     88 
     89 // Test that the environent variable is parsed correctly.
     90 TEST(ParseFlagsFromEnv, Basic) {
     91   // Prepare environment.
     92   setenv("TF_XLA_FLAGS", kTestFlagString, true /*overwrite*/);
     93   TestParseFlagsFromEnv("(flags in environment variable)");
     94 }
     95 
     96 // Test that a file named by the environent variable is parsed correctly.
     97 TEST(ParseFlagsFromEnv, File) {
     98   // environment variables where  tmp dir may be specified.
     99   static const char* kTempVars[] = {"TEST_TMPDIR", "TMP"};
    100   static const char kTempDir[] = "/tmp";  // default temp dir if all else fails.
    101   const char* tmp_dir = nullptr;
    102   for (int i = 0; i != TF_ARRAYSIZE(kTempVars) && tmp_dir == nullptr; i++) {
    103     tmp_dir = getenv(kTempVars[i]);
    104   }
    105   if (tmp_dir == nullptr) {
    106     tmp_dir = kTempDir;
    107   }
    108   string tmp_file = tensorflow::strings::Printf("%s/parse_flags_from_env.%d",
    109                                                 tmp_dir, getpid());
    110   FILE* fp = fopen(tmp_file.c_str(), "w");
    111   CHECK_NE(fp, nullptr) << "can't write to " << tmp_file;
    112   for (int i = 0; kTestFlagString[i] != '\0'; i++) {
    113     putc(kTestFlagString[i], fp);
    114   }
    115   fflush(fp);
    116   CHECK_EQ(ferror(fp), 0) << "writes failed to " << tmp_file;
    117   fclose(fp);
    118   // Prepare environment.
    119   setenv("TF_XLA_FLAGS", tmp_file.c_str(), true /*overwrite*/);
    120   TestParseFlagsFromEnv("(flags in file)");
    121   unlink(tmp_file.c_str());
    122 }
    123 
    124 // Name of the test binary.
    125 static const char* binary_name;
    126 
    127 // Test that when we use both the environment variable and actual
    128 // commend line flags (when the latter is possible), the latter win.
    129 TEST(ParseFlagsFromEnv, EnvAndFlag) {
    130   static struct {
    131     const char* env;
    132     const char* arg;
    133     const char* expected_value;
    134   } test[] = {
    135       {nullptr, nullptr, "1\n"},
    136       {nullptr, "--int_flag=2", "2\n"},
    137       {"--int_flag=3", nullptr, "3\n"},
    138       {"--int_flag=3", "--int_flag=2", "2\n"},  // flag beats environment
    139   };
    140   for (int i = 0; i != TF_ARRAYSIZE(test); i++) {
    141     if (test[i].env != nullptr) {
    142       setenv("TF_XLA_FLAGS", test[i].env, true /*overwrite*/);
    143     }
    144     tensorflow::SubProcess child;
    145     std::vector<string> argv;
    146     argv.push_back(binary_name);
    147     argv.push_back("--recursing");
    148     if (test[i].arg != nullptr) {
    149       argv.push_back(test[i].arg);
    150     }
    151     child.SetProgram(binary_name, argv);
    152     child.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
    153     CHECK(child.Start()) << "test " << i;
    154     string stdout_str;
    155     int child_status = child.Communicate(nullptr, &stdout_str, nullptr);
    156     CHECK_EQ(child_status, 0) << "test " << i;
    157     CHECK_EQ(stdout_str, test[i].expected_value) << "test " << i;
    158   }
    159 }
    160 
    161 }  // namespace legacy_flags
    162 }  // namespace xla
    163 
    164 int main(int argc, char* argv[]) {
    165   // Save name of binary so that it may invoke itself.
    166   xla::legacy_flags::binary_name = argv[0];
    167   bool recursing = false;
    168   xla::int32 int_flag = 1;
    169   const std::vector<tensorflow::Flag> flag_list = {
    170       tensorflow::Flag("recursing", &recursing,
    171                        "Whether the binary is being invoked recusively."),
    172       tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"),
    173   };
    174   xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
    175   bool parse_ok = xla::legacy_flags::ParseFlagsFromEnv(flag_list);
    176   if (!parse_ok) {
    177     LOG(QFATAL) << "can't parse from environment\n" << usage;
    178   }
    179   parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
    180   if (!parse_ok) {
    181     LOG(QFATAL) << usage;
    182   }
    183   if (recursing) {
    184     printf("%d\n", int_flag);
    185     exit(0);
    186   }
    187   testing::InitGoogleTest(&argc, argv);
    188   return RUN_ALL_TESTS();
    189 }
    190