Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline statistics gathering ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
     23 from tensorflow.contrib.data.python.ops import stats_ops
     24 from tensorflow.core.framework import summary_pb2
     25 from tensorflow.python.data.ops import dataset_ops
     26 from tensorflow.python.framework import errors
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.platform import test
     30 
     31 
     32 class StatsDatasetTest(test.TestCase):
     33 
     34   def _assertSummaryHasCount(self, summary_str, tag, expected_value):
     35     summary_proto = summary_pb2.Summary()
     36     summary_proto.ParseFromString(summary_str)
     37     for value in summary_proto.value:
     38       if tag == value.tag:
     39         self.assertEqual(expected_value, value.histo.num)
     40         return
     41     self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
     42 
     43   def _assertSummaryHasSum(self, summary_str, tag, expected_value):
     44     summary_proto = summary_pb2.Summary()
     45     summary_proto.ParseFromString(summary_str)
     46     for value in summary_proto.value:
     47       if tag == value.tag:
     48         self.assertEqual(expected_value, value.histo.sum)
     49         return
     50     self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
     51 
     52   def testBytesProduced(self):
     53     dataset = dataset_ops.Dataset.range(100).map(
     54         lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
     55             stats_ops.bytes_produced_stats("bytes_produced"))
     56     iterator = dataset.make_initializable_iterator()
     57     stats_aggregator = stats_ops.StatsAggregator()
     58     stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
     59     next_element = iterator.get_next()
     60     summary_t = stats_aggregator.get_summary()
     61 
     62     with self.test_session() as sess:
     63       sess.run([iterator.initializer, stats_aggregator_subscriber])
     64       expected_sum = 0.0
     65       for i in range(100):
     66         self.assertAllEqual(
     67             np.array([i] * i, dtype=np.int64), sess.run(next_element))
     68         summary_str = sess.run(summary_t)
     69         self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
     70         expected_sum += i * 8.0
     71         self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
     72       with self.assertRaises(errors.OutOfRangeError):
     73         sess.run(next_element)
     74       summary_str = sess.run(summary_t)
     75       self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
     76       self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
     77 
     78   def testLatencyStats(self):
     79     dataset = dataset_ops.Dataset.range(100).apply(
     80         stats_ops.latency_stats("record_latency"))
     81     iterator = dataset.make_initializable_iterator()
     82     stats_aggregator = stats_ops.StatsAggregator()
     83     stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
     84     next_element = iterator.get_next()
     85     summary_t = stats_aggregator.get_summary()
     86 
     87     with self.test_session() as sess:
     88       sess.run([iterator.initializer, stats_aggregator_subscriber])
     89       for i in range(100):
     90         self.assertEqual(i, sess.run(next_element))
     91         self._assertSummaryHasCount(
     92             sess.run(summary_t), "record_latency", float(i + 1))
     93       with self.assertRaises(errors.OutOfRangeError):
     94         sess.run(next_element)
     95       self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
     96 
     97   def testReinitialize(self):
     98     dataset = dataset_ops.Dataset.range(100).apply(
     99         stats_ops.latency_stats("record_latency"))
    100     iterator = dataset.make_initializable_iterator()
    101     stats_aggregator = stats_ops.StatsAggregator()
    102     stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
    103     next_element = iterator.get_next()
    104     summary_t = stats_aggregator.get_summary()
    105 
    106     with self.test_session() as sess:
    107       sess.run(stats_aggregator_subscriber)
    108       for j in range(5):
    109         sess.run(iterator.initializer)
    110         for i in range(100):
    111           self.assertEqual(i, sess.run(next_element))
    112           self._assertSummaryHasCount(
    113               sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
    114         with self.assertRaises(errors.OutOfRangeError):
    115           sess.run(next_element)
    116         self._assertSummaryHasCount(
    117             sess.run(summary_t), "record_latency", (j + 1) * 100.0)
    118 
    119   def testNoAggregatorRegistered(self):
    120     dataset = dataset_ops.Dataset.range(100).apply(
    121         stats_ops.latency_stats("record_latency"))
    122     iterator = dataset.make_initializable_iterator()
    123     next_element = iterator.get_next()
    124 
    125     with self.test_session() as sess:
    126       sess.run(iterator.initializer)
    127       for i in range(100):
    128         self.assertEqual(i, sess.run(next_element))
    129       with self.assertRaises(errors.OutOfRangeError):
    130         sess.run(next_element)
    131 
    132   def testMultipleTags(self):
    133     dataset = dataset_ops.Dataset.range(100).apply(
    134         stats_ops.latency_stats("record_latency")).apply(
    135             stats_ops.latency_stats("record_latency_2"))
    136     iterator = dataset.make_initializable_iterator()
    137     stats_aggregator = stats_ops.StatsAggregator()
    138     stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
    139     next_element = iterator.get_next()
    140     summary_t = stats_aggregator.get_summary()
    141 
    142     with self.test_session() as sess:
    143       sess.run([iterator.initializer, stats_aggregator_subscriber])
    144       for i in range(100):
    145         self.assertEqual(i, sess.run(next_element))
    146         self._assertSummaryHasCount(
    147             sess.run(summary_t), "record_latency", float(i + 1))
    148         self._assertSummaryHasCount(
    149             sess.run(summary_t), "record_latency_2", float(i + 1))
    150       with self.assertRaises(errors.OutOfRangeError):
    151         sess.run(next_element)
    152       self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
    153       self._assertSummaryHasCount(
    154           sess.run(summary_t), "record_latency_2", 100.0)
    155 
    156   def testRepeatedTags(self):
    157     dataset = dataset_ops.Dataset.range(100).apply(
    158         stats_ops.latency_stats("record_latency")).apply(
    159             stats_ops.latency_stats("record_latency"))
    160     iterator = dataset.make_initializable_iterator()
    161     stats_aggregator = stats_ops.StatsAggregator()
    162     stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
    163     next_element = iterator.get_next()
    164     summary_t = stats_aggregator.get_summary()
    165 
    166     with self.test_session() as sess:
    167       sess.run([iterator.initializer, stats_aggregator_subscriber])
    168       for i in range(100):
    169         self.assertEqual(i, sess.run(next_element))
    170         self._assertSummaryHasCount(
    171             sess.run(summary_t), "record_latency", float(2 * (i + 1)))
    172       with self.assertRaises(errors.OutOfRangeError):
    173         sess.run(next_element)
    174       self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
    175 
    176   def testMultipleIteratorsSameAggregator(self):
    177     dataset = dataset_ops.Dataset.range(100).apply(
    178         stats_ops.latency_stats("record_latency"))
    179     iterator_0 = dataset.make_initializable_iterator()
    180     iterator_1 = dataset.make_initializable_iterator()
    181     stats_aggregator = stats_ops.StatsAggregator()
    182     stats_aggregator_subscribers = [stats_aggregator.subscribe(iterator_0),
    183                                     stats_aggregator.subscribe(iterator_1)]
    184     next_element = iterator_0.get_next() + iterator_1.get_next()
    185     summary_t = stats_aggregator.get_summary()
    186 
    187     with self.test_session() as sess:
    188       sess.run([iterator_0.initializer, iterator_1.initializer,
    189                 stats_aggregator_subscribers])
    190       for i in range(100):
    191         self.assertEqual(i * 2, sess.run(next_element))
    192         self._assertSummaryHasCount(
    193             sess.run(summary_t), "record_latency", float(2 * (i + 1)))
    194       with self.assertRaises(errors.OutOfRangeError):
    195         sess.run(next_element)
    196       self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
    197 
    198   def testMultipleStatsAggregatorsSameIteratorFail(self):
    199     dataset = dataset_ops.Dataset.range(100).apply(
    200         stats_ops.latency_stats("record_latency"))
    201     iterator = dataset.make_initializable_iterator()
    202     stats_aggregator_0 = stats_ops.StatsAggregator()
    203     stats_aggregator_1 = stats_ops.StatsAggregator()
    204 
    205     with self.test_session() as sess:
    206       sess.run(stats_aggregator_0.subscribe(iterator))
    207       # TODO(mrry): Consider making this allowable (and also allowing
    208       # aggregators to unsubscribe).
    209       with self.assertRaises(errors.FailedPreconditionError):
    210         sess.run(stats_aggregator_1.subscribe(iterator))
    211 
    212 
    213 class StatsDatasetSerializationTest(
    214     dataset_serialization_test_base.DatasetSerializationTestBase):
    215 
    216   def _build_dataset_bytes_stats(self, num_elements):
    217     return dataset_ops.Dataset.range(num_elements).map(
    218         lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
    219             stats_ops.bytes_produced_stats("bytes_produced"))
    220 
    221   def testBytesStatsDatasetSaveableCore(self):
    222     num_outputs = 100
    223     self.run_core_tests(
    224         lambda: self._build_dataset_bytes_stats(num_outputs),
    225         lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs)
    226 
    227   def _build_dataset_latency_stats(self, num_elements, tag="record_latency"):
    228     return dataset_ops.Dataset.range(num_elements).apply(
    229         stats_ops.latency_stats(tag))
    230 
    231   def _build_dataset_multiple_tags(self,
    232                                    num_elements,
    233                                    tag1="record_latency",
    234                                    tag2="record_latency_2"):
    235     return dataset_ops.Dataset.range(num_elements).apply(
    236         stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2))
    237 
    238   def testLatencyStatsDatasetSaveableCore(self):
    239     num_outputs = 100
    240 
    241     self.run_core_tests(
    242         lambda: self._build_dataset_latency_stats(num_outputs),
    243         lambda: self._build_dataset_latency_stats(num_outputs // 10),
    244         num_outputs)
    245 
    246     self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs),
    247                         None, num_outputs)
    248 
    249     tag1 = "record_latency"
    250     tag2 = "record_latency"
    251     self.run_core_tests(
    252         lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
    253         None, num_outputs)
    254 
    255 
    256 if __name__ == "__main__":
    257   test.main()
    258