Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for summary sound op."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 from six.moves import xrange  # pylint: disable=redefined-builtin
     23 
     24 from tensorflow.core.framework import summary_pb2
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.platform import test
     27 from tensorflow.python.summary import summary
     28 
     29 
     30 class SummaryAudioOpTest(test.TestCase):
     31 
     32   def _AsSummary(self, s):
     33     summ = summary_pb2.Summary()
     34     summ.ParseFromString(s)
     35     return summ
     36 
     37   def _CheckProto(self, audio_summ, sample_rate, num_channels, length_frames):
     38     """Verify that the non-audio parts of the audio_summ proto match shape."""
     39     # Only the first 3 sounds are returned.
     40     for v in audio_summ.value:
     41       v.audio.ClearField("encoded_audio_string")
     42     expected = "\n".join("""
     43         value {
     44           tag: "snd/audio/%d"
     45           audio { content_type: "audio/wav" sample_rate: %d
     46                   num_channels: %d length_frames: %d }
     47         }""" % (i, sample_rate, num_channels, length_frames) for i in xrange(3))
     48     self.assertProtoEquals(expected, audio_summ)
     49 
     50   def testAudioSummary(self):
     51     np.random.seed(7)
     52     for channels in (1, 2, 5, 8):
     53       with self.test_session(graph=ops.Graph()) as sess:
     54         num_frames = 7
     55         shape = (4, num_frames, channels)
     56         # Generate random audio in the range [-1.0, 1.0).
     57         const = 2.0 * np.random.random(shape) - 1.0
     58 
     59         # Summarize
     60         sample_rate = 8000
     61         summ = summary.audio(
     62             "snd", const, max_outputs=3, sample_rate=sample_rate)
     63         value = sess.run(summ)
     64         self.assertEqual([], summ.get_shape())
     65         audio_summ = self._AsSummary(value)
     66 
     67         # Check the rest of the proto
     68         self._CheckProto(audio_summ, sample_rate, channels, num_frames)
     69 
     70 
     71 if __name__ == "__main__":
     72   test.main()
     73