Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for fft operations."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 from six.moves import xrange  # pylint: disable=redefined-builtin
     23 
     24 from tensorflow.core.protobuf import config_pb2
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import errors
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import gen_spectral_ops
     30 from tensorflow.python.ops import gradient_checker
     31 from tensorflow.python.ops import math_ops
     32 from tensorflow.python.ops import spectral_ops
     33 from tensorflow.python.ops import spectral_ops_test_util
     34 from tensorflow.python.platform import test
     35 
     36 VALID_FFT_RANKS = (1, 2, 3)
     37 
     38 
     39 class BaseFFTOpsTest(test.TestCase):
     40 
     41   def _compare(self, x, rank, fft_length=None, use_placeholder=False):
     42     self._compareForward(x, rank, fft_length, use_placeholder)
     43     self._compareBackward(x, rank, fft_length, use_placeholder)
     44 
     45   def _compareForward(self, x, rank, fft_length=None, use_placeholder=False):
     46     x_np = self._npFFT(x, rank, fft_length)
     47     if use_placeholder:
     48       x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
     49       x_tf = self._tfFFT(x_ph, rank, fft_length, feed_dict={x_ph: x})
     50     else:
     51       x_tf = self._tfFFT(x, rank, fft_length)
     52 
     53     self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4)
     54 
     55   def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False):
     56     x_np = self._npIFFT(x, rank, fft_length)
     57     if use_placeholder:
     58       x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
     59       x_tf = self._tfIFFT(x_ph, rank, fft_length, feed_dict={x_ph: x})
     60     else:
     61       x_tf = self._tfIFFT(x, rank, fft_length)
     62 
     63     self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4)
     64 
     65   def _checkMemoryFail(self, x, rank):
     66     config = config_pb2.ConfigProto()
     67     config.gpu_options.per_process_gpu_memory_fraction = 1e-2
     68     with self.test_session(config=config, force_gpu=True):
     69       self._tfFFT(x, rank, fft_length=None)
     70 
     71   def _checkGradComplex(self, func, x, y, result_is_complex=True):
     72     with self.test_session(use_gpu=True):
     73       inx = ops.convert_to_tensor(x)
     74       iny = ops.convert_to_tensor(y)
     75       # func is a forward or inverse, real or complex, batched or unbatched FFT
     76       # function with a complex input.
     77       z = func(math_ops.complex(inx, iny))
     78       # loss = sum(|z|^2)
     79       loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
     80 
     81       ((x_jacob_t, x_jacob_n),
     82        (y_jacob_t, y_jacob_n)) = gradient_checker.compute_gradient(
     83            [inx, iny], [list(x.shape), list(y.shape)],
     84            loss, [1],
     85            x_init_value=[x, y],
     86            delta=1e-2)
     87 
     88     self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=1e-2)
     89     self.assertAllClose(y_jacob_t, y_jacob_n, rtol=1e-2, atol=1e-2)
     90 
     91   def _checkGradReal(self, func, x):
     92     with self.test_session(use_gpu=True):
     93       inx = ops.convert_to_tensor(x)
     94       # func is a forward RFFT function (batched or unbatched).
     95       z = func(inx)
     96       # loss = sum(|z|^2)
     97       loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
     98       x_jacob_t, x_jacob_n = test.compute_gradient(
     99           inx, list(x.shape), loss, [1], x_init_value=x, delta=1e-2)
    100 
    101     self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=1e-2)
    102 
    103 
    104 class FFTOpsTest(BaseFFTOpsTest):
    105 
    106   def _tfFFT(self, x, rank, fft_length=None, feed_dict=None):
    107     # fft_length unused for complex FFTs.
    108     with self.test_session(use_gpu=True):
    109       return self._tfFFTForRank(rank)(x).eval(feed_dict=feed_dict)
    110 
    111   def _tfIFFT(self, x, rank, fft_length=None, feed_dict=None):
    112     # fft_length unused for complex FFTs.
    113     with self.test_session(use_gpu=True):
    114       return self._tfIFFTForRank(rank)(x).eval(feed_dict=feed_dict)
    115 
    116   def _npFFT(self, x, rank, fft_length=None):
    117     if rank == 1:
    118       return np.fft.fft2(x, s=fft_length, axes=(-1,))
    119     elif rank == 2:
    120       return np.fft.fft2(x, s=fft_length, axes=(-2, -1))
    121     elif rank == 3:
    122       return np.fft.fft2(x, s=fft_length, axes=(-3, -2, -1))
    123     else:
    124       raise ValueError("invalid rank")
    125 
    126   def _npIFFT(self, x, rank, fft_length=None):
    127     if rank == 1:
    128       return np.fft.ifft2(x, s=fft_length, axes=(-1,))
    129     elif rank == 2:
    130       return np.fft.ifft2(x, s=fft_length, axes=(-2, -1))
    131     elif rank == 3:
    132       return np.fft.ifft2(x, s=fft_length, axes=(-3, -2, -1))
    133     else:
    134       raise ValueError("invalid rank")
    135 
    136   def _tfFFTForRank(self, rank):
    137     if rank == 1:
    138       return spectral_ops.fft
    139     elif rank == 2:
    140       return spectral_ops.fft2d
    141     elif rank == 3:
    142       return spectral_ops.fft3d
    143     else:
    144       raise ValueError("invalid rank")
    145 
    146   def _tfIFFTForRank(self, rank):
    147     if rank == 1:
    148       return spectral_ops.ifft
    149     elif rank == 2:
    150       return spectral_ops.ifft2d
    151     elif rank == 3:
    152       return spectral_ops.ifft3d
    153     else:
    154       raise ValueError("invalid rank")
    155 
    156   def testEmpty(self):
    157     with spectral_ops_test_util.fft_kernel_label_map():
    158       for rank in VALID_FFT_RANKS:
    159         for dims in xrange(rank, rank + 3):
    160           x = np.zeros((0,) * dims).astype(np.complex64)
    161           self.assertEqual(x.shape, self._tfFFT(x, rank).shape)
    162           self.assertEqual(x.shape, self._tfIFFT(x, rank).shape)
    163 
    164   def testBasic(self):
    165     with spectral_ops_test_util.fft_kernel_label_map():
    166       for rank in VALID_FFT_RANKS:
    167         for dims in xrange(rank, rank + 3):
    168           self._compare(
    169               np.mod(np.arange(np.power(4, dims)), 10).reshape(
    170                   (4,) * dims).astype(np.complex64), rank)
    171 
    172   def testLargeBatch(self):
    173     if test.is_gpu_available(cuda_only=True):
    174       rank = 1
    175       for dims in xrange(rank, rank + 3):
    176         self._compare(
    177             np.mod(np.arange(np.power(128, dims)), 10).reshape(
    178                 (128,) * dims).astype(np.complex64), rank)
    179 
    180   # TODO(yangzihao): Disable before we can figure out a way to
    181   # properly test memory fail for large batch fft.
    182   # def testLargeBatchMemoryFail(self):
    183   #   if test.is_gpu_available(cuda_only=True):
    184   #     rank = 1
    185   #     for dims in xrange(rank, rank + 3):
    186   #       self._checkMemoryFail(
    187   #           np.mod(np.arange(np.power(128, dims)), 64).reshape(
    188   #               (128,) * dims).astype(np.complex64), rank)
    189 
    190   def testBasicPlaceholder(self):
    191     with spectral_ops_test_util.fft_kernel_label_map():
    192       for rank in VALID_FFT_RANKS:
    193         for dims in xrange(rank, rank + 3):
    194           self._compare(
    195               np.mod(np.arange(np.power(4, dims)), 10).reshape(
    196                   (4,) * dims).astype(np.complex64),
    197               rank,
    198               use_placeholder=True)
    199 
    200   def testRandom(self):
    201     with spectral_ops_test_util.fft_kernel_label_map():
    202       np.random.seed(12345)
    203 
    204       def gen(shape):
    205         n = np.prod(shape)
    206         re = np.random.uniform(size=n)
    207         im = np.random.uniform(size=n)
    208         return (re + im * 1j).reshape(shape)
    209 
    210       for rank in VALID_FFT_RANKS:
    211         for dims in xrange(rank, rank + 3):
    212           self._compare(gen((4,) * dims), rank)
    213 
    214   def testError(self):
    215     for rank in VALID_FFT_RANKS:
    216       for dims in xrange(0, rank):
    217         x = np.zeros((1,) * dims).astype(np.complex64)
    218         with self.assertRaisesWithPredicateMatch(
    219             ValueError, "Shape must be .*rank {}.*".format(rank)):
    220           self._tfFFT(x, rank)
    221         with self.assertRaisesWithPredicateMatch(
    222             ValueError, "Shape must be .*rank {}.*".format(rank)):
    223           self._tfIFFT(x, rank)
    224 
    225   def testGrad_Simple(self):
    226     with spectral_ops_test_util.fft_kernel_label_map():
    227       for rank in VALID_FFT_RANKS:
    228         for dims in xrange(rank, rank + 2):
    229           re = np.ones(shape=(4,) * dims, dtype=np.float32) / 10.0
    230           im = np.zeros(shape=(4,) * dims, dtype=np.float32)
    231           self._checkGradComplex(self._tfFFTForRank(rank), re, im)
    232           self._checkGradComplex(self._tfIFFTForRank(rank), re, im)
    233 
    234   def testGrad_Random(self):
    235     with spectral_ops_test_util.fft_kernel_label_map():
    236       np.random.seed(54321)
    237       for rank in VALID_FFT_RANKS:
    238         for dims in xrange(rank, rank + 2):
    239           re = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1
    240           im = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1
    241           self._checkGradComplex(self._tfFFTForRank(rank), re, im)
    242           self._checkGradComplex(self._tfIFFTForRank(rank), re, im)
    243 
    244 
    245 class RFFTOpsTest(BaseFFTOpsTest):
    246 
    247   def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False):
    248     super(RFFTOpsTest, self)._compareBackward(x, rank, fft_length,
    249                                               use_placeholder)
    250 
    251   def _tfFFT(self, x, rank, fft_length=None, feed_dict=None):
    252     with self.test_session(use_gpu=True):
    253       return self._tfFFTForRank(rank)(x, fft_length).eval(feed_dict=feed_dict)
    254 
    255   def _tfIFFT(self, x, rank, fft_length=None, feed_dict=None):
    256     with self.test_session(use_gpu=True):
    257       return self._tfIFFTForRank(rank)(x, fft_length).eval(feed_dict=feed_dict)
    258 
    259   def _npFFT(self, x, rank, fft_length=None):
    260     if rank == 1:
    261       return np.fft.rfft2(x, s=fft_length, axes=(-1,))
    262     elif rank == 2:
    263       return np.fft.rfft2(x, s=fft_length, axes=(-2, -1))
    264     elif rank == 3:
    265       return np.fft.rfft2(x, s=fft_length, axes=(-3, -2, -1))
    266     else:
    267       raise ValueError("invalid rank")
    268 
    269   def _npIFFT(self, x, rank, fft_length=None):
    270     if rank == 1:
    271       return np.fft.irfft2(x, s=fft_length, axes=(-1,))
    272     elif rank == 2:
    273       return np.fft.irfft2(x, s=fft_length, axes=(-2, -1))
    274     elif rank == 3:
    275       return np.fft.irfft2(x, s=fft_length, axes=(-3, -2, -1))
    276     else:
    277       raise ValueError("invalid rank")
    278 
    279   def _tfFFTForRank(self, rank):
    280     if rank == 1:
    281       return spectral_ops.rfft
    282     elif rank == 2:
    283       return spectral_ops.rfft2d
    284     elif rank == 3:
    285       return spectral_ops.rfft3d
    286     else:
    287       raise ValueError("invalid rank")
    288 
    289   def _tfIFFTForRank(self, rank):
    290     if rank == 1:
    291       return spectral_ops.irfft
    292     elif rank == 2:
    293       return spectral_ops.irfft2d
    294     elif rank == 3:
    295       return spectral_ops.irfft3d
    296     else:
    297       raise ValueError("invalid rank")
    298 
    299   def testEmpty(self):
    300     with spectral_ops_test_util.fft_kernel_label_map():
    301       for rank in VALID_FFT_RANKS:
    302         for dims in xrange(rank, rank + 3):
    303           x = np.zeros((0,) * dims).astype(np.float32)
    304           self.assertEqual(x.shape, self._tfFFT(x, rank).shape)
    305           x = np.zeros((0,) * dims).astype(np.complex64)
    306           self.assertEqual(x.shape, self._tfIFFT(x, rank).shape)
    307 
    308   def testBasic(self):
    309     with spectral_ops_test_util.fft_kernel_label_map():
    310       for rank in VALID_FFT_RANKS:
    311         for dims in xrange(rank, rank + 3):
    312           for size in (5, 6):
    313             inner_dim = size // 2 + 1
    314             r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
    315                 (size,) * dims)
    316             self._compareForward(r2c.astype(np.float32), rank, (size,) * rank)
    317             c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
    318                          10).reshape((size,) * (dims - 1) + (inner_dim,))
    319             self._compareBackward(
    320                 c2r.astype(np.complex64), rank, (size,) * rank)
    321 
    322   def testLargeBatch(self):
    323     if test.is_gpu_available(cuda_only=True):
    324       rank = 1
    325       for dims in xrange(rank, rank + 3):
    326         for size in (64, 128):
    327           inner_dim = size // 2 + 1
    328           r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
    329               (size,) * dims)
    330           self._compareForward(r2c.astype(np.float32), rank, (size,) * rank)
    331           c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
    332                        10).reshape((size,) * (dims - 1) + (inner_dim,))
    333           self._compareBackward(c2r.astype(np.complex64), rank, (size,) * rank)
    334 
    335   def testBasicPlaceholder(self):
    336     with spectral_ops_test_util.fft_kernel_label_map():
    337       for rank in VALID_FFT_RANKS:
    338         for dims in xrange(rank, rank + 3):
    339           for size in (5, 6):
    340             inner_dim = size // 2 + 1
    341             r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
    342                 (size,) * dims)
    343             self._compareForward(
    344                 r2c.astype(np.float32),
    345                 rank, (size,) * rank,
    346                 use_placeholder=True)
    347             c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
    348                          10).reshape((size,) * (dims - 1) + (inner_dim,))
    349             self._compareBackward(
    350                 c2r.astype(np.complex64),
    351                 rank, (size,) * rank,
    352                 use_placeholder=True)
    353 
    354   def testFftLength(self):
    355     if test.is_gpu_available(cuda_only=True):
    356       with spectral_ops_test_util.fft_kernel_label_map():
    357         for rank in VALID_FFT_RANKS:
    358           for dims in xrange(rank, rank + 3):
    359             for size in (5, 6):
    360               inner_dim = size // 2 + 1
    361               r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
    362                   (size,) * dims)
    363               c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
    364                            10).reshape((size,) * (dims - 1) + (inner_dim,))
    365               # Test truncation (FFT size < dimensions).
    366               fft_length = (size - 2,) * rank
    367               self._compareForward(r2c.astype(np.float32), rank, fft_length)
    368               self._compareBackward(c2r.astype(np.complex64), rank, fft_length)
    369               # Confirm it works with unknown shapes as well.
    370               self._compareForward(
    371                   r2c.astype(np.float32),
    372                   rank,
    373                   fft_length,
    374                   use_placeholder=True)
    375               self._compareBackward(
    376                   c2r.astype(np.complex64),
    377                   rank,
    378                   fft_length,
    379                   use_placeholder=True)
    380               # Test padding (FFT size > dimensions).
    381               fft_length = (size + 2,) * rank
    382               self._compareForward(r2c.astype(np.float32), rank, fft_length)
    383               self._compareBackward(c2r.astype(np.complex64), rank, fft_length)
    384               # Confirm it works with unknown shapes as well.
    385               self._compareForward(
    386                   r2c.astype(np.float32),
    387                   rank,
    388                   fft_length,
    389                   use_placeholder=True)
    390               self._compareBackward(
    391                   c2r.astype(np.complex64),
    392                   rank,
    393                   fft_length,
    394                   use_placeholder=True)
    395 
    396   def testRandom(self):
    397     with spectral_ops_test_util.fft_kernel_label_map():
    398       np.random.seed(12345)
    399 
    400       def gen_real(shape):
    401         n = np.prod(shape)
    402         re = np.random.uniform(size=n)
    403         ret = re.reshape(shape)
    404         return ret
    405 
    406       def gen_complex(shape):
    407         n = np.prod(shape)
    408         re = np.random.uniform(size=n)
    409         im = np.random.uniform(size=n)
    410         ret = (re + im * 1j).reshape(shape)
    411         return ret
    412 
    413       for rank in VALID_FFT_RANKS:
    414         for dims in xrange(rank, rank + 3):
    415           for size in (5, 6):
    416             inner_dim = size // 2 + 1
    417             self._compareForward(gen_real((size,) * dims), rank, (size,) * rank)
    418             complex_dims = (size,) * (dims - 1) + (inner_dim,)
    419             self._compareBackward(
    420                 gen_complex(complex_dims), rank, (size,) * rank)
    421 
    422   def testError(self):
    423     with spectral_ops_test_util.fft_kernel_label_map():
    424       for rank in VALID_FFT_RANKS:
    425         for dims in xrange(0, rank):
    426           x = np.zeros((1,) * dims).astype(np.complex64)
    427           with self.assertRaisesWithPredicateMatch(
    428               ValueError, "Shape .* must have rank at least {}".format(rank)):
    429             self._tfFFT(x, rank)
    430           with self.assertRaisesWithPredicateMatch(
    431               ValueError, "Shape .* must have rank at least {}".format(rank)):
    432             self._tfIFFT(x, rank)
    433         for dims in xrange(rank, rank + 2):
    434           x = np.zeros((1,) * rank)
    435 
    436           # Test non-rank-1 fft_length produces an error.
    437           fft_length = np.zeros((1, 1)).astype(np.int32)
    438           with self.assertRaisesWithPredicateMatch(ValueError,
    439                                                    "Shape .* must have rank 1"):
    440             self._tfFFT(x, rank, fft_length)
    441           with self.assertRaisesWithPredicateMatch(ValueError,
    442                                                    "Shape .* must have rank 1"):
    443             self._tfIFFT(x, rank, fft_length)
    444 
    445           # Test wrong fft_length length.
    446           fft_length = np.zeros((rank + 1,)).astype(np.int32)
    447           with self.assertRaisesWithPredicateMatch(
    448               ValueError, "Dimension must be .*but is {}.*".format(rank + 1)):
    449             self._tfFFT(x, rank, fft_length)
    450           with self.assertRaisesWithPredicateMatch(
    451               ValueError, "Dimension must be .*but is {}.*".format(rank + 1)):
    452             self._tfIFFT(x, rank, fft_length)
    453 
    454         # Test that calling the kernel directly without padding to fft_length
    455         # produces an error.
    456         rffts_for_rank = {
    457             1: [gen_spectral_ops.rfft, gen_spectral_ops.irfft],
    458             2: [gen_spectral_ops.rfft2d, gen_spectral_ops.irfft2d],
    459             3: [gen_spectral_ops.rfft3d, gen_spectral_ops.irfft3d]
    460         }
    461         rfft_fn, irfft_fn = rffts_for_rank[rank]
    462         with self.assertRaisesWithPredicateMatch(
    463             errors.InvalidArgumentError,
    464             "Input dimension .* must have length of at least 6 but got: 5"):
    465           x = np.zeros((5,) * rank).astype(np.float32)
    466           fft_length = [6] * rank
    467           with self.test_session():
    468             rfft_fn(x, fft_length).eval()
    469 
    470         with self.assertRaisesWithPredicateMatch(
    471             errors.InvalidArgumentError,
    472             "Input dimension .* must have length of at least .* but got: 3"):
    473           x = np.zeros((3,) * rank).astype(np.complex64)
    474           fft_length = [6] * rank
    475           with self.test_session():
    476             irfft_fn(x, fft_length).eval()
    477 
    478   def testGrad_Simple(self):
    479     with spectral_ops_test_util.fft_kernel_label_map():
    480       for rank in VALID_FFT_RANKS:
    481         # rfft3d/irfft3d do not have gradients yet.
    482         if rank == 3:
    483           continue
    484         for dims in xrange(rank, rank + 2):
    485           for size in (5, 6):
    486             re = np.ones(shape=(size,) * dims, dtype=np.float32)
    487             im = -np.ones(shape=(size,) * dims, dtype=np.float32)
    488             self._checkGradReal(self._tfFFTForRank(rank), re)
    489             self._checkGradComplex(
    490                 self._tfIFFTForRank(rank), re, im, result_is_complex=False)
    491 
    492   def testGrad_Random(self):
    493     with spectral_ops_test_util.fft_kernel_label_map():
    494       np.random.seed(54321)
    495       for rank in VALID_FFT_RANKS:
    496         # rfft3d/irfft3d do not have gradients yet.
    497         if rank == 3:
    498           continue
    499         for dims in xrange(rank, rank + 2):
    500           for size in (5, 6):
    501             re = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1
    502             im = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1
    503             self._checkGradReal(self._tfFFTForRank(rank), re)
    504             self._checkGradComplex(
    505                 self._tfIFFTForRank(rank), re, im, result_is_complex=False)
    506 
    507 
    508 if __name__ == "__main__":
    509   test.main()
    510