Home | History | Annotate | Download | only in python
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 # http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for StatSummarizer Python wrapper."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.core.protobuf import config_pb2
     22 from tensorflow.python import pywrap_tensorflow
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.ops import math_ops
     26 from tensorflow.python.ops import variables
     27 from tensorflow.python.platform import test
     28 
     29 
     30 class StatSummarizerTest(test.TestCase):
     31 
     32   def testStatSummarizer(self):
     33     with ops.Graph().as_default() as graph:
     34       matrix1 = constant_op.constant([[3., 3.]], name=r"m1")
     35       matrix2 = constant_op.constant([[2.], [2.]], name=r"m2")
     36       product = math_ops.matmul(matrix1, matrix2, name=r"product")
     37 
     38       graph_def = graph.as_graph_def()
     39       ss = pywrap_tensorflow.NewStatSummarizer(graph_def.SerializeToString())
     40 
     41       with self.test_session() as sess:
     42         sess.run(variables.global_variables_initializer())
     43 
     44         for _ in range(20):
     45           run_metadata = config_pb2.RunMetadata()
     46           run_options = config_pb2.RunOptions(
     47               trace_level=config_pb2.RunOptions.FULL_TRACE)
     48           sess.run(product, options=run_options, run_metadata=run_metadata)
     49 
     50           ss.ProcessStepStatsStr(run_metadata.step_stats.SerializeToString())
     51 
     52       output_string = ss.GetOutputString()
     53 
     54       print(output_string)
     55 
     56       # Test it recorded running the expected number of times.
     57       self.assertRegexpMatches(output_string, r"count=20")
     58 
     59       # Test that a header line got printed.
     60       self.assertRegexpMatches(output_string, r"====== .* ======")
     61 
     62       # Test that the nodes we added were analyzed.
     63       # The line for the op should contain both the op type (MatMul)
     64       # and the name of the node (product)
     65       self.assertRegexpMatches(output_string, r"MatMul.*product")
     66       self.assertRegexpMatches(output_string, r"Const.*m1")
     67       self.assertRegexpMatches(output_string, r"Const.*m2")
     68 
     69       # Test that a CDF summed to 100%
     70       self.assertRegexpMatches(output_string, r"100\.")
     71 
     72       pywrap_tensorflow.DeleteStatSummarizer(ss)
     73 
     74 
     75 if __name__ == "__main__":
     76   test.main()
     77