Home | History | Annotate | Download | only in lib
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tfdbg module debug_data."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.debug.lib import debug_graphs
     21 from tensorflow.python.framework import test_util
     22 from tensorflow.python.platform import test
     23 
     24 
     25 class ParseNodeOrTensorNameTest(test_util.TensorFlowTestCase):
     26 
     27   def testParseNodeName(self):
     28     node_name, slot = debug_graphs.parse_node_or_tensor_name(
     29         "namespace1/node_1")
     30 
     31     self.assertEqual("namespace1/node_1", node_name)
     32     self.assertIsNone(slot)
     33 
     34   def testParseTensorName(self):
     35     node_name, slot = debug_graphs.parse_node_or_tensor_name(
     36         "namespace1/node_2:3")
     37 
     38     self.assertEqual("namespace1/node_2", node_name)
     39     self.assertEqual(3, slot)
     40 
     41 
     42 class GetNodeNameAndOutputSlotTest(test_util.TensorFlowTestCase):
     43 
     44   def testParseTensorNameInputWorks(self):
     45     self.assertEqual("a", debug_graphs.get_node_name("a:0"))
     46     self.assertEqual(0, debug_graphs.get_output_slot("a:0"))
     47 
     48     self.assertEqual("_b", debug_graphs.get_node_name("_b:1"))
     49     self.assertEqual(1, debug_graphs.get_output_slot("_b:1"))
     50 
     51   def testParseNodeNameInputWorks(self):
     52     self.assertEqual("a", debug_graphs.get_node_name("a"))
     53     self.assertEqual(0, debug_graphs.get_output_slot("a"))
     54 
     55 
     56 class NodeNameChecksTest(test_util.TensorFlowTestCase):
     57 
     58   def testIsCopyNode(self):
     59     self.assertTrue(debug_graphs.is_copy_node("__copy_ns1/ns2/node3_0"))
     60 
     61     self.assertFalse(debug_graphs.is_copy_node("copy_ns1/ns2/node3_0"))
     62     self.assertFalse(debug_graphs.is_copy_node("_copy_ns1/ns2/node3_0"))
     63     self.assertFalse(debug_graphs.is_copy_node("_copyns1/ns2/node3_0"))
     64     self.assertFalse(debug_graphs.is_copy_node("__dbg_ns1/ns2/node3_0"))
     65 
     66   def testIsDebugNode(self):
     67     self.assertTrue(
     68         debug_graphs.is_debug_node("__dbg_ns1/ns2/node3:0_0_DebugIdentity"))
     69 
     70     self.assertFalse(
     71         debug_graphs.is_debug_node("dbg_ns1/ns2/node3:0_0_DebugIdentity"))
     72     self.assertFalse(
     73         debug_graphs.is_debug_node("_dbg_ns1/ns2/node3:0_0_DebugIdentity"))
     74     self.assertFalse(
     75         debug_graphs.is_debug_node("_dbgns1/ns2/node3:0_0_DebugIdentity"))
     76     self.assertFalse(debug_graphs.is_debug_node("__copy_ns1/ns2/node3_0"))
     77 
     78 
     79 class ParseDebugNodeNameTest(test_util.TensorFlowTestCase):
     80 
     81   def testParseDebugNodeName_valid(self):
     82     debug_node_name_1 = "__dbg_ns_a/ns_b/node_c:1_0_DebugIdentity"
     83     (watched_node, watched_output_slot, debug_op_index,
     84      debug_op) = debug_graphs.parse_debug_node_name(debug_node_name_1)
     85 
     86     self.assertEqual("ns_a/ns_b/node_c", watched_node)
     87     self.assertEqual(1, watched_output_slot)
     88     self.assertEqual(0, debug_op_index)
     89     self.assertEqual("DebugIdentity", debug_op)
     90 
     91   def testParseDebugNodeName_invalidPrefix(self):
     92     invalid_debug_node_name_1 = "__copy_ns_a/ns_b/node_c:1_0_DebugIdentity"
     93 
     94     with self.assertRaisesRegexp(ValueError, "Invalid prefix"):
     95       debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)
     96 
     97   def testParseDebugNodeName_missingDebugOpIndex(self):
     98     invalid_debug_node_name_1 = "__dbg_node1:0_DebugIdentity"
     99 
    100     with self.assertRaisesRegexp(ValueError, "Invalid debug node name"):
    101       debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)
    102 
    103   def testParseDebugNodeName_invalidWatchedTensorName(self):
    104     invalid_debug_node_name_1 = "__dbg_node1_0_DebugIdentity"
    105 
    106     with self.assertRaisesRegexp(ValueError,
    107                                  "Invalid tensor name in debug node name"):
    108       debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)
    109 
    110 
    111 if __name__ == "__main__":
    112   test.main()
    113