1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for FFT via the XLA JIT.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import itertools 22 23 import numpy as np 24 import scipy.signal as sps 25 26 from tensorflow.compiler.tests.xla_test import XLATestCase 27 from tensorflow.contrib.signal.python.ops import spectral_ops as signal 28 from tensorflow.python.framework import dtypes 29 from tensorflow.python.ops import array_ops 30 from tensorflow.python.ops import spectral_ops 31 from tensorflow.python.platform import googletest 32 33 BATCH_DIMS = (3, 5) 34 RTOL = 0.02 # Eigen/cuFFT differ widely from np, especially for FFT3D 35 ATOL = 1e-3 36 37 38 def pick_10(x): 39 x = list(x) 40 np.random.seed(123) 41 np.random.shuffle(x) 42 return x[:10] 43 44 45 def to_32bit(x): 46 if x.dtype == np.complex128: 47 return x.astype(np.complex64) 48 if x.dtype == np.float64: 49 return x.astype(np.float32) 50 return x 51 52 53 POWS_OF_2 = 2**np.arange(3, 12) 54 INNER_DIMS_1D = list((x,) for x in POWS_OF_2) 55 POWS_OF_2 = 2**np.arange(3, 8) # To avoid OOM on GPU. 56 INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2)) 57 INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2)) 58 59 60 class FFTTest(XLATestCase): 61 62 def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected, 63 tf_method): 64 for indims in inner_dims: 65 print("nfft =", indims) 66 shape = BATCH_DIMS + indims 67 data = np.arange(np.prod(shape) * 2) / np.prod(indims) 68 np.random.seed(123) 69 np.random.shuffle(data) 70 data = np.reshape(data.astype(np.float32).view(np.complex64), shape) 71 data = to_32bit(complex_to_input(data)) 72 expected = to_32bit(input_to_expected(data)) 73 with self.test_session() as sess: 74 with self.test_scope(): 75 ph = array_ops.placeholder( 76 dtypes.as_dtype(data.dtype), shape=data.shape) 77 out = tf_method(ph) 78 value = sess.run(out, {ph: data}) 79 self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL) 80 81 def testContribSignalSTFT(self): 82 ws = 512 83 hs = 128 84 dims = (ws * 20,) 85 shape = BATCH_DIMS + dims 86 data = np.arange(np.prod(shape)) / np.prod(dims) 87 np.random.seed(123) 88 np.random.shuffle(data) 89 data = np.reshape(data.astype(np.float32), shape) 90 window = sps.get_window("hann", ws) 91 expected = sps.stft( 92 data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2] 93 expected = np.swapaxes(expected, -1, -2) 94 expected *= window.sum() # scipy divides by window sum 95 with self.test_session() as sess: 96 with self.test_scope(): 97 ph = array_ops.placeholder( 98 dtypes.as_dtype(data.dtype), shape=data.shape) 99 out = signal.stft(ph, ws, hs) 100 101 value = sess.run(out, {ph: data}) 102 self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL) 103 104 def testFFT(self): 105 self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft, 106 spectral_ops.fft) 107 108 def testFFT2D(self): 109 self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2, 110 spectral_ops.fft2d) 111 112 def testFFT3D(self): 113 self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, 114 lambda x: np.fft.fftn(x, axes=(-3, -2, -1)), 115 spectral_ops.fft3d) 116 117 def testIFFT(self): 118 self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft, 119 spectral_ops.ifft) 120 121 def testIFFT2D(self): 122 self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2, 123 spectral_ops.ifft2d) 124 125 def testIFFT3D(self): 126 self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, 127 lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)), 128 spectral_ops.ifft3d) 129 130 def testRFFT(self): 131 self._VerifyFftMethod( 132 INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]), 133 lambda x: spectral_ops.rfft(x, fft_length=[x.shape[-1].value])) 134 135 def testRFFT2D(self): 136 137 def _tf_fn(x): 138 return spectral_ops.rfft2d( 139 x, fft_length=[x.shape[-2].value, x.shape[-1].value]) 140 141 self._VerifyFftMethod( 142 INNER_DIMS_2D, np.real, 143 lambda x: np.fft.rfft2(x, s=[x.shape[-2], x.shape[-1]]), _tf_fn) 144 145 def testRFFT3D(self): 146 147 def _to_expected(x): 148 return np.fft.rfftn( 149 x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]]) 150 151 def _tf_fn(x): 152 return spectral_ops.rfft3d( 153 x, 154 fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value]) 155 156 self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn) 157 158 def testIRFFT(self): 159 160 def _tf_fn(x): 161 return spectral_ops.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)]) 162 163 self._VerifyFftMethod( 164 INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]), 165 lambda x: np.fft.irfft(x, n=2 * (x.shape[-1] - 1)), _tf_fn) 166 167 def testIRFFT2D(self): 168 169 def _tf_fn(x): 170 return spectral_ops.irfft2d( 171 x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)]) 172 173 self._VerifyFftMethod( 174 INNER_DIMS_2D, 175 lambda x: np.fft.rfft2(np.real(x), s=[x.shape[-2], x.shape[-1]]), 176 lambda x: np.fft.irfft2(x, s=[x.shape[-2], 2 * (x.shape[-1] - 1)]), 177 _tf_fn) 178 179 def testIRFFT3D(self): 180 181 def _to_input(x): 182 return np.fft.rfftn( 183 np.real(x), 184 axes=(-3, -2, -1), 185 s=[x.shape[-3], x.shape[-2], x.shape[-1]]) 186 187 def _to_expected(x): 188 return np.fft.irfftn( 189 x, 190 axes=(-3, -2, -1), 191 s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)]) 192 193 def _tf_fn(x): 194 return spectral_ops.irfft3d( 195 x, 196 fft_length=[ 197 x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1) 198 ]) 199 200 self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn) 201 202 203 if __name__ == "__main__": 204 googletest.main() 205