1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline statistics gathering ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from absl.testing import parameterized 21 import numpy as np 22 23 from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base 24 from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base 25 from tensorflow.python.data.experimental.ops import batching 26 from tensorflow.python.data.experimental.ops import optimization 27 from tensorflow.python.data.experimental.ops import stats_aggregator 28 from tensorflow.python.data.experimental.ops import stats_ops 29 from tensorflow.python.data.ops import dataset_ops 30 from tensorflow.python.framework import errors 31 from tensorflow.python.framework import ops 32 from tensorflow.python.framework import test_util 33 from tensorflow.python.ops import array_ops 34 from tensorflow.python.ops import math_ops 35 from tensorflow.python.platform import test 36 37 38 def function_set_stats_aggregator(dataset, 39 aggregator, 40 prefix="", 41 counter_prefix=""): 42 return dataset.apply( 43 stats_ops.set_stats_aggregator(aggregator, prefix, counter_prefix)) 44 45 46 def function_apply_options(dataset, aggregator, prefix="", counter_prefix=""): 47 options = dataset_ops.Options() 48 options.experimental_stats.aggregator = aggregator 49 options.experimental_stats.prefix = prefix 50 options.experimental_stats.counter_prefix = counter_prefix 51 options.experimental_stats.latency_all_edges = False 52 return dataset.with_options(options) 53 54 55 @test_util.run_all_in_graph_and_eager_modes 56 @parameterized.named_parameters( 57 ("SetStatsAggregator", function_set_stats_aggregator), 58 ("StatsOptions", function_apply_options), 59 ) 60 class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): 61 62 def testBytesProduced(self, dataset_transformation): 63 aggregator = stats_aggregator.StatsAggregator() 64 dataset = dataset_ops.Dataset.range(100).map( 65 lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( 66 stats_ops.bytes_produced_stats("bytes_produced")) 67 dataset = dataset_transformation(dataset, aggregator) 68 next_element = self.getNext(dataset, requires_initialization=True) 69 70 expected_sum = 0.0 71 for i in range(100): 72 self.assertAllEqual( 73 np.array([i] * i, dtype=np.int64), self.evaluate(next_element())) 74 summary_str = self.evaluate(aggregator.get_summary()) 75 self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1)) 76 expected_sum += i * 8.0 77 self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) 78 with self.assertRaises(errors.OutOfRangeError): 79 self.evaluate(next_element()) 80 summary_str = self.evaluate(aggregator.get_summary()) 81 self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0) 82 self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) 83 84 def testLatencyStats(self, dataset_transformation): 85 aggregator = stats_aggregator.StatsAggregator() 86 dataset = dataset_ops.Dataset.range(100).apply( 87 stats_ops.latency_stats("record_latency")) 88 dataset = dataset_transformation(dataset, aggregator) 89 next_element = self.getNext(dataset, requires_initialization=True) 90 91 for i in range(100): 92 self.assertEqual(i, self.evaluate(next_element())) 93 self._assertSummaryHasCount( 94 self.evaluate(aggregator.get_summary()), "record_latency", 95 float(i + 1)) 96 with self.assertRaises(errors.OutOfRangeError): 97 self.evaluate(next_element()) 98 self._assertSummaryHasCount( 99 self.evaluate(aggregator.get_summary()), "record_latency", 100.0) 100 101 def testPrefetchBufferUtilization(self, dataset_transformation): 102 aggregator = stats_aggregator.StatsAggregator() 103 dataset = dataset_ops.Dataset.range(100).map( 104 lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(-1) 105 dataset = dataset_transformation(dataset, aggregator) 106 next_element = self.getNext(dataset, requires_initialization=True) 107 for i in range(100): 108 self.assertAllEqual( 109 np.array([i] * i, dtype=np.int64), self.evaluate(next_element())) 110 summary_str = self.evaluate(aggregator.get_summary()) 111 self._assertSummaryHasCount( 112 summary_str, 113 self.regexForNodeName("PrefetchDataset", "buffer_utilization"), 114 float(i + 1)) 115 self._assertSummaryContains( 116 summary_str, 117 self.regexForNodeName("PrefetchDataset", "buffer_capacity")) 118 self._assertSummaryContains( 119 summary_str, self.regexForNodeName("PrefetchDataset", "buffer_size")) 120 self._assertSummaryHasRange( 121 summary_str, 122 self.regexForNodeName("PrefetchDataset", "buffer_utilization"), 0, 1) 123 with self.assertRaises(errors.OutOfRangeError): 124 self.evaluate(next_element()) 125 summary_str = self.evaluate(aggregator.get_summary()) 126 self._assertSummaryHasCount( 127 summary_str, 128 self.regexForNodeName("PrefetchDataset", "buffer_utilization"), 100) 129 130 def testPrefetchBufferScalars(self, dataset_transformation): 131 aggregator = stats_aggregator.StatsAggregator() 132 dataset = dataset_ops.Dataset.range(10).map( 133 lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(1) 134 dataset = dataset_transformation(dataset, aggregator) 135 next_element = self.getNext(dataset, requires_initialization=True) 136 137 for i in range(10): 138 self.assertAllEqual( 139 np.array([i] * i, dtype=np.int64), self.evaluate(next_element())) 140 summary_str = self.evaluate(aggregator.get_summary()) 141 self._assertSummaryHasScalarValue( 142 summary_str, 143 self.regexForNodeName("PrefetchDataset", "buffer_capacity"), 1) 144 self._assertSummaryHasScalarValue( 145 summary_str, self.regexForNodeName("PrefetchDataset", "buffer_size"), 146 1) 147 with self.assertRaises(errors.OutOfRangeError): 148 self.evaluate(next_element()) 149 150 def testFilteredElementsStats(self, dataset_transformation): 151 aggregator = stats_aggregator.StatsAggregator() 152 dataset = dataset_ops.Dataset.range(101).filter( 153 lambda x: math_ops.equal(math_ops.mod(x, 3), 0)) 154 dataset = dataset_transformation(dataset, aggregator) 155 next_element = self.getNext(dataset, requires_initialization=True) 156 157 for i in range(34): 158 self.assertEqual(i * 3, self.evaluate(next_element())) 159 summary_str = self.evaluate(aggregator.get_summary()) 160 if i != 0: 161 self._assertSummaryHasScalarValue( 162 summary_str, 163 self.regexForNodeName("FilterDataset", "dropped_elements"), 164 float(i * 2)) 165 self._assertSummaryHasScalarValue( 166 summary_str, 167 self.regexForNodeName("FilterDataset", "filtered_elements"), 168 float(i + 1)) 169 with self.assertRaises(errors.OutOfRangeError): 170 self.evaluate(next_element()) 171 summary_str = self.evaluate(aggregator.get_summary()) 172 self._assertSummaryHasScalarValue( 173 summary_str, self.regexForNodeName("FilterDataset", "dropped_elements"), 174 67.0) 175 self._assertSummaryHasScalarValue( 176 summary_str, self.regexForNodeName("FilterDataset", 177 "filtered_elements"), 34.0) 178 179 def testMapBufferUtilization(self, dataset_transformation): 180 181 def dataset_fn(): 182 return dataset_ops.Dataset.range(10).map( 183 lambda x: array_ops.tile([x], ops.convert_to_tensor([x])), 184 num_parallel_calls=4) 185 186 self._testParallelCallsStats( 187 dataset_fn, {self.regexForNodeName("ParallelMapDataset")}, 188 10, 189 dataset_transformation, 190 function_processing_time=True) 191 192 def testMapAutoTuneBufferUtilization(self, dataset_transformation): 193 194 def dataset_fn(): 195 return dataset_ops.Dataset.range(10).map( 196 lambda x: array_ops.tile([x], ops.convert_to_tensor([x])), 197 num_parallel_calls=optimization.AUTOTUNE) 198 199 self._testParallelCallsStats( 200 dataset_fn, {self.regexForNodeName("ParallelMapDataset")}, 201 10, 202 dataset_transformation, 203 function_processing_time=True) 204 205 def testInterleaveAutoTuneBufferUtilization(self, dataset_transformation): 206 207 def dataset_fn(): 208 209 def interleave_fn(_): 210 return dataset_ops.Dataset.range( 211 10).map(lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))) 212 213 return dataset_ops.Dataset.range(1).interleave( 214 interleave_fn, 215 cycle_length=1, 216 num_parallel_calls=optimization.AUTOTUNE) 217 218 self._testParallelCallsStats( 219 dataset_fn, {self.regexForNodeName("ParallelInterleaveDatasetV2")}, 10, 220 dataset_transformation) 221 222 def testMapAndBatchAutoTuneBufferUtilization(self, dataset_transformation): 223 224 def dataset_fn(): 225 return dataset_ops.Dataset.range(100).apply( 226 batching.map_and_batch( 227 lambda x: array_ops.tile([x], ops.convert_to_tensor([2])), 228 num_parallel_calls=optimization.AUTOTUNE, 229 batch_size=16)) 230 231 num_output = 100 // 16 + 1 232 self._testParallelCallsStats( 233 dataset_fn, {self.regexForNodeName("ExperimentalMapAndBatchDataset")}, 234 num_output, 235 dataset_transformation, 236 check_elements=False, 237 function_processing_time=True) 238 239 def testReinitialize(self, dataset_transformation): 240 aggregator = stats_aggregator.StatsAggregator() 241 dataset = dataset_ops.Dataset.range(100).apply( 242 stats_ops.latency_stats("record_latency")) 243 dataset = dataset_transformation(dataset, aggregator) 244 245 for j in range(5): 246 next_element = self.getNext(dataset, requires_initialization=True) 247 for i in range(100): 248 self.assertEqual(i, self.evaluate(next_element())) 249 self._assertSummaryHasCount( 250 self.evaluate(aggregator.get_summary()), "record_latency", 251 float((j * 100) + i + 1)) 252 with self.assertRaises(errors.OutOfRangeError): 253 self.evaluate(next_element()) 254 self._assertSummaryHasCount( 255 self.evaluate(aggregator.get_summary()), "record_latency", 256 (j + 1) * 100.0) 257 258 def testNoAggregatorRegistered(self, dataset_transformation): 259 dataset = dataset_ops.Dataset.range(100).apply( 260 stats_ops.latency_stats("record_latency")) 261 262 next_element = self.getNext(dataset, requires_initialization=True) 263 264 for i in range(100): 265 self.assertEqual(i, self.evaluate(next_element())) 266 with self.assertRaises(errors.OutOfRangeError): 267 self.evaluate(next_element()) 268 269 def testMultipleTags(self, dataset_transformation): 270 aggregator = stats_aggregator.StatsAggregator() 271 dataset = dataset_ops.Dataset.range(100).apply( 272 stats_ops.latency_stats("record_latency")).apply( 273 stats_ops.latency_stats("record_latency_2")) 274 dataset = dataset_transformation(dataset, aggregator) 275 276 next_element = self.getNext(dataset, requires_initialization=True) 277 278 for i in range(100): 279 self.assertEqual(i, self.evaluate(next_element())) 280 self._assertSummaryHasCount( 281 self.evaluate(aggregator.get_summary()), "record_latency", 282 float(i + 1)) 283 self._assertSummaryHasCount( 284 self.evaluate(aggregator.get_summary()), "record_latency_2", 285 float(i + 1)) 286 with self.assertRaises(errors.OutOfRangeError): 287 self.evaluate(next_element()) 288 self._assertSummaryHasCount( 289 self.evaluate(aggregator.get_summary()), "record_latency", 100.0) 290 self._assertSummaryHasCount( 291 self.evaluate(aggregator.get_summary()), "record_latency_2", 100.0) 292 293 def testRepeatedTags(self, dataset_transformation): 294 aggregator = stats_aggregator.StatsAggregator() 295 dataset = dataset_ops.Dataset.range(100).apply( 296 stats_ops.latency_stats("record_latency")).apply( 297 stats_ops.latency_stats("record_latency")) 298 dataset = dataset_transformation(dataset, aggregator) 299 next_element = self.getNext(dataset, requires_initialization=True) 300 301 for i in range(100): 302 self.assertEqual(i, self.evaluate(next_element())) 303 self._assertSummaryHasCount( 304 self.evaluate(aggregator.get_summary()), "record_latency", 305 float(2 * (i + 1))) 306 with self.assertRaises(errors.OutOfRangeError): 307 self.evaluate(next_element()) 308 self._assertSummaryHasCount( 309 self.evaluate(aggregator.get_summary()), "record_latency", 200.0) 310 311 def testMultipleIteratorsSameAggregator(self, dataset_transformation): 312 aggregator = stats_aggregator.StatsAggregator() 313 dataset = dataset_ops.Dataset.range(100).apply( 314 stats_ops.latency_stats("record_latency")) 315 dataset = dataset_transformation(dataset, aggregator) 316 next_element1 = self.getNext(dataset, requires_initialization=True) 317 next_element2 = self.getNext(dataset, requires_initialization=True) 318 319 for i in range(100): 320 self.assertEqual(i * 2, self.evaluate(next_element1() + next_element2())) 321 self._assertSummaryHasCount( 322 self.evaluate(aggregator.get_summary()), "record_latency", 323 float(2 * (i + 1))) 324 with self.assertRaises(errors.OutOfRangeError): 325 self.evaluate(next_element1()) 326 with self.assertRaises(errors.OutOfRangeError): 327 self.evaluate(next_element2()) 328 self._assertSummaryHasCount( 329 self.evaluate(aggregator.get_summary()), "record_latency", 200.0) 330 331 def testMultipleDatasetWithPrefixes(self, dataset_transformation): 332 aggregator = stats_aggregator.StatsAggregator() 333 dataset = dataset_ops.Dataset.range(100).apply( 334 stats_ops.latency_stats("record_latency")) 335 dataset = dataset_transformation(dataset, aggregator, prefix="dataset1") 336 dataset2 = dataset_ops.Dataset.range(100).apply( 337 stats_ops.latency_stats("record_latency")) 338 dataset2 = dataset_transformation(dataset2, aggregator, prefix="dataset2") 339 next_element1 = self.getNext(dataset, requires_initialization=True) 340 next_element2 = self.getNext(dataset2, requires_initialization=True) 341 342 for i in range(100): 343 self.assertEqual(i * 2, self.evaluate(next_element1() + next_element2())) 344 self._assertSummaryHasCount( 345 self.evaluate(aggregator.get_summary()), "dataset1_record_latency", 346 float(i + 1)) 347 self._assertSummaryHasCount( 348 self.evaluate(aggregator.get_summary()), "dataset2_record_latency", 349 float(i + 1)) 350 with self.assertRaises(errors.OutOfRangeError): 351 self.evaluate(next_element1()) 352 with self.assertRaises(errors.OutOfRangeError): 353 self.evaluate(next_element2()) 354 self._assertSummaryHasCount( 355 self.evaluate(aggregator.get_summary()), "dataset1_record_latency", 356 100.0) 357 self._assertSummaryHasCount( 358 self.evaluate(aggregator.get_summary()), "dataset2_record_latency", 359 100.0) 360 361 def testMultiplePrefetchStats(self, dataset_transformation): 362 363 aggregator = stats_aggregator.StatsAggregator() 364 dataset = dataset_ops.Dataset.range(10).prefetch( 365 2).map(lambda x: math_ops.add(x, 2)).prefetch(1) 366 367 dataset = dataset_transformation(dataset, aggregator) 368 next_element = self.getNext(dataset, requires_initialization=True) 369 370 for i in range(10): 371 self.assertEqual(i + 2, self.evaluate(next_element())) 372 summary_str = self.evaluate(aggregator.get_summary()) 373 # TODO(shivaniagarwal): using exact name of prefetch node than the regex, 374 # to differentiate between two prefetch. This might break in future, at 375 # which point, it would be best to disable this test. 376 self._assertSummaryHasScalarValue( 377 summary_str, "PrefetchDataset/_5::buffer_capacity", 2) 378 self._assertSummaryContains(summary_str, 379 "PrefetchDataset/_5::buffer_size") 380 self._assertSummaryHasScalarValue( 381 summary_str, "PrefetchDataset/_8::buffer_capacity", 1) 382 self._assertSummaryContains(summary_str, 383 "PrefetchDataset/_8::buffer_size") 384 with self.assertRaises(errors.OutOfRangeError): 385 self.evaluate(next_element()) 386 387 388 @test_util.run_all_in_graph_and_eager_modes 389 @parameterized.named_parameters( 390 ("SetStatsAggregator", function_set_stats_aggregator), 391 ("StatsOptions", function_apply_options) 392 ) 393 class FeatureStatsDatasetTest( 394 stats_dataset_test_base.StatsDatasetTestBase, 395 reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase): 396 397 def testFeaturesStats(self, dataset_transformation): 398 num_epochs = 5 399 total_records = num_epochs * self._num_records 400 batch_size = 2 401 aggregator = stats_aggregator.StatsAggregator() 402 403 def dataset_fn(): 404 return self.make_batch_feature( 405 filenames=self.test_filenames[0], 406 num_epochs=num_epochs, 407 batch_size=batch_size, 408 shuffle=True, 409 shuffle_seed=5, 410 drop_final_batch=False) 411 412 num_output = total_records // batch_size 413 if total_records % batch_size: 414 num_output = total_records // batch_size + 1 415 416 self._testParallelCallsStats( 417 dataset_fn, {self.regexForNodeName("ExperimentalParseExampleDataset")}, 418 num_output, 419 dataset_transformation, 420 check_elements=False) 421 422 dataset = dataset_transformation( 423 dataset_fn(), aggregator, prefix="record_stats") 424 425 next_element = self.getNext(dataset, requires_initialization=True) 426 427 for _ in range(num_output): 428 self.evaluate(next_element()) 429 430 with self.assertRaises(errors.OutOfRangeError): 431 self.evaluate(next_element()) 432 self._assertSummaryHasCount( 433 self.evaluate(aggregator.get_summary()), 434 self.regexForNodeName("record_stats_ExperimentalParseExampleDataset", 435 "features_count"), total_records) 436 self._assertSummaryHasCount( 437 self.evaluate(aggregator.get_summary()), 438 self.regexForNodeName("record_stats_ExperimentalParseExampleDataset", 439 "feature_values_count"), total_records) 440 self._assertSummaryHasSum( 441 self.evaluate(aggregator.get_summary()), 442 self.regexForNodeName("record_stats_ExperimentalParseExampleDataset", 443 "features_count"), total_records * 4) 444 self._assertSummaryHasSum( 445 self.evaluate(aggregator.get_summary()), 446 self.regexForNodeName("record_stats_ExperimentalParseExampleDataset", 447 "feature_values_count"), 448 self._sum_keywords(1) * num_epochs + 3 * total_records) 449 450 451 if __name__ == "__main__": 452 test.main() 453