Home | History | Annotate | Download | only in tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for FFT via the XLA JIT."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import itertools
     22 
     23 import numpy as np
     24 import scipy.signal as sps
     25 
     26 from tensorflow.compiler.tests.xla_test import XLATestCase
     27 from tensorflow.contrib.signal.python.ops import spectral_ops as signal
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import spectral_ops
     31 from tensorflow.python.platform import googletest
     32 
     33 BATCH_DIMS = (3, 5)
     34 RTOL = 0.02  # Eigen/cuFFT differ widely from np, especially for FFT3D
     35 ATOL = 1e-3
     36 
     37 
     38 def pick_10(x):
     39   x = list(x)
     40   np.random.seed(123)
     41   np.random.shuffle(x)
     42   return x[:10]
     43 
     44 
     45 def to_32bit(x):
     46   if x.dtype == np.complex128:
     47     return x.astype(np.complex64)
     48   if x.dtype == np.float64:
     49     return x.astype(np.float32)
     50   return x
     51 
     52 
     53 POWS_OF_2 = 2**np.arange(3, 12)
     54 INNER_DIMS_1D = list((x,) for x in POWS_OF_2)
     55 POWS_OF_2 = 2**np.arange(3, 8)  # To avoid OOM on GPU.
     56 INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2))
     57 INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2))
     58 
     59 
     60 class FFTTest(XLATestCase):
     61 
     62   def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected,
     63                        tf_method):
     64     for indims in inner_dims:
     65       print("nfft =", indims)
     66       shape = BATCH_DIMS + indims
     67       data = np.arange(np.prod(shape) * 2) / np.prod(indims)
     68       np.random.seed(123)
     69       np.random.shuffle(data)
     70       data = np.reshape(data.astype(np.float32).view(np.complex64), shape)
     71       data = to_32bit(complex_to_input(data))
     72       expected = to_32bit(input_to_expected(data))
     73       with self.test_session() as sess:
     74         with self.test_scope():
     75           ph = array_ops.placeholder(
     76               dtypes.as_dtype(data.dtype), shape=data.shape)
     77           out = tf_method(ph)
     78         value = sess.run(out, {ph: data})
     79         self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL)
     80 
     81   def testContribSignalSTFT(self):
     82     ws = 512
     83     hs = 128
     84     dims = (ws * 20,)
     85     shape = BATCH_DIMS + dims
     86     data = np.arange(np.prod(shape)) / np.prod(dims)
     87     np.random.seed(123)
     88     np.random.shuffle(data)
     89     data = np.reshape(data.astype(np.float32), shape)
     90     window = sps.get_window("hann", ws)
     91     expected = sps.stft(
     92         data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2]
     93     expected = np.swapaxes(expected, -1, -2)
     94     expected *= window.sum()  # scipy divides by window sum
     95     with self.test_session() as sess:
     96       with self.test_scope():
     97         ph = array_ops.placeholder(
     98             dtypes.as_dtype(data.dtype), shape=data.shape)
     99         out = signal.stft(ph, ws, hs)
    100 
    101       value = sess.run(out, {ph: data})
    102       self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL)
    103 
    104   def testFFT(self):
    105     self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft,
    106                           spectral_ops.fft)
    107 
    108   def testFFT2D(self):
    109     self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2,
    110                           spectral_ops.fft2d)
    111 
    112   def testFFT3D(self):
    113     self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x,
    114                           lambda x: np.fft.fftn(x, axes=(-3, -2, -1)),
    115                           spectral_ops.fft3d)
    116 
    117   def testIFFT(self):
    118     self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft,
    119                           spectral_ops.ifft)
    120 
    121   def testIFFT2D(self):
    122     self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2,
    123                           spectral_ops.ifft2d)
    124 
    125   def testIFFT3D(self):
    126     self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x,
    127                           lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)),
    128                           spectral_ops.ifft3d)
    129 
    130   def testRFFT(self):
    131     self._VerifyFftMethod(
    132         INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]),
    133         lambda x: spectral_ops.rfft(x, fft_length=[x.shape[-1].value]))
    134 
    135   def testRFFT2D(self):
    136 
    137     def _tf_fn(x):
    138       return spectral_ops.rfft2d(
    139           x, fft_length=[x.shape[-2].value, x.shape[-1].value])
    140 
    141     self._VerifyFftMethod(
    142         INNER_DIMS_2D, np.real,
    143         lambda x: np.fft.rfft2(x, s=[x.shape[-2], x.shape[-1]]), _tf_fn)
    144 
    145   def testRFFT3D(self):
    146 
    147     def _to_expected(x):
    148       return np.fft.rfftn(
    149           x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]])
    150 
    151     def _tf_fn(x):
    152       return spectral_ops.rfft3d(
    153           x,
    154           fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value])
    155 
    156     self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
    157 
    158   def testIRFFT(self):
    159 
    160     def _tf_fn(x):
    161       return spectral_ops.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)])
    162 
    163     self._VerifyFftMethod(
    164         INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]),
    165         lambda x: np.fft.irfft(x, n=2 * (x.shape[-1] - 1)), _tf_fn)
    166 
    167   def testIRFFT2D(self):
    168 
    169     def _tf_fn(x):
    170       return spectral_ops.irfft2d(
    171           x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)])
    172 
    173     self._VerifyFftMethod(
    174         INNER_DIMS_2D,
    175         lambda x: np.fft.rfft2(np.real(x), s=[x.shape[-2], x.shape[-1]]),
    176         lambda x: np.fft.irfft2(x, s=[x.shape[-2], 2 * (x.shape[-1] - 1)]),
    177         _tf_fn)
    178 
    179   def testIRFFT3D(self):
    180 
    181     def _to_input(x):
    182       return np.fft.rfftn(
    183           np.real(x),
    184           axes=(-3, -2, -1),
    185           s=[x.shape[-3], x.shape[-2], x.shape[-1]])
    186 
    187     def _to_expected(x):
    188       return np.fft.irfftn(
    189           x,
    190           axes=(-3, -2, -1),
    191           s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)])
    192 
    193     def _tf_fn(x):
    194       return spectral_ops.irfft3d(
    195           x,
    196           fft_length=[
    197               x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1)
    198           ])
    199 
    200     self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
    201 
    202 
    203 if __name__ == "__main__":
    204   googletest.main()
    205