Home | History | Annotate | Download | only in cli
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Unit tests for the shared functions and classes for tfdbg CLI."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from collections import namedtuple
     21 
     22 from tensorflow.python.debug.cli import cli_shared
     23 from tensorflow.python.debug.cli import debugger_cli_common
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import errors
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import sparse_tensor
     28 from tensorflow.python.framework import test_util
     29 from tensorflow.python.ops import variables
     30 from tensorflow.python.platform import googletest
     31 
     32 
     33 class BytesToReadableStrTest(test_util.TensorFlowTestCase):
     34 
     35   def testNoneSizeWorks(self):
     36     self.assertEqual(str(None), cli_shared.bytes_to_readable_str(None))
     37 
     38   def testSizesBelowOneKiloByteWorks(self):
     39     self.assertEqual("0", cli_shared.bytes_to_readable_str(0))
     40     self.assertEqual("500", cli_shared.bytes_to_readable_str(500))
     41     self.assertEqual("1023", cli_shared.bytes_to_readable_str(1023))
     42 
     43   def testSizesBetweenOneKiloByteandOneMegaByteWorks(self):
     44     self.assertEqual("1.00k", cli_shared.bytes_to_readable_str(1024))
     45     self.assertEqual("2.40k", cli_shared.bytes_to_readable_str(int(1024 * 2.4)))
     46     self.assertEqual("1023.00k", cli_shared.bytes_to_readable_str(1024 * 1023))
     47 
     48   def testSizesBetweenOneMegaByteandOneGigaByteWorks(self):
     49     self.assertEqual("1.00M", cli_shared.bytes_to_readable_str(1024**2))
     50     self.assertEqual("2.40M",
     51                      cli_shared.bytes_to_readable_str(int(1024**2 * 2.4)))
     52     self.assertEqual("1023.00M",
     53                      cli_shared.bytes_to_readable_str(1024**2 * 1023))
     54 
     55   def testSizeAboveOneGigaByteWorks(self):
     56     self.assertEqual("1.00G", cli_shared.bytes_to_readable_str(1024**3))
     57     self.assertEqual("2000.00G",
     58                      cli_shared.bytes_to_readable_str(1024**3 * 2000))
     59 
     60   def testReadableStrIncludesBAtTheEndOnRequest(self):
     61     self.assertEqual("0B", cli_shared.bytes_to_readable_str(0, include_b=True))
     62     self.assertEqual(
     63         "1.00kB", cli_shared.bytes_to_readable_str(
     64             1024, include_b=True))
     65     self.assertEqual(
     66         "1.00MB", cli_shared.bytes_to_readable_str(
     67             1024**2, include_b=True))
     68     self.assertEqual(
     69         "1.00GB", cli_shared.bytes_to_readable_str(
     70             1024**3, include_b=True))
     71 
     72 
     73 class TimeToReadableStrTest(test_util.TensorFlowTestCase):
     74 
     75   def testNoneTimeWorks(self):
     76     self.assertEqual("0", cli_shared.time_to_readable_str(None))
     77 
     78   def testMicrosecondsTime(self):
     79     self.assertEqual("40us", cli_shared.time_to_readable_str(40))
     80 
     81   def testMillisecondTime(self):
     82     self.assertEqual("40ms", cli_shared.time_to_readable_str(40e3))
     83 
     84   def testSecondTime(self):
     85     self.assertEqual("40s", cli_shared.time_to_readable_str(40e6))
     86 
     87   def testForceTimeUnit(self):
     88     self.assertEqual("40s",
     89                      cli_shared.time_to_readable_str(
     90                          40e6, force_time_unit=cli_shared.TIME_UNIT_S))
     91     self.assertEqual("40000ms",
     92                      cli_shared.time_to_readable_str(
     93                          40e6, force_time_unit=cli_shared.TIME_UNIT_MS))
     94     self.assertEqual("40000000us",
     95                      cli_shared.time_to_readable_str(
     96                          40e6, force_time_unit=cli_shared.TIME_UNIT_US))
     97     self.assertEqual("4e-05s",
     98                      cli_shared.time_to_readable_str(
     99                          40, force_time_unit=cli_shared.TIME_UNIT_S))
    100     self.assertEqual("0",
    101                      cli_shared.time_to_readable_str(
    102                          0, force_time_unit=cli_shared.TIME_UNIT_S))
    103 
    104     with self.assertRaisesRegexp(ValueError, r"Invalid time unit: ks"):
    105       cli_shared.time_to_readable_str(100, force_time_unit="ks")
    106 
    107 
    108 class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
    109 
    110   def setUp(self):
    111     self.const_a = constant_op.constant(11.0, name="a")
    112     self.const_b = constant_op.constant(22.0, name="b")
    113     self.const_c = constant_op.constant(33.0, name="c")
    114 
    115     self.sparse_d = sparse_tensor.SparseTensor(
    116         indices=[[0, 0], [1, 1]], values=[1.0, 2.0], dense_shape=[3, 3])
    117 
    118   def tearDown(self):
    119     ops.reset_default_graph()
    120 
    121   def testSingleFetchNoFeeds(self):
    122     run_start_intro = cli_shared.get_run_start_intro(12, self.const_a, None, {})
    123 
    124     # Verify line about run() call number.
    125     self.assertTrue(run_start_intro.lines[1].endswith("run() call #12:"))
    126 
    127     # Verify line about fetch.
    128     const_a_name_line = run_start_intro.lines[4]
    129     self.assertEqual(self.const_a.name, const_a_name_line.strip())
    130 
    131     # Verify line about feeds.
    132     feeds_line = run_start_intro.lines[7]
    133     self.assertEqual("(Empty)", feeds_line.strip())
    134 
    135     # Verify lines about possible commands and their font attributes.
    136     self.assertEqual("run:", run_start_intro.lines[11][2:])
    137     annot = run_start_intro.font_attr_segs[11][0]
    138     self.assertEqual(2, annot[0])
    139     self.assertEqual(5, annot[1])
    140     self.assertEqual("run", annot[2][0].content)
    141     self.assertEqual("bold", annot[2][1])
    142     annot = run_start_intro.font_attr_segs[13][0]
    143     self.assertEqual(2, annot[0])
    144     self.assertEqual(8, annot[1])
    145     self.assertEqual("run -n", annot[2][0].content)
    146     self.assertEqual("bold", annot[2][1])
    147     self.assertEqual("run -t <T>:", run_start_intro.lines[15][2:])
    148     self.assertEqual([(2, 12, "bold")], run_start_intro.font_attr_segs[15])
    149     self.assertEqual("run -f <filter_name>:", run_start_intro.lines[17][2:])
    150     self.assertEqual([(2, 22, "bold")], run_start_intro.font_attr_segs[17])
    151     annot = run_start_intro.font_attr_segs[21][0]
    152     self.assertEqual(2, annot[0])
    153     self.assertEqual(16, annot[1])
    154     self.assertEqual("invoke_stepper", annot[2][0].content)
    155 
    156     # Verify short description.
    157     description = cli_shared.get_run_short_description(12, self.const_a, None)
    158     self.assertEqual("run #12: 1 fetch (a:0); 0 feeds", description)
    159 
    160     # Verify the main menu associated with the run_start_intro.
    161     self.assertIn(debugger_cli_common.MAIN_MENU_KEY,
    162                   run_start_intro.annotations)
    163     menu = run_start_intro.annotations[debugger_cli_common.MAIN_MENU_KEY]
    164     self.assertEqual("run", menu.caption_to_item("run").content)
    165     self.assertEqual("invoke_stepper",
    166                      menu.caption_to_item("invoke_stepper").content)
    167     self.assertEqual("exit", menu.caption_to_item("exit").content)
    168 
    169   def testSparseTensorAsFeedShouldHandleNoNameAttribute(self):
    170     sparse_feed_val = ([[0, 0], [1, 1]], [10.0, 20.0])
    171     run_start_intro = cli_shared.get_run_start_intro(
    172         1, self.sparse_d, {self.sparse_d: sparse_feed_val}, {})
    173     self.assertEqual(str(self.sparse_d), run_start_intro.lines[7].strip())
    174 
    175     short_description = cli_shared.get_run_short_description(
    176         1, self.sparse_d, {self.sparse_d: sparse_feed_val})
    177     self.assertEqual(
    178         "run #1: 1 fetch; 1 feed (%s)" % self.sparse_d, short_description)
    179 
    180   def testSparseTensorAsFetchShouldHandleNoNameAttribute(self):
    181     run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {})
    182     self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip())
    183 
    184   def testTwoFetchesListNoFeeds(self):
    185     fetches = [self.const_a, self.const_b]
    186     run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
    187 
    188     const_a_name_line = run_start_intro.lines[4]
    189     const_b_name_line = run_start_intro.lines[5]
    190     self.assertEqual(self.const_a.name, const_a_name_line.strip())
    191     self.assertEqual(self.const_b.name, const_b_name_line.strip())
    192 
    193     feeds_line = run_start_intro.lines[8]
    194     self.assertEqual("(Empty)", feeds_line.strip())
    195 
    196     # Verify short description.
    197     description = cli_shared.get_run_short_description(1, fetches, None)
    198     self.assertEqual("run #1: 2 fetches; 0 feeds", description)
    199 
    200   def testNestedListAsFetches(self):
    201     fetches = [self.const_c, [self.const_a, self.const_b]]
    202     run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
    203 
    204     # Verify lines about the fetches.
    205     self.assertEqual(self.const_c.name, run_start_intro.lines[4].strip())
    206     self.assertEqual(self.const_a.name, run_start_intro.lines[5].strip())
    207     self.assertEqual(self.const_b.name, run_start_intro.lines[6].strip())
    208 
    209     # Verify short description.
    210     description = cli_shared.get_run_short_description(1, fetches, None)
    211     self.assertEqual("run #1: 3 fetches; 0 feeds", description)
    212 
    213   def testNestedDictAsFetches(self):
    214     fetches = {"c": self.const_c, "ab": {"a": self.const_a, "b": self.const_b}}
    215     run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
    216 
    217     # Verify lines about the fetches. The ordering of the dict keys is
    218     # indeterminate.
    219     fetch_names = set()
    220     fetch_names.add(run_start_intro.lines[4].strip())
    221     fetch_names.add(run_start_intro.lines[5].strip())
    222     fetch_names.add(run_start_intro.lines[6].strip())
    223 
    224     self.assertEqual({"a:0", "b:0", "c:0"}, fetch_names)
    225 
    226     # Verify short description.
    227     description = cli_shared.get_run_short_description(1, fetches, None)
    228     self.assertEqual("run #1: 3 fetches; 0 feeds", description)
    229 
    230   def testTwoFetchesAsTupleNoFeeds(self):
    231     fetches = (self.const_a, self.const_b)
    232     run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
    233 
    234     const_a_name_line = run_start_intro.lines[4]
    235     const_b_name_line = run_start_intro.lines[5]
    236     self.assertEqual(self.const_a.name, const_a_name_line.strip())
    237     self.assertEqual(self.const_b.name, const_b_name_line.strip())
    238 
    239     feeds_line = run_start_intro.lines[8]
    240     self.assertEqual("(Empty)", feeds_line.strip())
    241 
    242     # Verify short description.
    243     description = cli_shared.get_run_short_description(1, fetches, None)
    244     self.assertEqual("run #1: 2 fetches; 0 feeds", description)
    245 
    246   def testTwoFetchesAsNamedTupleNoFeeds(self):
    247     fetches_namedtuple = namedtuple("fetches", "x y")
    248     fetches = fetches_namedtuple(self.const_b, self.const_c)
    249     run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
    250 
    251     const_b_name_line = run_start_intro.lines[4]
    252     const_c_name_line = run_start_intro.lines[5]
    253     self.assertEqual(self.const_b.name, const_b_name_line.strip())
    254     self.assertEqual(self.const_c.name, const_c_name_line.strip())
    255 
    256     feeds_line = run_start_intro.lines[8]
    257     self.assertEqual("(Empty)", feeds_line.strip())
    258 
    259     # Verify short description.
    260     description = cli_shared.get_run_short_description(1, fetches, None)
    261     self.assertEqual("run #1: 2 fetches; 0 feeds", description)
    262 
    263   def testWithFeedDict(self):
    264     feed_dict = {
    265         self.const_a: 10.0,
    266         self.const_b: 20.0,
    267     }
    268 
    269     run_start_intro = cli_shared.get_run_start_intro(1, self.const_c, feed_dict,
    270                                                      {})
    271 
    272     const_c_name_line = run_start_intro.lines[4]
    273     self.assertEqual(self.const_c.name, const_c_name_line.strip())
    274 
    275     # Verify lines about the feed dict.
    276     feed_a_line = run_start_intro.lines[7]
    277     feed_b_line = run_start_intro.lines[8]
    278     self.assertEqual(self.const_a.name, feed_a_line.strip())
    279     self.assertEqual(self.const_b.name, feed_b_line.strip())
    280 
    281     # Verify short description.
    282     description = cli_shared.get_run_short_description(1, self.const_c,
    283                                                        feed_dict)
    284     self.assertEqual("run #1: 1 fetch (c:0); 2 feeds", description)
    285 
    286   def testTensorFilters(self):
    287     feed_dict = {self.const_a: 10.0}
    288     tensor_filters = {
    289         "filter_a": lambda x: True,
    290         "filter_b": lambda x: False,
    291     }
    292 
    293     run_start_intro = cli_shared.get_run_start_intro(1, self.const_c, feed_dict,
    294                                                      tensor_filters)
    295 
    296     # Verify the listed names of the tensor filters.
    297     filter_names = set()
    298     filter_names.add(run_start_intro.lines[20].split(" ")[-1])
    299     filter_names.add(run_start_intro.lines[21].split(" ")[-1])
    300 
    301     self.assertEqual({"filter_a", "filter_b"}, filter_names)
    302 
    303     # Verify short description.
    304     description = cli_shared.get_run_short_description(1, self.const_c,
    305                                                        feed_dict)
    306     self.assertEqual("run #1: 1 fetch (c:0); 1 feed (a:0)", description)
    307 
    308     # Verify the command links for the two filters.
    309     command_set = set()
    310     annot = run_start_intro.font_attr_segs[20][0]
    311     command_set.add(annot[2].content)
    312     annot = run_start_intro.font_attr_segs[21][0]
    313     command_set.add(annot[2].content)
    314     self.assertEqual({"run -f filter_a", "run -f filter_b"}, command_set)
    315 
    316   def testGetRunShortDescriptionWorksForTensorFeedKey(self):
    317     short_description = cli_shared.get_run_short_description(
    318         1, self.const_a, {self.const_a: 42.0})
    319     self.assertEqual("run #1: 1 fetch (a:0); 1 feed (a:0)", short_description)
    320 
    321   def testGetRunShortDescriptionWorksForUnicodeFeedKey(self):
    322     short_description = cli_shared.get_run_short_description(
    323         1, self.const_a, {u"foo": 42.0})
    324     self.assertEqual("run #1: 1 fetch (a:0); 1 feed (foo)", short_description)
    325 
    326 
    327 class GetErrorIntroTest(test_util.TensorFlowTestCase):
    328 
    329   def setUp(self):
    330     self.var_a = variables.Variable(42.0, name="a")
    331 
    332   def tearDown(self):
    333     ops.reset_default_graph()
    334 
    335   def testShapeError(self):
    336     tf_error = errors.OpError(None, self.var_a.initializer, "foo description",
    337                               None)
    338 
    339     error_intro = cli_shared.get_error_intro(tf_error)
    340 
    341     self.assertEqual("!!! An error occurred during the run !!!",
    342                      error_intro.lines[1])
    343     self.assertEqual([(0, len(error_intro.lines[1]), "blink")],
    344                      error_intro.font_attr_segs[1])
    345 
    346     self.assertEqual(2, error_intro.lines[4].index("ni -a -d -t a/Assign"))
    347     self.assertEqual(2, error_intro.font_attr_segs[4][0][0])
    348     self.assertEqual(22, error_intro.font_attr_segs[4][0][1])
    349     self.assertEqual("ni -a -d -t a/Assign",
    350                      error_intro.font_attr_segs[4][0][2][0].content)
    351     self.assertEqual("bold", error_intro.font_attr_segs[4][0][2][1])
    352 
    353     self.assertEqual(2, error_intro.lines[6].index("li -r a/Assign"))
    354     self.assertEqual(2, error_intro.font_attr_segs[6][0][0])
    355     self.assertEqual(16, error_intro.font_attr_segs[6][0][1])
    356     self.assertEqual("li -r a/Assign",
    357                      error_intro.font_attr_segs[6][0][2][0].content)
    358     self.assertEqual("bold", error_intro.font_attr_segs[6][0][2][1])
    359 
    360     self.assertEqual(2, error_intro.lines[8].index("lt"))
    361     self.assertEqual(2, error_intro.font_attr_segs[8][0][0])
    362     self.assertEqual(4, error_intro.font_attr_segs[8][0][1])
    363     self.assertEqual("lt", error_intro.font_attr_segs[8][0][2][0].content)
    364     self.assertEqual("bold", error_intro.font_attr_segs[8][0][2][1])
    365 
    366     self.assertStartsWith(error_intro.lines[11], "Op name:")
    367     self.assertTrue(error_intro.lines[11].endswith("a/Assign"))
    368 
    369     self.assertStartsWith(error_intro.lines[12], "Error type:")
    370     self.assertTrue(error_intro.lines[12].endswith(str(type(tf_error))))
    371 
    372     self.assertEqual("Details:", error_intro.lines[14])
    373     self.assertStartsWith(error_intro.lines[15], "foo description")
    374 
    375 
    376 if __name__ == "__main__":
    377   googletest.main()
    378