1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Unit tests for the basic data structures and algorithms for profiling.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.core.framework import step_stats_pb2 22 from tensorflow.python.debug.lib import profiling 23 from tensorflow.python.framework import test_util 24 from tensorflow.python.platform import googletest 25 26 27 class AggregateProfile(test_util.TensorFlowTestCase): 28 29 def setUp(self): 30 node_1 = step_stats_pb2.NodeExecStats( 31 node_name="Add/123", 32 op_start_rel_micros=3, 33 op_end_rel_micros=5, 34 all_end_rel_micros=4) 35 self.profile_datum_1 = profiling.ProfileDatum( 36 "cpu:0", node_1, "/foo/bar.py", 10, "func1", "Add") 37 38 node_2 = step_stats_pb2.NodeExecStats( 39 node_name="Mul/456", 40 op_start_rel_micros=13, 41 op_end_rel_micros=16, 42 all_end_rel_micros=17) 43 self.profile_datum_2 = profiling.ProfileDatum( 44 "cpu:0", node_2, "/foo/bar.py", 11, "func1", "Mul") 45 46 node_3 = step_stats_pb2.NodeExecStats( 47 node_name="Add/123", 48 op_start_rel_micros=103, 49 op_end_rel_micros=105, 50 all_end_rel_micros=4) 51 self.profile_datum_3 = profiling.ProfileDatum( 52 "cpu:0", node_3, "/foo/bar.py", 12, "func1", "Add") 53 54 node_4 = step_stats_pb2.NodeExecStats( 55 node_name="Add/123", 56 op_start_rel_micros=203, 57 op_end_rel_micros=205, 58 all_end_rel_micros=4) 59 self.profile_datum_4 = profiling.ProfileDatum( 60 "gpu:0", node_4, "/foo/bar.py", 13, "func1", "Add") 61 62 def testAggregateProfileConstructorWorks(self): 63 aggregate_data = profiling.AggregateProfile(self.profile_datum_1) 64 65 self.assertEqual(2, aggregate_data.total_op_time) 66 self.assertEqual(4, aggregate_data.total_exec_time) 67 self.assertEqual(1, aggregate_data.node_count) 68 self.assertEqual(1, aggregate_data.node_exec_count) 69 70 def testAddToAggregateProfileWithDifferentNodeWorks(self): 71 aggregate_data = profiling.AggregateProfile(self.profile_datum_1) 72 aggregate_data.add(self.profile_datum_2) 73 74 self.assertEqual(5, aggregate_data.total_op_time) 75 self.assertEqual(21, aggregate_data.total_exec_time) 76 self.assertEqual(2, aggregate_data.node_count) 77 self.assertEqual(2, aggregate_data.node_exec_count) 78 79 def testAddToAggregateProfileWithSameNodeWorks(self): 80 aggregate_data = profiling.AggregateProfile(self.profile_datum_1) 81 aggregate_data.add(self.profile_datum_2) 82 aggregate_data.add(self.profile_datum_3) 83 84 self.assertEqual(7, aggregate_data.total_op_time) 85 self.assertEqual(25, aggregate_data.total_exec_time) 86 self.assertEqual(2, aggregate_data.node_count) 87 self.assertEqual(3, aggregate_data.node_exec_count) 88 89 def testAddToAggregateProfileWithDifferentDeviceSameNodeWorks(self): 90 aggregate_data = profiling.AggregateProfile(self.profile_datum_1) 91 aggregate_data.add(self.profile_datum_4) 92 93 self.assertEqual(4, aggregate_data.total_op_time) 94 self.assertEqual(8, aggregate_data.total_exec_time) 95 self.assertEqual(2, aggregate_data.node_count) 96 self.assertEqual(2, aggregate_data.node_exec_count) 97 98 99 if __name__ == "__main__": 100 googletest.main() 101