Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Base class for testing the input pipeline statistics gathering ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import re
     21 import numpy as np
     22 
     23 from tensorflow.core.framework import summary_pb2
     24 from tensorflow.python.data.experimental.ops import stats_aggregator
     25 from tensorflow.python.data.kernel_tests import test_base
     26 from tensorflow.python.framework import errors
     27 
     28 
     29 class StatsDatasetTestBase(test_base.DatasetTestBase):
     30   """Base class for testing statistics gathered in `StatsAggregator`."""
     31 
     32   def regexForNodeName(self, op_name, stats_type=""):
     33     return "".join([op_name, r"/_\d+::", stats_type])
     34 
     35   def _assertSummaryContains(self, summary_str, tag):
     36     summary_proto = summary_pb2.Summary()
     37     summary_proto.ParseFromString(summary_str)
     38     for value in summary_proto.value:
     39       if re.match(tag, value.tag):
     40         return
     41     self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
     42 
     43   def _assertSummaryHasCount(self,
     44                              summary_str,
     45                              tag,
     46                              expected_value,
     47                              greater_than=False):
     48     summary_proto = summary_pb2.Summary()
     49     summary_proto.ParseFromString(summary_str)
     50     for value in summary_proto.value:
     51       if re.match(tag, value.tag):
     52         if greater_than:
     53           self.assertGreaterEqual(value.histo.num, expected_value)
     54         else:
     55           self.assertEqual(expected_value, value.histo.num)
     56         return
     57     self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
     58 
     59   def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value):
     60     summary_proto = summary_pb2.Summary()
     61     summary_proto.ParseFromString(summary_str)
     62     for value in summary_proto.value:
     63       if re.match(tag, value.tag):
     64         self.assertLessEqual(min_value, value.histo.min)
     65         self.assertGreaterEqual(max_value, value.histo.max)
     66         return
     67     self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
     68 
     69   def _assertSummaryHasSum(self, summary_str, tag, expected_value):
     70     summary_proto = summary_pb2.Summary()
     71     summary_proto.ParseFromString(summary_str)
     72     for value in summary_proto.value:
     73       if re.match(tag, value.tag):
     74         self.assertEqual(expected_value, value.histo.sum)
     75         return
     76     self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
     77 
     78   def _assertSummaryHasScalarValue(self, summary_str, tag, expected_value):
     79     summary_proto = summary_pb2.Summary()
     80     summary_proto.ParseFromString(summary_str)
     81     for value in summary_proto.value:
     82       if re.match(tag, value.tag):
     83         self.assertEqual(expected_value, value.simple_value)
     84         return
     85     self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
     86 
     87   def _testParallelCallsStats(self,
     88                               dataset_fn,
     89                               dataset_names,
     90                               num_output,
     91                               dataset_transformation,
     92                               function_processing_time=False,
     93                               check_elements=True):
     94     aggregator = stats_aggregator.StatsAggregator()
     95     dataset = dataset_fn()
     96     dataset = dataset_transformation(dataset, aggregator)
     97     next_element = self.getNext(dataset, requires_initialization=True)
     98 
     99     for i in range(num_output):
    100       next_ = self.evaluate(next_element())
    101       if check_elements:
    102         self.assertAllEqual(np.array([i] * i, dtype=np.int64), next_)
    103       summary_str = self.evaluate(aggregator.get_summary())
    104       for dataset_name in dataset_names:
    105         if function_processing_time:
    106           self._assertSummaryHasCount(
    107               summary_str,
    108               r"(.*)::execution_time$",
    109               float(i + 1),
    110               greater_than=True)
    111         self._assertSummaryContains(summary_str,
    112                                     dataset_name + "thread_utilization")
    113     with self.assertRaises(errors.OutOfRangeError):
    114       self.evaluate(next_element())
    115     if function_processing_time:
    116       summary_str = self.evaluate(aggregator.get_summary())
    117       for dataset_name in dataset_names:
    118         self._assertSummaryHasCount(
    119             summary_str,
    120             r"(.*)::execution_time$",
    121             float(num_output),
    122             greater_than=True)
    123