Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline statistics gathering ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from absl.testing import parameterized
     21 import numpy as np
     22 
     23 from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
     24 from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
     25 from tensorflow.python.data.experimental.ops import batching
     26 from tensorflow.python.data.experimental.ops import optimization
     27 from tensorflow.python.data.experimental.ops import stats_aggregator
     28 from tensorflow.python.data.experimental.ops import stats_ops
     29 from tensorflow.python.data.ops import dataset_ops
     30 from tensorflow.python.framework import errors
     31 from tensorflow.python.framework import ops
     32 from tensorflow.python.framework import test_util
     33 from tensorflow.python.ops import array_ops
     34 from tensorflow.python.ops import math_ops
     35 from tensorflow.python.platform import test
     36 
     37 
     38 def function_set_stats_aggregator(dataset,
     39                                   aggregator,
     40                                   prefix="",
     41                                   counter_prefix=""):
     42   return dataset.apply(
     43       stats_ops.set_stats_aggregator(aggregator, prefix, counter_prefix))
     44 
     45 
     46 def function_apply_options(dataset, aggregator, prefix="", counter_prefix=""):
     47   options = dataset_ops.Options()
     48   options.experimental_stats.aggregator = aggregator
     49   options.experimental_stats.prefix = prefix
     50   options.experimental_stats.counter_prefix = counter_prefix
     51   options.experimental_stats.latency_all_edges = False
     52   return dataset.with_options(options)
     53 
     54 
     55 @test_util.run_all_in_graph_and_eager_modes
     56 @parameterized.named_parameters(
     57     ("SetStatsAggregator", function_set_stats_aggregator),
     58     ("StatsOptions", function_apply_options),
     59 )
     60 class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
     61 
     62   def testBytesProduced(self, dataset_transformation):
     63     aggregator = stats_aggregator.StatsAggregator()
     64     dataset = dataset_ops.Dataset.range(100).map(
     65         lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
     66             stats_ops.bytes_produced_stats("bytes_produced"))
     67     dataset = dataset_transformation(dataset, aggregator)
     68     next_element = self.getNext(dataset, requires_initialization=True)
     69 
     70     expected_sum = 0.0
     71     for i in range(100):
     72       self.assertAllEqual(
     73           np.array([i] * i, dtype=np.int64), self.evaluate(next_element()))
     74       summary_str = self.evaluate(aggregator.get_summary())
     75       self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
     76       expected_sum += i * 8.0
     77       self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
     78     with self.assertRaises(errors.OutOfRangeError):
     79       self.evaluate(next_element())
     80     summary_str = self.evaluate(aggregator.get_summary())
     81     self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
     82     self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
     83 
     84   def testLatencyStats(self, dataset_transformation):
     85     aggregator = stats_aggregator.StatsAggregator()
     86     dataset = dataset_ops.Dataset.range(100).apply(
     87         stats_ops.latency_stats("record_latency"))
     88     dataset = dataset_transformation(dataset, aggregator)
     89     next_element = self.getNext(dataset, requires_initialization=True)
     90 
     91     for i in range(100):
     92       self.assertEqual(i, self.evaluate(next_element()))
     93       self._assertSummaryHasCount(
     94           self.evaluate(aggregator.get_summary()), "record_latency",
     95           float(i + 1))
     96     with self.assertRaises(errors.OutOfRangeError):
     97       self.evaluate(next_element())
     98     self._assertSummaryHasCount(
     99         self.evaluate(aggregator.get_summary()), "record_latency", 100.0)
    100 
    101   def testPrefetchBufferUtilization(self, dataset_transformation):
    102     aggregator = stats_aggregator.StatsAggregator()
    103     dataset = dataset_ops.Dataset.range(100).map(
    104         lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(-1)
    105     dataset = dataset_transformation(dataset, aggregator)
    106     next_element = self.getNext(dataset, requires_initialization=True)
    107     for i in range(100):
    108       self.assertAllEqual(
    109           np.array([i] * i, dtype=np.int64), self.evaluate(next_element()))
    110       summary_str = self.evaluate(aggregator.get_summary())
    111       self._assertSummaryHasCount(
    112           summary_str,
    113           self.regexForNodeName("PrefetchDataset", "buffer_utilization"),
    114           float(i + 1))
    115       self._assertSummaryContains(
    116           summary_str,
    117           self.regexForNodeName("PrefetchDataset", "buffer_capacity"))
    118       self._assertSummaryContains(
    119           summary_str, self.regexForNodeName("PrefetchDataset", "buffer_size"))
    120       self._assertSummaryHasRange(
    121           summary_str,
    122           self.regexForNodeName("PrefetchDataset", "buffer_utilization"), 0, 1)
    123     with self.assertRaises(errors.OutOfRangeError):
    124       self.evaluate(next_element())
    125     summary_str = self.evaluate(aggregator.get_summary())
    126     self._assertSummaryHasCount(
    127         summary_str,
    128         self.regexForNodeName("PrefetchDataset", "buffer_utilization"), 100)
    129 
    130   def testPrefetchBufferScalars(self, dataset_transformation):
    131     aggregator = stats_aggregator.StatsAggregator()
    132     dataset = dataset_ops.Dataset.range(10).map(
    133         lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(1)
    134     dataset = dataset_transformation(dataset, aggregator)
    135     next_element = self.getNext(dataset, requires_initialization=True)
    136 
    137     for i in range(10):
    138       self.assertAllEqual(
    139           np.array([i] * i, dtype=np.int64), self.evaluate(next_element()))
    140       summary_str = self.evaluate(aggregator.get_summary())
    141       self._assertSummaryHasScalarValue(
    142           summary_str,
    143           self.regexForNodeName("PrefetchDataset", "buffer_capacity"), 1)
    144       self._assertSummaryHasScalarValue(
    145           summary_str, self.regexForNodeName("PrefetchDataset", "buffer_size"),
    146           1)
    147     with self.assertRaises(errors.OutOfRangeError):
    148       self.evaluate(next_element())
    149 
    150   def testFilteredElementsStats(self, dataset_transformation):
    151     aggregator = stats_aggregator.StatsAggregator()
    152     dataset = dataset_ops.Dataset.range(101).filter(
    153         lambda x: math_ops.equal(math_ops.mod(x, 3), 0))
    154     dataset = dataset_transformation(dataset, aggregator)
    155     next_element = self.getNext(dataset, requires_initialization=True)
    156 
    157     for i in range(34):
    158       self.assertEqual(i * 3, self.evaluate(next_element()))
    159       summary_str = self.evaluate(aggregator.get_summary())
    160       if i != 0:
    161         self._assertSummaryHasScalarValue(
    162             summary_str,
    163             self.regexForNodeName("FilterDataset", "dropped_elements"),
    164             float(i * 2))
    165       self._assertSummaryHasScalarValue(
    166           summary_str,
    167           self.regexForNodeName("FilterDataset", "filtered_elements"),
    168           float(i + 1))
    169     with self.assertRaises(errors.OutOfRangeError):
    170       self.evaluate(next_element())
    171     summary_str = self.evaluate(aggregator.get_summary())
    172     self._assertSummaryHasScalarValue(
    173         summary_str, self.regexForNodeName("FilterDataset", "dropped_elements"),
    174         67.0)
    175     self._assertSummaryHasScalarValue(
    176         summary_str, self.regexForNodeName("FilterDataset",
    177                                            "filtered_elements"), 34.0)
    178 
    179   def testMapBufferUtilization(self, dataset_transformation):
    180 
    181     def dataset_fn():
    182       return dataset_ops.Dataset.range(10).map(
    183           lambda x: array_ops.tile([x], ops.convert_to_tensor([x])),
    184           num_parallel_calls=4)
    185 
    186     self._testParallelCallsStats(
    187         dataset_fn, {self.regexForNodeName("ParallelMapDataset")},
    188         10,
    189         dataset_transformation,
    190         function_processing_time=True)
    191 
    192   def testMapAutoTuneBufferUtilization(self, dataset_transformation):
    193 
    194     def dataset_fn():
    195       return dataset_ops.Dataset.range(10).map(
    196           lambda x: array_ops.tile([x], ops.convert_to_tensor([x])),
    197           num_parallel_calls=optimization.AUTOTUNE)
    198 
    199     self._testParallelCallsStats(
    200         dataset_fn, {self.regexForNodeName("ParallelMapDataset")},
    201         10,
    202         dataset_transformation,
    203         function_processing_time=True)
    204 
    205   def testInterleaveAutoTuneBufferUtilization(self, dataset_transformation):
    206 
    207     def dataset_fn():
    208 
    209       def interleave_fn(_):
    210         return dataset_ops.Dataset.range(
    211             10).map(lambda x: array_ops.tile([x], ops.convert_to_tensor([x])))
    212 
    213       return dataset_ops.Dataset.range(1).interleave(
    214           interleave_fn,
    215           cycle_length=1,
    216           num_parallel_calls=optimization.AUTOTUNE)
    217 
    218     self._testParallelCallsStats(
    219         dataset_fn, {self.regexForNodeName("ParallelInterleaveDatasetV2")}, 10,
    220         dataset_transformation)
    221 
    222   def testMapAndBatchAutoTuneBufferUtilization(self, dataset_transformation):
    223 
    224     def dataset_fn():
    225       return dataset_ops.Dataset.range(100).apply(
    226           batching.map_and_batch(
    227               lambda x: array_ops.tile([x], ops.convert_to_tensor([2])),
    228               num_parallel_calls=optimization.AUTOTUNE,
    229               batch_size=16))
    230 
    231     num_output = 100 // 16 + 1
    232     self._testParallelCallsStats(
    233         dataset_fn, {self.regexForNodeName("ExperimentalMapAndBatchDataset")},
    234         num_output,
    235         dataset_transformation,
    236         check_elements=False,
    237         function_processing_time=True)
    238 
    239   def testReinitialize(self, dataset_transformation):
    240     aggregator = stats_aggregator.StatsAggregator()
    241     dataset = dataset_ops.Dataset.range(100).apply(
    242         stats_ops.latency_stats("record_latency"))
    243     dataset = dataset_transformation(dataset, aggregator)
    244 
    245     for j in range(5):
    246       next_element = self.getNext(dataset, requires_initialization=True)
    247       for i in range(100):
    248         self.assertEqual(i, self.evaluate(next_element()))
    249         self._assertSummaryHasCount(
    250             self.evaluate(aggregator.get_summary()), "record_latency",
    251             float((j * 100) + i + 1))
    252       with self.assertRaises(errors.OutOfRangeError):
    253         self.evaluate(next_element())
    254       self._assertSummaryHasCount(
    255           self.evaluate(aggregator.get_summary()), "record_latency",
    256           (j + 1) * 100.0)
    257 
    258   def testNoAggregatorRegistered(self, dataset_transformation):
    259     dataset = dataset_ops.Dataset.range(100).apply(
    260         stats_ops.latency_stats("record_latency"))
    261 
    262     next_element = self.getNext(dataset, requires_initialization=True)
    263 
    264     for i in range(100):
    265       self.assertEqual(i, self.evaluate(next_element()))
    266     with self.assertRaises(errors.OutOfRangeError):
    267       self.evaluate(next_element())
    268 
    269   def testMultipleTags(self, dataset_transformation):
    270     aggregator = stats_aggregator.StatsAggregator()
    271     dataset = dataset_ops.Dataset.range(100).apply(
    272         stats_ops.latency_stats("record_latency")).apply(
    273             stats_ops.latency_stats("record_latency_2"))
    274     dataset = dataset_transformation(dataset, aggregator)
    275 
    276     next_element = self.getNext(dataset, requires_initialization=True)
    277 
    278     for i in range(100):
    279       self.assertEqual(i, self.evaluate(next_element()))
    280       self._assertSummaryHasCount(
    281           self.evaluate(aggregator.get_summary()), "record_latency",
    282           float(i + 1))
    283       self._assertSummaryHasCount(
    284           self.evaluate(aggregator.get_summary()), "record_latency_2",
    285           float(i + 1))
    286     with self.assertRaises(errors.OutOfRangeError):
    287       self.evaluate(next_element())
    288     self._assertSummaryHasCount(
    289         self.evaluate(aggregator.get_summary()), "record_latency", 100.0)
    290     self._assertSummaryHasCount(
    291         self.evaluate(aggregator.get_summary()), "record_latency_2", 100.0)
    292 
    293   def testRepeatedTags(self, dataset_transformation):
    294     aggregator = stats_aggregator.StatsAggregator()
    295     dataset = dataset_ops.Dataset.range(100).apply(
    296         stats_ops.latency_stats("record_latency")).apply(
    297             stats_ops.latency_stats("record_latency"))
    298     dataset = dataset_transformation(dataset, aggregator)
    299     next_element = self.getNext(dataset, requires_initialization=True)
    300 
    301     for i in range(100):
    302       self.assertEqual(i, self.evaluate(next_element()))
    303       self._assertSummaryHasCount(
    304           self.evaluate(aggregator.get_summary()), "record_latency",
    305           float(2 * (i + 1)))
    306     with self.assertRaises(errors.OutOfRangeError):
    307       self.evaluate(next_element())
    308     self._assertSummaryHasCount(
    309         self.evaluate(aggregator.get_summary()), "record_latency", 200.0)
    310 
    311   def testMultipleIteratorsSameAggregator(self, dataset_transformation):
    312     aggregator = stats_aggregator.StatsAggregator()
    313     dataset = dataset_ops.Dataset.range(100).apply(
    314         stats_ops.latency_stats("record_latency"))
    315     dataset = dataset_transformation(dataset, aggregator)
    316     next_element1 = self.getNext(dataset, requires_initialization=True)
    317     next_element2 = self.getNext(dataset, requires_initialization=True)
    318 
    319     for i in range(100):
    320       self.assertEqual(i * 2, self.evaluate(next_element1() + next_element2()))
    321       self._assertSummaryHasCount(
    322           self.evaluate(aggregator.get_summary()), "record_latency",
    323           float(2 * (i + 1)))
    324     with self.assertRaises(errors.OutOfRangeError):
    325       self.evaluate(next_element1())
    326     with self.assertRaises(errors.OutOfRangeError):
    327       self.evaluate(next_element2())
    328     self._assertSummaryHasCount(
    329         self.evaluate(aggregator.get_summary()), "record_latency", 200.0)
    330 
    331   def testMultipleDatasetWithPrefixes(self, dataset_transformation):
    332     aggregator = stats_aggregator.StatsAggregator()
    333     dataset = dataset_ops.Dataset.range(100).apply(
    334         stats_ops.latency_stats("record_latency"))
    335     dataset = dataset_transformation(dataset, aggregator, prefix="dataset1")
    336     dataset2 = dataset_ops.Dataset.range(100).apply(
    337         stats_ops.latency_stats("record_latency"))
    338     dataset2 = dataset_transformation(dataset2, aggregator, prefix="dataset2")
    339     next_element1 = self.getNext(dataset, requires_initialization=True)
    340     next_element2 = self.getNext(dataset2, requires_initialization=True)
    341 
    342     for i in range(100):
    343       self.assertEqual(i * 2, self.evaluate(next_element1() + next_element2()))
    344       self._assertSummaryHasCount(
    345           self.evaluate(aggregator.get_summary()), "dataset1_record_latency",
    346           float(i + 1))
    347       self._assertSummaryHasCount(
    348           self.evaluate(aggregator.get_summary()), "dataset2_record_latency",
    349           float(i + 1))
    350     with self.assertRaises(errors.OutOfRangeError):
    351       self.evaluate(next_element1())
    352     with self.assertRaises(errors.OutOfRangeError):
    353       self.evaluate(next_element2())
    354     self._assertSummaryHasCount(
    355         self.evaluate(aggregator.get_summary()), "dataset1_record_latency",
    356         100.0)
    357     self._assertSummaryHasCount(
    358         self.evaluate(aggregator.get_summary()), "dataset2_record_latency",
    359         100.0)
    360 
    361   def testMultiplePrefetchStats(self, dataset_transformation):
    362 
    363     aggregator = stats_aggregator.StatsAggregator()
    364     dataset = dataset_ops.Dataset.range(10).prefetch(
    365         2).map(lambda x: math_ops.add(x, 2)).prefetch(1)
    366 
    367     dataset = dataset_transformation(dataset, aggregator)
    368     next_element = self.getNext(dataset, requires_initialization=True)
    369 
    370     for i in range(10):
    371       self.assertEqual(i + 2, self.evaluate(next_element()))
    372       summary_str = self.evaluate(aggregator.get_summary())
    373       # TODO(shivaniagarwal): using exact name of prefetch node than the regex,
    374       # to differentiate between two prefetch. This might break in future, at
    375       # which point, it would be best to disable this test.
    376       self._assertSummaryHasScalarValue(
    377           summary_str, "PrefetchDataset/_5::buffer_capacity", 2)
    378       self._assertSummaryContains(summary_str,
    379                                   "PrefetchDataset/_5::buffer_size")
    380       self._assertSummaryHasScalarValue(
    381           summary_str, "PrefetchDataset/_8::buffer_capacity", 1)
    382       self._assertSummaryContains(summary_str,
    383                                   "PrefetchDataset/_8::buffer_size")
    384     with self.assertRaises(errors.OutOfRangeError):
    385       self.evaluate(next_element())
    386 
    387 
    388 @test_util.run_all_in_graph_and_eager_modes
    389 @parameterized.named_parameters(
    390     ("SetStatsAggregator", function_set_stats_aggregator),
    391     ("StatsOptions", function_apply_options)
    392 )
    393 class FeatureStatsDatasetTest(
    394     stats_dataset_test_base.StatsDatasetTestBase,
    395     reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
    396 
    397   def testFeaturesStats(self, dataset_transformation):
    398     num_epochs = 5
    399     total_records = num_epochs * self._num_records
    400     batch_size = 2
    401     aggregator = stats_aggregator.StatsAggregator()
    402 
    403     def dataset_fn():
    404       return self.make_batch_feature(
    405           filenames=self.test_filenames[0],
    406           num_epochs=num_epochs,
    407           batch_size=batch_size,
    408           shuffle=True,
    409           shuffle_seed=5,
    410           drop_final_batch=False)
    411 
    412     num_output = total_records // batch_size
    413     if total_records % batch_size:
    414       num_output = total_records // batch_size + 1
    415 
    416     self._testParallelCallsStats(
    417         dataset_fn, {self.regexForNodeName("ExperimentalParseExampleDataset")},
    418         num_output,
    419         dataset_transformation,
    420         check_elements=False)
    421 
    422     dataset = dataset_transformation(
    423         dataset_fn(), aggregator, prefix="record_stats")
    424 
    425     next_element = self.getNext(dataset, requires_initialization=True)
    426 
    427     for _ in range(num_output):
    428       self.evaluate(next_element())
    429 
    430     with self.assertRaises(errors.OutOfRangeError):
    431       self.evaluate(next_element())
    432     self._assertSummaryHasCount(
    433         self.evaluate(aggregator.get_summary()),
    434         self.regexForNodeName("record_stats_ExperimentalParseExampleDataset",
    435                               "features_count"), total_records)
    436     self._assertSummaryHasCount(
    437         self.evaluate(aggregator.get_summary()),
    438         self.regexForNodeName("record_stats_ExperimentalParseExampleDataset",
    439                               "feature_values_count"), total_records)
    440     self._assertSummaryHasSum(
    441         self.evaluate(aggregator.get_summary()),
    442         self.regexForNodeName("record_stats_ExperimentalParseExampleDataset",
    443                               "features_count"), total_records * 4)
    444     self._assertSummaryHasSum(
    445         self.evaluate(aggregator.get_summary()),
    446         self.regexForNodeName("record_stats_ExperimentalParseExampleDataset",
    447                               "feature_values_count"),
    448         self._sum_keywords(1) * num_epochs + 3 * total_records)
    449 
    450 
    451 if __name__ == "__main__":
    452   test.main()
    453