1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Unit tests for the shared functions and classes for tfdbg CLI.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from collections import namedtuple 21 22 from tensorflow.python.debug.cli import cli_shared 23 from tensorflow.python.debug.cli import debugger_cli_common 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import errors 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import sparse_tensor 28 from tensorflow.python.framework import test_util 29 from tensorflow.python.ops import variables 30 from tensorflow.python.platform import googletest 31 32 33 class BytesToReadableStrTest(test_util.TensorFlowTestCase): 34 35 def testNoneSizeWorks(self): 36 self.assertEqual(str(None), cli_shared.bytes_to_readable_str(None)) 37 38 def testSizesBelowOneKiloByteWorks(self): 39 self.assertEqual("0", cli_shared.bytes_to_readable_str(0)) 40 self.assertEqual("500", cli_shared.bytes_to_readable_str(500)) 41 self.assertEqual("1023", cli_shared.bytes_to_readable_str(1023)) 42 43 def testSizesBetweenOneKiloByteandOneMegaByteWorks(self): 44 self.assertEqual("1.00k", cli_shared.bytes_to_readable_str(1024)) 45 self.assertEqual("2.40k", cli_shared.bytes_to_readable_str(int(1024 * 2.4))) 46 self.assertEqual("1023.00k", cli_shared.bytes_to_readable_str(1024 * 1023)) 47 48 def testSizesBetweenOneMegaByteandOneGigaByteWorks(self): 49 self.assertEqual("1.00M", cli_shared.bytes_to_readable_str(1024**2)) 50 self.assertEqual("2.40M", 51 cli_shared.bytes_to_readable_str(int(1024**2 * 2.4))) 52 self.assertEqual("1023.00M", 53 cli_shared.bytes_to_readable_str(1024**2 * 1023)) 54 55 def testSizeAboveOneGigaByteWorks(self): 56 self.assertEqual("1.00G", cli_shared.bytes_to_readable_str(1024**3)) 57 self.assertEqual("2000.00G", 58 cli_shared.bytes_to_readable_str(1024**3 * 2000)) 59 60 def testReadableStrIncludesBAtTheEndOnRequest(self): 61 self.assertEqual("0B", cli_shared.bytes_to_readable_str(0, include_b=True)) 62 self.assertEqual( 63 "1.00kB", cli_shared.bytes_to_readable_str( 64 1024, include_b=True)) 65 self.assertEqual( 66 "1.00MB", cli_shared.bytes_to_readable_str( 67 1024**2, include_b=True)) 68 self.assertEqual( 69 "1.00GB", cli_shared.bytes_to_readable_str( 70 1024**3, include_b=True)) 71 72 73 class TimeToReadableStrTest(test_util.TensorFlowTestCase): 74 75 def testNoneTimeWorks(self): 76 self.assertEqual("0", cli_shared.time_to_readable_str(None)) 77 78 def testMicrosecondsTime(self): 79 self.assertEqual("40us", cli_shared.time_to_readable_str(40)) 80 81 def testMillisecondTime(self): 82 self.assertEqual("40ms", cli_shared.time_to_readable_str(40e3)) 83 84 def testSecondTime(self): 85 self.assertEqual("40s", cli_shared.time_to_readable_str(40e6)) 86 87 def testForceTimeUnit(self): 88 self.assertEqual("40s", 89 cli_shared.time_to_readable_str( 90 40e6, force_time_unit=cli_shared.TIME_UNIT_S)) 91 self.assertEqual("40000ms", 92 cli_shared.time_to_readable_str( 93 40e6, force_time_unit=cli_shared.TIME_UNIT_MS)) 94 self.assertEqual("40000000us", 95 cli_shared.time_to_readable_str( 96 40e6, force_time_unit=cli_shared.TIME_UNIT_US)) 97 self.assertEqual("4e-05s", 98 cli_shared.time_to_readable_str( 99 40, force_time_unit=cli_shared.TIME_UNIT_S)) 100 self.assertEqual("0", 101 cli_shared.time_to_readable_str( 102 0, force_time_unit=cli_shared.TIME_UNIT_S)) 103 104 with self.assertRaisesRegexp(ValueError, r"Invalid time unit: ks"): 105 cli_shared.time_to_readable_str(100, force_time_unit="ks") 106 107 108 class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase): 109 110 def setUp(self): 111 self.const_a = constant_op.constant(11.0, name="a") 112 self.const_b = constant_op.constant(22.0, name="b") 113 self.const_c = constant_op.constant(33.0, name="c") 114 115 self.sparse_d = sparse_tensor.SparseTensor( 116 indices=[[0, 0], [1, 1]], values=[1.0, 2.0], dense_shape=[3, 3]) 117 118 def tearDown(self): 119 ops.reset_default_graph() 120 121 def testSingleFetchNoFeeds(self): 122 run_start_intro = cli_shared.get_run_start_intro(12, self.const_a, None, {}) 123 124 # Verify line about run() call number. 125 self.assertTrue(run_start_intro.lines[1].endswith("run() call #12:")) 126 127 # Verify line about fetch. 128 const_a_name_line = run_start_intro.lines[4] 129 self.assertEqual(self.const_a.name, const_a_name_line.strip()) 130 131 # Verify line about feeds. 132 feeds_line = run_start_intro.lines[7] 133 self.assertEqual("(Empty)", feeds_line.strip()) 134 135 # Verify lines about possible commands and their font attributes. 136 self.assertEqual("run:", run_start_intro.lines[11][2:]) 137 annot = run_start_intro.font_attr_segs[11][0] 138 self.assertEqual(2, annot[0]) 139 self.assertEqual(5, annot[1]) 140 self.assertEqual("run", annot[2][0].content) 141 self.assertEqual("bold", annot[2][1]) 142 annot = run_start_intro.font_attr_segs[13][0] 143 self.assertEqual(2, annot[0]) 144 self.assertEqual(8, annot[1]) 145 self.assertEqual("run -n", annot[2][0].content) 146 self.assertEqual("bold", annot[2][1]) 147 self.assertEqual("run -t <T>:", run_start_intro.lines[15][2:]) 148 self.assertEqual([(2, 12, "bold")], run_start_intro.font_attr_segs[15]) 149 self.assertEqual("run -f <filter_name>:", run_start_intro.lines[17][2:]) 150 self.assertEqual([(2, 22, "bold")], run_start_intro.font_attr_segs[17]) 151 annot = run_start_intro.font_attr_segs[21][0] 152 self.assertEqual(2, annot[0]) 153 self.assertEqual(16, annot[1]) 154 self.assertEqual("invoke_stepper", annot[2][0].content) 155 156 # Verify short description. 157 description = cli_shared.get_run_short_description(12, self.const_a, None) 158 self.assertEqual("run #12: 1 fetch (a:0); 0 feeds", description) 159 160 # Verify the main menu associated with the run_start_intro. 161 self.assertIn(debugger_cli_common.MAIN_MENU_KEY, 162 run_start_intro.annotations) 163 menu = run_start_intro.annotations[debugger_cli_common.MAIN_MENU_KEY] 164 self.assertEqual("run", menu.caption_to_item("run").content) 165 self.assertEqual("invoke_stepper", 166 menu.caption_to_item("invoke_stepper").content) 167 self.assertEqual("exit", menu.caption_to_item("exit").content) 168 169 def testSparseTensorAsFeedShouldHandleNoNameAttribute(self): 170 sparse_feed_val = ([[0, 0], [1, 1]], [10.0, 20.0]) 171 run_start_intro = cli_shared.get_run_start_intro( 172 1, self.sparse_d, {self.sparse_d: sparse_feed_val}, {}) 173 self.assertEqual(str(self.sparse_d), run_start_intro.lines[7].strip()) 174 175 short_description = cli_shared.get_run_short_description( 176 1, self.sparse_d, {self.sparse_d: sparse_feed_val}) 177 self.assertEqual( 178 "run #1: 1 fetch; 1 feed (%s)" % self.sparse_d, short_description) 179 180 def testSparseTensorAsFetchShouldHandleNoNameAttribute(self): 181 run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {}) 182 self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip()) 183 184 def testTwoFetchesListNoFeeds(self): 185 fetches = [self.const_a, self.const_b] 186 run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {}) 187 188 const_a_name_line = run_start_intro.lines[4] 189 const_b_name_line = run_start_intro.lines[5] 190 self.assertEqual(self.const_a.name, const_a_name_line.strip()) 191 self.assertEqual(self.const_b.name, const_b_name_line.strip()) 192 193 feeds_line = run_start_intro.lines[8] 194 self.assertEqual("(Empty)", feeds_line.strip()) 195 196 # Verify short description. 197 description = cli_shared.get_run_short_description(1, fetches, None) 198 self.assertEqual("run #1: 2 fetches; 0 feeds", description) 199 200 def testNestedListAsFetches(self): 201 fetches = [self.const_c, [self.const_a, self.const_b]] 202 run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {}) 203 204 # Verify lines about the fetches. 205 self.assertEqual(self.const_c.name, run_start_intro.lines[4].strip()) 206 self.assertEqual(self.const_a.name, run_start_intro.lines[5].strip()) 207 self.assertEqual(self.const_b.name, run_start_intro.lines[6].strip()) 208 209 # Verify short description. 210 description = cli_shared.get_run_short_description(1, fetches, None) 211 self.assertEqual("run #1: 3 fetches; 0 feeds", description) 212 213 def testNestedDictAsFetches(self): 214 fetches = {"c": self.const_c, "ab": {"a": self.const_a, "b": self.const_b}} 215 run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {}) 216 217 # Verify lines about the fetches. The ordering of the dict keys is 218 # indeterminate. 219 fetch_names = set() 220 fetch_names.add(run_start_intro.lines[4].strip()) 221 fetch_names.add(run_start_intro.lines[5].strip()) 222 fetch_names.add(run_start_intro.lines[6].strip()) 223 224 self.assertEqual({"a:0", "b:0", "c:0"}, fetch_names) 225 226 # Verify short description. 227 description = cli_shared.get_run_short_description(1, fetches, None) 228 self.assertEqual("run #1: 3 fetches; 0 feeds", description) 229 230 def testTwoFetchesAsTupleNoFeeds(self): 231 fetches = (self.const_a, self.const_b) 232 run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {}) 233 234 const_a_name_line = run_start_intro.lines[4] 235 const_b_name_line = run_start_intro.lines[5] 236 self.assertEqual(self.const_a.name, const_a_name_line.strip()) 237 self.assertEqual(self.const_b.name, const_b_name_line.strip()) 238 239 feeds_line = run_start_intro.lines[8] 240 self.assertEqual("(Empty)", feeds_line.strip()) 241 242 # Verify short description. 243 description = cli_shared.get_run_short_description(1, fetches, None) 244 self.assertEqual("run #1: 2 fetches; 0 feeds", description) 245 246 def testTwoFetchesAsNamedTupleNoFeeds(self): 247 fetches_namedtuple = namedtuple("fetches", "x y") 248 fetches = fetches_namedtuple(self.const_b, self.const_c) 249 run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {}) 250 251 const_b_name_line = run_start_intro.lines[4] 252 const_c_name_line = run_start_intro.lines[5] 253 self.assertEqual(self.const_b.name, const_b_name_line.strip()) 254 self.assertEqual(self.const_c.name, const_c_name_line.strip()) 255 256 feeds_line = run_start_intro.lines[8] 257 self.assertEqual("(Empty)", feeds_line.strip()) 258 259 # Verify short description. 260 description = cli_shared.get_run_short_description(1, fetches, None) 261 self.assertEqual("run #1: 2 fetches; 0 feeds", description) 262 263 def testWithFeedDict(self): 264 feed_dict = { 265 self.const_a: 10.0, 266 self.const_b: 20.0, 267 } 268 269 run_start_intro = cli_shared.get_run_start_intro(1, self.const_c, feed_dict, 270 {}) 271 272 const_c_name_line = run_start_intro.lines[4] 273 self.assertEqual(self.const_c.name, const_c_name_line.strip()) 274 275 # Verify lines about the feed dict. 276 feed_a_line = run_start_intro.lines[7] 277 feed_b_line = run_start_intro.lines[8] 278 self.assertEqual(self.const_a.name, feed_a_line.strip()) 279 self.assertEqual(self.const_b.name, feed_b_line.strip()) 280 281 # Verify short description. 282 description = cli_shared.get_run_short_description(1, self.const_c, 283 feed_dict) 284 self.assertEqual("run #1: 1 fetch (c:0); 2 feeds", description) 285 286 def testTensorFilters(self): 287 feed_dict = {self.const_a: 10.0} 288 tensor_filters = { 289 "filter_a": lambda x: True, 290 "filter_b": lambda x: False, 291 } 292 293 run_start_intro = cli_shared.get_run_start_intro(1, self.const_c, feed_dict, 294 tensor_filters) 295 296 # Verify the listed names of the tensor filters. 297 filter_names = set() 298 filter_names.add(run_start_intro.lines[20].split(" ")[-1]) 299 filter_names.add(run_start_intro.lines[21].split(" ")[-1]) 300 301 self.assertEqual({"filter_a", "filter_b"}, filter_names) 302 303 # Verify short description. 304 description = cli_shared.get_run_short_description(1, self.const_c, 305 feed_dict) 306 self.assertEqual("run #1: 1 fetch (c:0); 1 feed (a:0)", description) 307 308 # Verify the command links for the two filters. 309 command_set = set() 310 annot = run_start_intro.font_attr_segs[20][0] 311 command_set.add(annot[2].content) 312 annot = run_start_intro.font_attr_segs[21][0] 313 command_set.add(annot[2].content) 314 self.assertEqual({"run -f filter_a", "run -f filter_b"}, command_set) 315 316 def testGetRunShortDescriptionWorksForTensorFeedKey(self): 317 short_description = cli_shared.get_run_short_description( 318 1, self.const_a, {self.const_a: 42.0}) 319 self.assertEqual("run #1: 1 fetch (a:0); 1 feed (a:0)", short_description) 320 321 def testGetRunShortDescriptionWorksForUnicodeFeedKey(self): 322 short_description = cli_shared.get_run_short_description( 323 1, self.const_a, {u"foo": 42.0}) 324 self.assertEqual("run #1: 1 fetch (a:0); 1 feed (foo)", short_description) 325 326 327 class GetErrorIntroTest(test_util.TensorFlowTestCase): 328 329 def setUp(self): 330 self.var_a = variables.Variable(42.0, name="a") 331 332 def tearDown(self): 333 ops.reset_default_graph() 334 335 def testShapeError(self): 336 tf_error = errors.OpError(None, self.var_a.initializer, "foo description", 337 None) 338 339 error_intro = cli_shared.get_error_intro(tf_error) 340 341 self.assertEqual("!!! An error occurred during the run !!!", 342 error_intro.lines[1]) 343 self.assertEqual([(0, len(error_intro.lines[1]), "blink")], 344 error_intro.font_attr_segs[1]) 345 346 self.assertEqual(2, error_intro.lines[4].index("ni -a -d -t a/Assign")) 347 self.assertEqual(2, error_intro.font_attr_segs[4][0][0]) 348 self.assertEqual(22, error_intro.font_attr_segs[4][0][1]) 349 self.assertEqual("ni -a -d -t a/Assign", 350 error_intro.font_attr_segs[4][0][2][0].content) 351 self.assertEqual("bold", error_intro.font_attr_segs[4][0][2][1]) 352 353 self.assertEqual(2, error_intro.lines[6].index("li -r a/Assign")) 354 self.assertEqual(2, error_intro.font_attr_segs[6][0][0]) 355 self.assertEqual(16, error_intro.font_attr_segs[6][0][1]) 356 self.assertEqual("li -r a/Assign", 357 error_intro.font_attr_segs[6][0][2][0].content) 358 self.assertEqual("bold", error_intro.font_attr_segs[6][0][2][1]) 359 360 self.assertEqual(2, error_intro.lines[8].index("lt")) 361 self.assertEqual(2, error_intro.font_attr_segs[8][0][0]) 362 self.assertEqual(4, error_intro.font_attr_segs[8][0][1]) 363 self.assertEqual("lt", error_intro.font_attr_segs[8][0][2][0].content) 364 self.assertEqual("bold", error_intro.font_attr_segs[8][0][2][1]) 365 366 self.assertStartsWith(error_intro.lines[11], "Op name:") 367 self.assertTrue(error_intro.lines[11].endswith("a/Assign")) 368 369 self.assertStartsWith(error_intro.lines[12], "Error type:") 370 self.assertTrue(error_intro.lines[12].endswith(str(type(tf_error)))) 371 372 self.assertEqual("Details:", error_intro.lines[14]) 373 self.assertStartsWith(error_intro.lines[15], "foo description") 374 375 376 if __name__ == "__main__": 377 googletest.main() 378