Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for DCT operations."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import importlib
     22 
     23 import numpy as np
     24 
     25 from tensorflow.python.ops import spectral_ops
     26 from tensorflow.python.ops import spectral_ops_test_util
     27 from tensorflow.python.platform import test
     28 from tensorflow.python.platform import tf_logging
     29 
     30 
     31 def try_import(name):  # pylint: disable=invalid-name
     32   module = None
     33   try:
     34     module = importlib.import_module(name)
     35   except ImportError as e:
     36     tf_logging.warning("Could not import %s: %s" % (name, str(e)))
     37   return module
     38 
     39 
     40 fftpack = try_import("scipy.fftpack")
     41 
     42 
     43 class DCTOpsTest(test.TestCase):
     44 
     45   def _np_dct2(self, signals, norm=None):
     46     """Computes the DCT-II manually with NumPy."""
     47     # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k)  k=0,...,N-1
     48     dct_size = signals.shape[-1]
     49     dct = np.zeros_like(signals)
     50     for k in range(dct_size):
     51       phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
     52       dct[..., k] = np.sum(signals * phi, axis=-1)
     53     # SciPy's `dct` has a scaling factor of 2.0 which we follow.
     54     # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
     55     if norm == "ortho":
     56       # The orthonormal scaling includes a factor of 0.5 which we combine with
     57       # the overall scaling of 2.0 to cancel.
     58       dct[..., 0] *= np.sqrt(1.0 / dct_size)
     59       dct[..., 1:] *= np.sqrt(2.0 / dct_size)
     60     else:
     61       dct *= 2.0
     62     return dct
     63 
     64   def _compare(self, signals, norm, atol=5e-4, rtol=5e-4):
     65     """Compares the DCT to SciPy (if available) and a NumPy implementation."""
     66     np_dct = self._np_dct2(signals, norm)
     67     tf_dct = spectral_ops.dct(signals, type=2, norm=norm).eval()
     68     self.assertAllClose(np_dct, tf_dct, atol=atol, rtol=rtol)
     69     if fftpack:
     70       scipy_dct = fftpack.dct(signals, type=2, norm=norm)
     71       self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol)
     72 
     73   def test_random(self):
     74     """Test randomly generated batches of data."""
     75     with spectral_ops_test_util.fft_kernel_label_map():
     76       with self.test_session(use_gpu=True):
     77         for shape in ([2, 20], [1], [2], [3], [10], [2, 20], [2, 3, 25]):
     78           signals = np.random.rand(*shape).astype(np.float32)
     79           for norm in (None, "ortho"):
     80             self._compare(signals, norm)
     81 
     82   def test_error(self):
     83     signals = np.random.rand(10)
     84     # Unsupported type.
     85     with self.assertRaises(ValueError):
     86       spectral_ops.dct(signals, type=3)
     87     # Unknown normalization.
     88     with self.assertRaises(ValueError):
     89       spectral_ops.dct(signals, norm="bad")
     90     with self.assertRaises(NotImplementedError):
     91       spectral_ops.dct(signals, n=10)
     92     with self.assertRaises(NotImplementedError):
     93       spectral_ops.dct(signals, axis=0)
     94 
     95 
     96 if __name__ == "__main__":
     97   test.main()
     98