1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for cli_config.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import json 21 import os 22 import shutil 23 import tempfile 24 25 from tensorflow.python.debug.cli import cli_config 26 from tensorflow.python.framework import test_util 27 from tensorflow.python.platform import gfile 28 from tensorflow.python.platform import googletest 29 30 31 class CLIConfigTest(test_util.TensorFlowTestCase): 32 33 def setUp(self): 34 self._tmp_dir = tempfile.mkdtemp() 35 self._tmp_config_path = os.path.join(self._tmp_dir, ".tfdbg_config") 36 self.assertFalse(gfile.Exists(self._tmp_config_path)) 37 super(CLIConfigTest, self).setUp() 38 39 def tearDown(self): 40 shutil.rmtree(self._tmp_dir) 41 super(CLIConfigTest, self).tearDown() 42 43 def testConstructCLIConfigWithoutFile(self): 44 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 45 self.assertEqual(20, config.get("graph_recursion_depth")) 46 self.assertEqual(True, config.get("mouse_mode")) 47 with self.assertRaises(KeyError): 48 config.get("property_that_should_not_exist") 49 self.assertTrue(gfile.Exists(self._tmp_config_path)) 50 51 def testCLIConfigForwardCompatibilityTest(self): 52 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 53 with open(self._tmp_config_path, "rt") as f: 54 config_json = json.load(f) 55 # Remove a field to simulate forward compatibility test. 56 del config_json["graph_recursion_depth"] 57 with open(self._tmp_config_path, "wt") as f: 58 json.dump(config_json, f) 59 60 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 61 self.assertEqual(20, config.get("graph_recursion_depth")) 62 63 def testModifyConfigValue(self): 64 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 65 config.set("graph_recursion_depth", 9) 66 config.set("mouse_mode", False) 67 self.assertEqual(9, config.get("graph_recursion_depth")) 68 self.assertEqual(False, config.get("mouse_mode")) 69 70 def testModifyConfigValueWithTypeCasting(self): 71 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 72 config.set("graph_recursion_depth", "18") 73 config.set("mouse_mode", "false") 74 self.assertEqual(18, config.get("graph_recursion_depth")) 75 self.assertEqual(False, config.get("mouse_mode")) 76 77 def testModifyConfigValueWithTypeCastingFailure(self): 78 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 79 with self.assertRaises(ValueError): 80 config.set("mouse_mode", "maybe") 81 82 def testLoadFromModifiedConfigFile(self): 83 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 84 config.set("graph_recursion_depth", 9) 85 config.set("mouse_mode", False) 86 config2 = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 87 self.assertEqual(9, config2.get("graph_recursion_depth")) 88 self.assertEqual(False, config2.get("mouse_mode")) 89 90 def testSummarizeFromConfig(self): 91 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 92 output = config.summarize() 93 self.assertEqual( 94 ["Command-line configuration:", 95 "", 96 " graph_recursion_depth: %d" % config.get("graph_recursion_depth"), 97 " mouse_mode: %s" % config.get("mouse_mode")], output.lines) 98 99 def testSummarizeFromConfigWithHighlight(self): 100 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 101 output = config.summarize(highlight="mouse_mode") 102 self.assertEqual( 103 ["Command-line configuration:", 104 "", 105 " graph_recursion_depth: %d" % config.get("graph_recursion_depth"), 106 " mouse_mode: %s" % config.get("mouse_mode")], output.lines) 107 self.assertEqual((2, 12, ["underline", "bold"]), 108 output.font_attr_segs[3][0]) 109 self.assertEqual((14, 18, "bold"), output.font_attr_segs[3][1]) 110 111 def testSetCallback(self): 112 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 113 114 test_value = {"graph_recursion_depth": -1} 115 def callback(config): 116 test_value["graph_recursion_depth"] = config.get("graph_recursion_depth") 117 config.set_callback("graph_recursion_depth", callback) 118 119 config.set("graph_recursion_depth", config.get("graph_recursion_depth") - 1) 120 self.assertEqual(test_value["graph_recursion_depth"], 121 config.get("graph_recursion_depth")) 122 123 def testSetCallbackInvalidPropertyName(self): 124 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 125 126 with self.assertRaises(KeyError): 127 config.set_callback("nonexistent_property_name", print) 128 129 def testSetCallbackNotCallable(self): 130 config = cli_config.CLIConfig(config_file_path=self._tmp_config_path) 131 132 with self.assertRaises(TypeError): 133 config.set_callback("graph_recursion_depth", 1) 134 135 136 if __name__ == "__main__": 137 googletest.main() 138