1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Base class for testing the input pipeline statistics gathering ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import re 21 import numpy as np 22 23 from tensorflow.core.framework import summary_pb2 24 from tensorflow.python.data.experimental.ops import stats_aggregator 25 from tensorflow.python.data.kernel_tests import test_base 26 from tensorflow.python.framework import errors 27 28 29 class StatsDatasetTestBase(test_base.DatasetTestBase): 30 """Base class for testing statistics gathered in `StatsAggregator`.""" 31 32 def regexForNodeName(self, op_name, stats_type=""): 33 return "".join([op_name, r"/_\d+::", stats_type]) 34 35 def _assertSummaryContains(self, summary_str, tag): 36 summary_proto = summary_pb2.Summary() 37 summary_proto.ParseFromString(summary_str) 38 for value in summary_proto.value: 39 if re.match(tag, value.tag): 40 return 41 self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) 42 43 def _assertSummaryHasCount(self, 44 summary_str, 45 tag, 46 expected_value, 47 greater_than=False): 48 summary_proto = summary_pb2.Summary() 49 summary_proto.ParseFromString(summary_str) 50 for value in summary_proto.value: 51 if re.match(tag, value.tag): 52 if greater_than: 53 self.assertGreaterEqual(value.histo.num, expected_value) 54 else: 55 self.assertEqual(expected_value, value.histo.num) 56 return 57 self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) 58 59 def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value): 60 summary_proto = summary_pb2.Summary() 61 summary_proto.ParseFromString(summary_str) 62 for value in summary_proto.value: 63 if re.match(tag, value.tag): 64 self.assertLessEqual(min_value, value.histo.min) 65 self.assertGreaterEqual(max_value, value.histo.max) 66 return 67 self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) 68 69 def _assertSummaryHasSum(self, summary_str, tag, expected_value): 70 summary_proto = summary_pb2.Summary() 71 summary_proto.ParseFromString(summary_str) 72 for value in summary_proto.value: 73 if re.match(tag, value.tag): 74 self.assertEqual(expected_value, value.histo.sum) 75 return 76 self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) 77 78 def _assertSummaryHasScalarValue(self, summary_str, tag, expected_value): 79 summary_proto = summary_pb2.Summary() 80 summary_proto.ParseFromString(summary_str) 81 for value in summary_proto.value: 82 if re.match(tag, value.tag): 83 self.assertEqual(expected_value, value.simple_value) 84 return 85 self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) 86 87 def _testParallelCallsStats(self, 88 dataset_fn, 89 dataset_names, 90 num_output, 91 dataset_transformation, 92 function_processing_time=False, 93 check_elements=True): 94 aggregator = stats_aggregator.StatsAggregator() 95 dataset = dataset_fn() 96 dataset = dataset_transformation(dataset, aggregator) 97 next_element = self.getNext(dataset, requires_initialization=True) 98 99 for i in range(num_output): 100 next_ = self.evaluate(next_element()) 101 if check_elements: 102 self.assertAllEqual(np.array([i] * i, dtype=np.int64), next_) 103 summary_str = self.evaluate(aggregator.get_summary()) 104 for dataset_name in dataset_names: 105 if function_processing_time: 106 self._assertSummaryHasCount( 107 summary_str, 108 r"(.*)::execution_time$", 109 float(i + 1), 110 greater_than=True) 111 self._assertSummaryContains(summary_str, 112 dataset_name + "thread_utilization") 113 with self.assertRaises(errors.OutOfRangeError): 114 self.evaluate(next_element()) 115 if function_processing_time: 116 summary_str = self.evaluate(aggregator.get_summary()) 117 for dataset_name in dataset_names: 118 self._assertSummaryHasCount( 119 summary_str, 120 r"(.*)::execution_time$", 121 float(num_output), 122 greater_than=True) 123