Home | History | Annotate | Download | only in lib
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Unit tests for the basic data structures and algorithms for profiling."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.core.framework import step_stats_pb2
     22 from tensorflow.python.debug.lib import profiling
     23 from tensorflow.python.framework import test_util
     24 from tensorflow.python.platform import googletest
     25 
     26 
     27 class AggregateProfile(test_util.TensorFlowTestCase):
     28 
     29   def setUp(self):
     30     node_1 = step_stats_pb2.NodeExecStats(
     31         node_name="Add/123",
     32         op_start_rel_micros=3,
     33         op_end_rel_micros=5,
     34         all_end_rel_micros=4)
     35     self.profile_datum_1 = profiling.ProfileDatum(
     36         "cpu:0", node_1, "/foo/bar.py", 10, "func1", "Add")
     37 
     38     node_2 = step_stats_pb2.NodeExecStats(
     39         node_name="Mul/456",
     40         op_start_rel_micros=13,
     41         op_end_rel_micros=16,
     42         all_end_rel_micros=17)
     43     self.profile_datum_2 = profiling.ProfileDatum(
     44         "cpu:0", node_2, "/foo/bar.py", 11, "func1", "Mul")
     45 
     46     node_3 = step_stats_pb2.NodeExecStats(
     47         node_name="Add/123",
     48         op_start_rel_micros=103,
     49         op_end_rel_micros=105,
     50         all_end_rel_micros=4)
     51     self.profile_datum_3 = profiling.ProfileDatum(
     52         "cpu:0", node_3, "/foo/bar.py", 12, "func1", "Add")
     53 
     54     node_4 = step_stats_pb2.NodeExecStats(
     55         node_name="Add/123",
     56         op_start_rel_micros=203,
     57         op_end_rel_micros=205,
     58         all_end_rel_micros=4)
     59     self.profile_datum_4 = profiling.ProfileDatum(
     60         "gpu:0", node_4, "/foo/bar.py", 13, "func1", "Add")
     61 
     62   def testAggregateProfileConstructorWorks(self):
     63     aggregate_data = profiling.AggregateProfile(self.profile_datum_1)
     64 
     65     self.assertEqual(2, aggregate_data.total_op_time)
     66     self.assertEqual(4, aggregate_data.total_exec_time)
     67     self.assertEqual(1, aggregate_data.node_count)
     68     self.assertEqual(1, aggregate_data.node_exec_count)
     69 
     70   def testAddToAggregateProfileWithDifferentNodeWorks(self):
     71     aggregate_data = profiling.AggregateProfile(self.profile_datum_1)
     72     aggregate_data.add(self.profile_datum_2)
     73 
     74     self.assertEqual(5, aggregate_data.total_op_time)
     75     self.assertEqual(21, aggregate_data.total_exec_time)
     76     self.assertEqual(2, aggregate_data.node_count)
     77     self.assertEqual(2, aggregate_data.node_exec_count)
     78 
     79   def testAddToAggregateProfileWithSameNodeWorks(self):
     80     aggregate_data = profiling.AggregateProfile(self.profile_datum_1)
     81     aggregate_data.add(self.profile_datum_2)
     82     aggregate_data.add(self.profile_datum_3)
     83 
     84     self.assertEqual(7, aggregate_data.total_op_time)
     85     self.assertEqual(25, aggregate_data.total_exec_time)
     86     self.assertEqual(2, aggregate_data.node_count)
     87     self.assertEqual(3, aggregate_data.node_exec_count)
     88 
     89   def testAddToAggregateProfileWithDifferentDeviceSameNodeWorks(self):
     90     aggregate_data = profiling.AggregateProfile(self.profile_datum_1)
     91     aggregate_data.add(self.profile_datum_4)
     92 
     93     self.assertEqual(4, aggregate_data.total_op_time)
     94     self.assertEqual(8, aggregate_data.total_exec_time)
     95     self.assertEqual(2, aggregate_data.node_count)
     96     self.assertEqual(2, aggregate_data.node_exec_count)
     97 
     98 
     99 if __name__ == "__main__":
    100   googletest.main()
    101