Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for spectral_ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.signal.python.ops import spectral_ops
     24 from tensorflow.contrib.signal.python.ops import window_ops
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import gradients_impl
     28 from tensorflow.python.ops import math_ops
     29 from tensorflow.python.ops import random_ops
     30 from tensorflow.python.ops import spectral_ops_test_util
     31 from tensorflow.python.platform import test
     32 
     33 
     34 class SpectralOpsTest(test.TestCase):
     35 
     36   @staticmethod
     37   def _np_hann_periodic_window(length):
     38     if length == 1:
     39       return np.ones(1)
     40     odd = length % 2
     41     if not odd:
     42       length += 1
     43     window = 0.5 - 0.5 * np.cos(2.0 * np.pi * np.arange(length) / (length - 1))
     44     if not odd:
     45       window = window[:-1]
     46     return window
     47 
     48   @staticmethod
     49   def _np_frame(data, window_length, hop_length):
     50     num_frames = 1 + int(np.floor((len(data) - window_length) // hop_length))
     51     shape = (num_frames, window_length)
     52     strides = (data.strides[0] * hop_length, data.strides[0])
     53     return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
     54 
     55   @staticmethod
     56   def _np_stft(data, fft_length, hop_length, window_length):
     57     frames = SpectralOpsTest._np_frame(data, window_length, hop_length)
     58     window = SpectralOpsTest._np_hann_periodic_window(window_length)
     59     return np.fft.rfft(frames * window, fft_length)
     60 
     61   @staticmethod
     62   def _np_inverse_stft(stft, fft_length, hop_length, window_length):
     63     frames = np.fft.irfft(stft, fft_length)
     64     # Pad or truncate frames's inner dimension to window_length.
     65     frames = frames[..., :window_length]
     66     frames = np.pad(frames, [[0, 0]] * (frames.ndim - 1) +
     67                     [[0, max(0, window_length - frames.shape[-1])]], "constant")
     68     window = SpectralOpsTest._np_hann_periodic_window(window_length)
     69     return SpectralOpsTest._np_overlap_add(frames * window, hop_length)
     70 
     71   @staticmethod
     72   def _np_overlap_add(stft, hop_length):
     73     num_frames, window_length = np.shape(stft)
     74     # Output length will be one complete window, plus another hop_length's
     75     # worth of points for each additional window.
     76     output_length = window_length + (num_frames - 1) * hop_length
     77     output = np.zeros(output_length)
     78     for i in range(num_frames):
     79       output[i * hop_length:i * hop_length + window_length] += stft[i,]
     80     return output
     81 
     82   def _compare(self, signal, frame_length, frame_step, fft_length):
     83     with spectral_ops_test_util.fft_kernel_label_map(), (
     84         self.test_session(use_gpu=True)) as sess:
     85       actual_stft = spectral_ops.stft(
     86           signal, frame_length, frame_step, fft_length, pad_end=False)
     87       signal_ph = array_ops.placeholder(dtype=dtypes.as_dtype(signal.dtype))
     88       actual_stft_from_ph = spectral_ops.stft(
     89           signal_ph, frame_length, frame_step, fft_length, pad_end=False)
     90 
     91       actual_inverse_stft = spectral_ops.inverse_stft(
     92           actual_stft, frame_length, frame_step, fft_length)
     93 
     94       actual_stft, actual_stft_from_ph, actual_inverse_stft = sess.run(
     95           [actual_stft, actual_stft_from_ph, actual_inverse_stft],
     96           feed_dict={signal_ph: signal})
     97 
     98       actual_stft_ph = array_ops.placeholder(dtype=actual_stft.dtype)
     99       actual_inverse_stft_from_ph = sess.run(
    100           spectral_ops.inverse_stft(
    101               actual_stft_ph, frame_length, frame_step, fft_length),
    102           feed_dict={actual_stft_ph: actual_stft})
    103 
    104       # Confirm that there is no difference in output when shape/rank is fully
    105       # unknown or known.
    106       self.assertAllClose(actual_stft, actual_stft_from_ph)
    107       self.assertAllClose(actual_inverse_stft, actual_inverse_stft_from_ph)
    108 
    109       expected_stft = SpectralOpsTest._np_stft(
    110           signal, fft_length, frame_step, frame_length)
    111       self.assertAllClose(expected_stft, actual_stft, 1e-4, 1e-4)
    112 
    113       expected_inverse_stft = SpectralOpsTest._np_inverse_stft(
    114           expected_stft, fft_length, frame_step, frame_length)
    115       self.assertAllClose(
    116           expected_inverse_stft, actual_inverse_stft, 1e-4, 1e-4)
    117 
    118   def test_shapes(self):
    119     with spectral_ops_test_util.fft_kernel_label_map(), (
    120         self.test_session(use_gpu=True)):
    121       signal = np.zeros((512,)).astype(np.float32)
    122 
    123       # If fft_length is not provided, the smallest enclosing power of 2 of
    124       # frame_length (8) is used.
    125       stft = spectral_ops.stft(signal, frame_length=7, frame_step=8,
    126                                pad_end=True)
    127       self.assertAllEqual([64, 5], stft.shape.as_list())
    128       self.assertAllEqual([64, 5], stft.eval().shape)
    129 
    130       stft = spectral_ops.stft(signal, frame_length=8, frame_step=8,
    131                                pad_end=True)
    132       self.assertAllEqual([64, 5], stft.shape.as_list())
    133       self.assertAllEqual([64, 5], stft.eval().shape)
    134 
    135       stft = spectral_ops.stft(signal, frame_length=8, frame_step=8,
    136                                fft_length=16, pad_end=True)
    137       self.assertAllEqual([64, 9], stft.shape.as_list())
    138       self.assertAllEqual([64, 9], stft.eval().shape)
    139 
    140       stft = spectral_ops.stft(signal, frame_length=16, frame_step=8,
    141                                fft_length=8, pad_end=True)
    142       self.assertAllEqual([64, 5], stft.shape.as_list())
    143       self.assertAllEqual([64, 5], stft.eval().shape)
    144 
    145       stft = np.zeros((32, 9)).astype(np.complex64)
    146 
    147       inverse_stft = spectral_ops.inverse_stft(stft, frame_length=8,
    148                                                fft_length=16, frame_step=8)
    149       expected_length = (stft.shape[0] - 1) * 8 + 8
    150       self.assertAllEqual([None], inverse_stft.shape.as_list())
    151       self.assertAllEqual([expected_length], inverse_stft.eval().shape)
    152 
    153   def test_stft_and_inverse_stft(self):
    154     """Test that spectral_ops.stft/inverse_stft match a NumPy implementation."""
    155     # Tuples of (signal_length, frame_length, frame_step, fft_length).
    156     test_configs = [
    157         (512, 64, 32, 64),
    158         (512, 64, 64, 64),
    159         (512, 72, 64, 64),
    160         (512, 64, 25, 64),
    161         (512, 25, 15, 36),
    162         (123, 23, 5, 42),
    163     ]
    164 
    165     for signal_length, frame_length, frame_step, fft_length in test_configs:
    166       signal = np.random.random(signal_length).astype(np.float32)
    167       self._compare(signal, frame_length, frame_step, fft_length)
    168 
    169   def test_stft_round_trip(self):
    170     # Tuples of (signal_length, frame_length, frame_step, fft_length,
    171     # threshold, corrected_threshold).
    172     test_configs = [
    173         # 87.5% overlap.
    174         (4096, 256, 32, 256, 1e-5, 1e-6),
    175         # 75% overlap.
    176         (4096, 256, 64, 256, 1e-5, 1e-6),
    177         # Odd frame hop.
    178         (4096, 128, 25, 128, 1e-3, 1e-6),
    179         # Odd frame length.
    180         (4096, 127, 32, 128, 1e-3, 1e-6),
    181         # 50% overlap.
    182         (4096, 128, 64, 128, 0.40, 1e-6),
    183     ]
    184 
    185     for (signal_length, frame_length, frame_step, fft_length, threshold,
    186          corrected_threshold) in test_configs:
    187       # Generate a random white Gaussian signal.
    188       signal = random_ops.random_normal([signal_length])
    189 
    190       with spectral_ops_test_util.fft_kernel_label_map(), (
    191           self.test_session(use_gpu=True)) as sess:
    192         stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length,
    193                                  pad_end=False)
    194         inverse_stft = spectral_ops.inverse_stft(stft, frame_length, frame_step,
    195                                                  fft_length)
    196         inverse_stft_corrected = spectral_ops.inverse_stft(
    197             stft, frame_length, frame_step, fft_length,
    198             window_fn=spectral_ops.inverse_stft_window_fn(frame_step))
    199         signal, inverse_stft, inverse_stft_corrected = sess.run(
    200             [signal, inverse_stft, inverse_stft_corrected])
    201 
    202         # Truncate signal to the size of inverse stft.
    203         signal = signal[:inverse_stft.shape[0]]
    204 
    205         # Ignore the frame_length samples at either edge.
    206         signal = signal[frame_length:-frame_length]
    207         inverse_stft = inverse_stft[frame_length:-frame_length]
    208         inverse_stft_corrected = inverse_stft_corrected[
    209             frame_length:-frame_length]
    210 
    211         # Check that the inverse and original signal are close up to a scale
    212         # factor.
    213         inverse_stft_scaled = inverse_stft / np.mean(np.abs(inverse_stft))
    214         signal_scaled = signal / np.mean(np.abs(signal))
    215         self.assertLess(np.std(inverse_stft_scaled - signal_scaled), threshold)
    216 
    217         # Check that the inverse with correction and original signal are close.
    218         self.assertLess(np.std(inverse_stft_corrected - signal),
    219                         corrected_threshold)
    220 
    221   def test_inverse_stft_window_fn(self):
    222     """Test that inverse_stft_window_fn has unit gain at each window phase."""
    223     # Tuples of (frame_length, frame_step).
    224     test_configs = [
    225         (256, 32),
    226         (256, 64),
    227         (128, 25),
    228         (127, 32),
    229         (128, 64),
    230     ]
    231 
    232     for (frame_length, frame_step) in test_configs:
    233       hann_window = window_ops.hann_window(frame_length, dtype=dtypes.float32)
    234       inverse_window_fn = spectral_ops.inverse_stft_window_fn(frame_step)
    235       inverse_window = inverse_window_fn(frame_length, dtype=dtypes.float32)
    236 
    237       with self.test_session(use_gpu=True) as sess:
    238         hann_window, inverse_window = sess.run([hann_window, inverse_window])
    239 
    240       # Expect unit gain at each phase of the window.
    241       product_window = hann_window * inverse_window
    242       for i in range(frame_step):
    243         self.assertAllClose(1.0, np.sum(product_window[i::frame_step]))
    244 
    245   def test_inverse_stft_window_fn_special_case(self):
    246     """Test inverse_stft_window_fn in special overlap = 3/4 case."""
    247     # Cases in which frame_length is an integer multiple of 4 * frame_step are
    248     # special because they allow exact reproduction of the waveform with a
    249     # squared Hann window (Hann window in both forward and reverse transforms).
    250     # In the case where frame_length = 4 * frame_step, that combination
    251     # produces a constant gain of 1.5, and so the corrected window will be the
    252     # Hann window / 1.5.
    253 
    254     # Tuples of (frame_length, frame_step).
    255     test_configs = [
    256         (256, 64),
    257         (128, 32),
    258     ]
    259 
    260     for (frame_length, frame_step) in test_configs:
    261       hann_window = window_ops.hann_window(frame_length, dtype=dtypes.float32)
    262       inverse_window_fn = spectral_ops.inverse_stft_window_fn(frame_step)
    263       inverse_window = inverse_window_fn(frame_length, dtype=dtypes.float32)
    264 
    265       with self.test_session(use_gpu=True) as sess:
    266         hann_window, inverse_window = sess.run([hann_window, inverse_window])
    267 
    268       self.assertAllClose(hann_window, inverse_window * 1.5)
    269 
    270   @staticmethod
    271   def _compute_stft_gradient(signal, frame_length=32, frame_step=16,
    272                              fft_length=32):
    273     """Computes the gradient of the STFT with respect to `signal`."""
    274     stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length)
    275     magnitude_stft = math_ops.abs(stft)
    276     loss = math_ops.reduce_sum(magnitude_stft)
    277     return gradients_impl.gradients([loss], [signal])[0]
    278 
    279   def test_gradients(self):
    280     """Test that spectral_ops.stft has a working gradient."""
    281     with spectral_ops_test_util.fft_kernel_label_map(), (
    282         self.test_session(use_gpu=True)) as sess:
    283       signal_length = 512
    284 
    285       # An all-zero signal has all zero gradients with respect to the sum of the
    286       # magnitude STFT.
    287       empty_signal = array_ops.zeros([signal_length], dtype=dtypes.float32)
    288       empty_signal_gradient = sess.run(
    289           self._compute_stft_gradient(empty_signal))
    290       self.assertTrue((empty_signal_gradient == 0.0).all())
    291 
    292       # A sinusoid will have non-zero components of its gradient with respect to
    293       # the sum of the magnitude STFT.
    294       sinusoid = math_ops.sin(
    295           2 * np.pi * math_ops.linspace(0.0, 1.0, signal_length))
    296       sinusoid_gradient = sess.run(self._compute_stft_gradient(sinusoid))
    297       self.assertFalse((sinusoid_gradient == 0.0).all())
    298 
    299   def test_gradients_numerical(self):
    300     with spectral_ops_test_util.fft_kernel_label_map(), (
    301         self.test_session(use_gpu=True)):
    302       # Tuples of (signal_length, frame_length, frame_step, fft_length,
    303       # stft_bound, inverse_stft_bound).
    304       # TODO(rjryan): Investigate why STFT gradient error is so high.
    305       test_configs = [
    306           (64, 16, 8, 16),
    307           (64, 16, 16, 16),
    308           (64, 16, 7, 16),
    309           (64, 7, 4, 9),
    310           (29, 5, 1, 10),
    311       ]
    312 
    313       for (signal_length, frame_length, frame_step, fft_length) in test_configs:
    314         signal_shape = [signal_length]
    315         signal = random_ops.random_uniform(signal_shape)
    316         stft_shape = [max(0, 1 + (signal_length - frame_length) // frame_step),
    317                       fft_length // 2 + 1]
    318         stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length,
    319                                  pad_end=False)
    320         inverse_stft_shape = [(stft_shape[0] - 1) * frame_step + frame_length]
    321         inverse_stft = spectral_ops.inverse_stft(stft, frame_length, frame_step,
    322                                                  fft_length)
    323         stft_error = test.compute_gradient_error(signal, [signal_length],
    324                                                  stft, stft_shape)
    325         inverse_stft_error = test.compute_gradient_error(
    326             stft, stft_shape, inverse_stft, inverse_stft_shape)
    327         self.assertLess(stft_error, 2e-3)
    328         self.assertLess(inverse_stft_error, 5e-4)
    329 
    330 
    331 if __name__ == "__main__":
    332   test.main()
    333