Home | History | Annotate | Download | only in cli
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for cli_config."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import json
     21 import os
     22 import shutil
     23 import tempfile
     24 
     25 from tensorflow.python.debug.cli import cli_config
     26 from tensorflow.python.framework import test_util
     27 from tensorflow.python.platform import gfile
     28 from tensorflow.python.platform import googletest
     29 
     30 
     31 class CLIConfigTest(test_util.TensorFlowTestCase):
     32 
     33   def setUp(self):
     34     self._tmp_dir = tempfile.mkdtemp()
     35     self._tmp_config_path = os.path.join(self._tmp_dir, ".tfdbg_config")
     36     self.assertFalse(gfile.Exists(self._tmp_config_path))
     37     super(CLIConfigTest, self).setUp()
     38 
     39   def tearDown(self):
     40     shutil.rmtree(self._tmp_dir)
     41     super(CLIConfigTest, self).tearDown()
     42 
     43   def testConstructCLIConfigWithoutFile(self):
     44     config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
     45     self.assertEqual(20, config.get("graph_recursion_depth"))
     46     self.assertEqual(True, config.get("mouse_mode"))
     47     with self.assertRaises(KeyError):
     48       config.get("property_that_should_not_exist")
     49     self.assertTrue(gfile.Exists(self._tmp_config_path))
     50 
     51   def testCLIConfigForwardCompatibilityTest(self):
     52     config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
     53     with open(self._tmp_config_path, "rt") as f:
     54       config_json = json.load(f)
     55     # Remove a field to simulate forward compatibility test.
     56     del config_json["graph_recursion_depth"]
     57     with open(self._tmp_config_path, "wt") as f:
     58       json.dump(config_json, f)
     59 
     60     config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
     61     self.assertEqual(20, config.get("graph_recursion_depth"))
     62 
     63   def testModifyConfigValue(self):
     64     config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
     65     config.set("graph_recursion_depth", 9)
     66     config.set("mouse_mode", False)
     67     self.assertEqual(9, config.get("graph_recursion_depth"))
     68     self.assertEqual(False, config.get("mouse_mode"))
     69 
     70   def testModifyConfigValueWithTypeCasting(self):
     71     config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
     72     config.set("graph_recursion_depth", "18")
     73     config.set("mouse_mode", "false")
     74     self.assertEqual(18, config.get("graph_recursion_depth"))
     75     self.assertEqual(False, config.get("mouse_mode"))
     76 
     77   def testModifyConfigValueWithTypeCastingFailure(self):
     78     config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
     79     with self.assertRaises(ValueError):
     80       config.set("mouse_mode", "maybe")
     81 
     82   def testLoadFromModifiedConfigFile(self):
     83     config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
     84     config.set("graph_recursion_depth", 9)
     85     config.set("mouse_mode", False)
     86     config2 = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
     87     self.assertEqual(9, config2.get("graph_recursion_depth"))
     88     self.assertEqual(False, config2.get("mouse_mode"))
     89 
     90   def testSummarizeFromConfig(self):
     91     config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
     92     output = config.summarize()
     93     self.assertEqual(
     94         ["Command-line configuration:",
     95          "",
     96          "  graph_recursion_depth: %d" % config.get("graph_recursion_depth"),
     97          "  mouse_mode: %s" % config.get("mouse_mode")], output.lines)
     98 
     99   def testSummarizeFromConfigWithHighlight(self):
    100     config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
    101     output = config.summarize(highlight="mouse_mode")
    102     self.assertEqual(
    103         ["Command-line configuration:",
    104          "",
    105          "  graph_recursion_depth: %d" % config.get("graph_recursion_depth"),
    106          "  mouse_mode: %s" % config.get("mouse_mode")], output.lines)
    107     self.assertEqual((2, 12, ["underline", "bold"]),
    108                      output.font_attr_segs[3][0])
    109     self.assertEqual((14, 18, "bold"), output.font_attr_segs[3][1])
    110 
    111   def testSetCallback(self):
    112     config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
    113 
    114     test_value = {"graph_recursion_depth": -1}
    115     def callback(config):
    116       test_value["graph_recursion_depth"] = config.get("graph_recursion_depth")
    117     config.set_callback("graph_recursion_depth", callback)
    118 
    119     config.set("graph_recursion_depth", config.get("graph_recursion_depth") - 1)
    120     self.assertEqual(test_value["graph_recursion_depth"],
    121                      config.get("graph_recursion_depth"))
    122 
    123   def testSetCallbackInvalidPropertyName(self):
    124     config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
    125 
    126     with self.assertRaises(KeyError):
    127       config.set_callback("nonexistent_property_name", print)
    128 
    129   def testSetCallbackNotCallable(self):
    130     config = cli_config.CLIConfig(config_file_path=self._tmp_config_path)
    131 
    132     with self.assertRaises(TypeError):
    133       config.set_callback("graph_recursion_depth", 1)
    134 
    135 
    136 if __name__ == "__main__":
    137   googletest.main()
    138