1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for spectral_ops.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.signal.python.ops import spectral_ops 24 from tensorflow.contrib.signal.python.ops import window_ops 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import gradients_impl 28 from tensorflow.python.ops import math_ops 29 from tensorflow.python.ops import random_ops 30 from tensorflow.python.ops import spectral_ops_test_util 31 from tensorflow.python.platform import test 32 33 34 class SpectralOpsTest(test.TestCase): 35 36 @staticmethod 37 def _np_hann_periodic_window(length): 38 if length == 1: 39 return np.ones(1) 40 odd = length % 2 41 if not odd: 42 length += 1 43 window = 0.5 - 0.5 * np.cos(2.0 * np.pi * np.arange(length) / (length - 1)) 44 if not odd: 45 window = window[:-1] 46 return window 47 48 @staticmethod 49 def _np_frame(data, window_length, hop_length): 50 num_frames = 1 + int(np.floor((len(data) - window_length) // hop_length)) 51 shape = (num_frames, window_length) 52 strides = (data.strides[0] * hop_length, data.strides[0]) 53 return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) 54 55 @staticmethod 56 def _np_stft(data, fft_length, hop_length, window_length): 57 frames = SpectralOpsTest._np_frame(data, window_length, hop_length) 58 window = SpectralOpsTest._np_hann_periodic_window(window_length) 59 return np.fft.rfft(frames * window, fft_length) 60 61 @staticmethod 62 def _np_inverse_stft(stft, fft_length, hop_length, window_length): 63 frames = np.fft.irfft(stft, fft_length) 64 # Pad or truncate frames's inner dimension to window_length. 65 frames = frames[..., :window_length] 66 frames = np.pad(frames, [[0, 0]] * (frames.ndim - 1) + 67 [[0, max(0, window_length - frames.shape[-1])]], "constant") 68 window = SpectralOpsTest._np_hann_periodic_window(window_length) 69 return SpectralOpsTest._np_overlap_add(frames * window, hop_length) 70 71 @staticmethod 72 def _np_overlap_add(stft, hop_length): 73 num_frames, window_length = np.shape(stft) 74 # Output length will be one complete window, plus another hop_length's 75 # worth of points for each additional window. 76 output_length = window_length + (num_frames - 1) * hop_length 77 output = np.zeros(output_length) 78 for i in range(num_frames): 79 output[i * hop_length:i * hop_length + window_length] += stft[i,] 80 return output 81 82 def _compare(self, signal, frame_length, frame_step, fft_length): 83 with spectral_ops_test_util.fft_kernel_label_map(), ( 84 self.test_session(use_gpu=True)) as sess: 85 actual_stft = spectral_ops.stft( 86 signal, frame_length, frame_step, fft_length, pad_end=False) 87 signal_ph = array_ops.placeholder(dtype=dtypes.as_dtype(signal.dtype)) 88 actual_stft_from_ph = spectral_ops.stft( 89 signal_ph, frame_length, frame_step, fft_length, pad_end=False) 90 91 actual_inverse_stft = spectral_ops.inverse_stft( 92 actual_stft, frame_length, frame_step, fft_length) 93 94 actual_stft, actual_stft_from_ph, actual_inverse_stft = sess.run( 95 [actual_stft, actual_stft_from_ph, actual_inverse_stft], 96 feed_dict={signal_ph: signal}) 97 98 actual_stft_ph = array_ops.placeholder(dtype=actual_stft.dtype) 99 actual_inverse_stft_from_ph = sess.run( 100 spectral_ops.inverse_stft( 101 actual_stft_ph, frame_length, frame_step, fft_length), 102 feed_dict={actual_stft_ph: actual_stft}) 103 104 # Confirm that there is no difference in output when shape/rank is fully 105 # unknown or known. 106 self.assertAllClose(actual_stft, actual_stft_from_ph) 107 self.assertAllClose(actual_inverse_stft, actual_inverse_stft_from_ph) 108 109 expected_stft = SpectralOpsTest._np_stft( 110 signal, fft_length, frame_step, frame_length) 111 self.assertAllClose(expected_stft, actual_stft, 1e-4, 1e-4) 112 113 expected_inverse_stft = SpectralOpsTest._np_inverse_stft( 114 expected_stft, fft_length, frame_step, frame_length) 115 self.assertAllClose( 116 expected_inverse_stft, actual_inverse_stft, 1e-4, 1e-4) 117 118 def test_shapes(self): 119 with spectral_ops_test_util.fft_kernel_label_map(), ( 120 self.test_session(use_gpu=True)): 121 signal = np.zeros((512,)).astype(np.float32) 122 123 # If fft_length is not provided, the smallest enclosing power of 2 of 124 # frame_length (8) is used. 125 stft = spectral_ops.stft(signal, frame_length=7, frame_step=8, 126 pad_end=True) 127 self.assertAllEqual([64, 5], stft.shape.as_list()) 128 self.assertAllEqual([64, 5], stft.eval().shape) 129 130 stft = spectral_ops.stft(signal, frame_length=8, frame_step=8, 131 pad_end=True) 132 self.assertAllEqual([64, 5], stft.shape.as_list()) 133 self.assertAllEqual([64, 5], stft.eval().shape) 134 135 stft = spectral_ops.stft(signal, frame_length=8, frame_step=8, 136 fft_length=16, pad_end=True) 137 self.assertAllEqual([64, 9], stft.shape.as_list()) 138 self.assertAllEqual([64, 9], stft.eval().shape) 139 140 stft = spectral_ops.stft(signal, frame_length=16, frame_step=8, 141 fft_length=8, pad_end=True) 142 self.assertAllEqual([64, 5], stft.shape.as_list()) 143 self.assertAllEqual([64, 5], stft.eval().shape) 144 145 stft = np.zeros((32, 9)).astype(np.complex64) 146 147 inverse_stft = spectral_ops.inverse_stft(stft, frame_length=8, 148 fft_length=16, frame_step=8) 149 expected_length = (stft.shape[0] - 1) * 8 + 8 150 self.assertAllEqual([None], inverse_stft.shape.as_list()) 151 self.assertAllEqual([expected_length], inverse_stft.eval().shape) 152 153 def test_stft_and_inverse_stft(self): 154 """Test that spectral_ops.stft/inverse_stft match a NumPy implementation.""" 155 # Tuples of (signal_length, frame_length, frame_step, fft_length). 156 test_configs = [ 157 (512, 64, 32, 64), 158 (512, 64, 64, 64), 159 (512, 72, 64, 64), 160 (512, 64, 25, 64), 161 (512, 25, 15, 36), 162 (123, 23, 5, 42), 163 ] 164 165 for signal_length, frame_length, frame_step, fft_length in test_configs: 166 signal = np.random.random(signal_length).astype(np.float32) 167 self._compare(signal, frame_length, frame_step, fft_length) 168 169 def test_stft_round_trip(self): 170 # Tuples of (signal_length, frame_length, frame_step, fft_length, 171 # threshold, corrected_threshold). 172 test_configs = [ 173 # 87.5% overlap. 174 (4096, 256, 32, 256, 1e-5, 1e-6), 175 # 75% overlap. 176 (4096, 256, 64, 256, 1e-5, 1e-6), 177 # Odd frame hop. 178 (4096, 128, 25, 128, 1e-3, 1e-6), 179 # Odd frame length. 180 (4096, 127, 32, 128, 1e-3, 1e-6), 181 # 50% overlap. 182 (4096, 128, 64, 128, 0.40, 1e-6), 183 ] 184 185 for (signal_length, frame_length, frame_step, fft_length, threshold, 186 corrected_threshold) in test_configs: 187 # Generate a random white Gaussian signal. 188 signal = random_ops.random_normal([signal_length]) 189 190 with spectral_ops_test_util.fft_kernel_label_map(), ( 191 self.test_session(use_gpu=True)) as sess: 192 stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length, 193 pad_end=False) 194 inverse_stft = spectral_ops.inverse_stft(stft, frame_length, frame_step, 195 fft_length) 196 inverse_stft_corrected = spectral_ops.inverse_stft( 197 stft, frame_length, frame_step, fft_length, 198 window_fn=spectral_ops.inverse_stft_window_fn(frame_step)) 199 signal, inverse_stft, inverse_stft_corrected = sess.run( 200 [signal, inverse_stft, inverse_stft_corrected]) 201 202 # Truncate signal to the size of inverse stft. 203 signal = signal[:inverse_stft.shape[0]] 204 205 # Ignore the frame_length samples at either edge. 206 signal = signal[frame_length:-frame_length] 207 inverse_stft = inverse_stft[frame_length:-frame_length] 208 inverse_stft_corrected = inverse_stft_corrected[ 209 frame_length:-frame_length] 210 211 # Check that the inverse and original signal are close up to a scale 212 # factor. 213 inverse_stft_scaled = inverse_stft / np.mean(np.abs(inverse_stft)) 214 signal_scaled = signal / np.mean(np.abs(signal)) 215 self.assertLess(np.std(inverse_stft_scaled - signal_scaled), threshold) 216 217 # Check that the inverse with correction and original signal are close. 218 self.assertLess(np.std(inverse_stft_corrected - signal), 219 corrected_threshold) 220 221 def test_inverse_stft_window_fn(self): 222 """Test that inverse_stft_window_fn has unit gain at each window phase.""" 223 # Tuples of (frame_length, frame_step). 224 test_configs = [ 225 (256, 32), 226 (256, 64), 227 (128, 25), 228 (127, 32), 229 (128, 64), 230 ] 231 232 for (frame_length, frame_step) in test_configs: 233 hann_window = window_ops.hann_window(frame_length, dtype=dtypes.float32) 234 inverse_window_fn = spectral_ops.inverse_stft_window_fn(frame_step) 235 inverse_window = inverse_window_fn(frame_length, dtype=dtypes.float32) 236 237 with self.test_session(use_gpu=True) as sess: 238 hann_window, inverse_window = sess.run([hann_window, inverse_window]) 239 240 # Expect unit gain at each phase of the window. 241 product_window = hann_window * inverse_window 242 for i in range(frame_step): 243 self.assertAllClose(1.0, np.sum(product_window[i::frame_step])) 244 245 def test_inverse_stft_window_fn_special_case(self): 246 """Test inverse_stft_window_fn in special overlap = 3/4 case.""" 247 # Cases in which frame_length is an integer multiple of 4 * frame_step are 248 # special because they allow exact reproduction of the waveform with a 249 # squared Hann window (Hann window in both forward and reverse transforms). 250 # In the case where frame_length = 4 * frame_step, that combination 251 # produces a constant gain of 1.5, and so the corrected window will be the 252 # Hann window / 1.5. 253 254 # Tuples of (frame_length, frame_step). 255 test_configs = [ 256 (256, 64), 257 (128, 32), 258 ] 259 260 for (frame_length, frame_step) in test_configs: 261 hann_window = window_ops.hann_window(frame_length, dtype=dtypes.float32) 262 inverse_window_fn = spectral_ops.inverse_stft_window_fn(frame_step) 263 inverse_window = inverse_window_fn(frame_length, dtype=dtypes.float32) 264 265 with self.test_session(use_gpu=True) as sess: 266 hann_window, inverse_window = sess.run([hann_window, inverse_window]) 267 268 self.assertAllClose(hann_window, inverse_window * 1.5) 269 270 @staticmethod 271 def _compute_stft_gradient(signal, frame_length=32, frame_step=16, 272 fft_length=32): 273 """Computes the gradient of the STFT with respect to `signal`.""" 274 stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length) 275 magnitude_stft = math_ops.abs(stft) 276 loss = math_ops.reduce_sum(magnitude_stft) 277 return gradients_impl.gradients([loss], [signal])[0] 278 279 def test_gradients(self): 280 """Test that spectral_ops.stft has a working gradient.""" 281 with spectral_ops_test_util.fft_kernel_label_map(), ( 282 self.test_session(use_gpu=True)) as sess: 283 signal_length = 512 284 285 # An all-zero signal has all zero gradients with respect to the sum of the 286 # magnitude STFT. 287 empty_signal = array_ops.zeros([signal_length], dtype=dtypes.float32) 288 empty_signal_gradient = sess.run( 289 self._compute_stft_gradient(empty_signal)) 290 self.assertTrue((empty_signal_gradient == 0.0).all()) 291 292 # A sinusoid will have non-zero components of its gradient with respect to 293 # the sum of the magnitude STFT. 294 sinusoid = math_ops.sin( 295 2 * np.pi * math_ops.linspace(0.0, 1.0, signal_length)) 296 sinusoid_gradient = sess.run(self._compute_stft_gradient(sinusoid)) 297 self.assertFalse((sinusoid_gradient == 0.0).all()) 298 299 def test_gradients_numerical(self): 300 with spectral_ops_test_util.fft_kernel_label_map(), ( 301 self.test_session(use_gpu=True)): 302 # Tuples of (signal_length, frame_length, frame_step, fft_length, 303 # stft_bound, inverse_stft_bound). 304 # TODO(rjryan): Investigate why STFT gradient error is so high. 305 test_configs = [ 306 (64, 16, 8, 16), 307 (64, 16, 16, 16), 308 (64, 16, 7, 16), 309 (64, 7, 4, 9), 310 (29, 5, 1, 10), 311 ] 312 313 for (signal_length, frame_length, frame_step, fft_length) in test_configs: 314 signal_shape = [signal_length] 315 signal = random_ops.random_uniform(signal_shape) 316 stft_shape = [max(0, 1 + (signal_length - frame_length) // frame_step), 317 fft_length // 2 + 1] 318 stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length, 319 pad_end=False) 320 inverse_stft_shape = [(stft_shape[0] - 1) * frame_step + frame_length] 321 inverse_stft = spectral_ops.inverse_stft(stft, frame_length, frame_step, 322 fft_length) 323 stft_error = test.compute_gradient_error(signal, [signal_length], 324 stft, stft_shape) 325 inverse_stft_error = test.compute_gradient_error( 326 stft, stft_shape, inverse_stft, inverse_stft_shape) 327 self.assertLess(stft_error, 2e-3) 328 self.assertLess(inverse_stft_error, 5e-4) 329 330 331 if __name__ == "__main__": 332 test.main() 333