1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for DCT operations.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import importlib 22 23 import numpy as np 24 25 from tensorflow.python.ops import spectral_ops 26 from tensorflow.python.ops import spectral_ops_test_util 27 from tensorflow.python.platform import test 28 from tensorflow.python.platform import tf_logging 29 30 31 def try_import(name): # pylint: disable=invalid-name 32 module = None 33 try: 34 module = importlib.import_module(name) 35 except ImportError as e: 36 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 37 return module 38 39 40 fftpack = try_import("scipy.fftpack") 41 42 43 class DCTOpsTest(test.TestCase): 44 45 def _np_dct2(self, signals, norm=None): 46 """Computes the DCT-II manually with NumPy.""" 47 # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1 48 dct_size = signals.shape[-1] 49 dct = np.zeros_like(signals) 50 for k in range(dct_size): 51 phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size) 52 dct[..., k] = np.sum(signals * phi, axis=-1) 53 # SciPy's `dct` has a scaling factor of 2.0 which we follow. 54 # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src 55 if norm == "ortho": 56 # The orthonormal scaling includes a factor of 0.5 which we combine with 57 # the overall scaling of 2.0 to cancel. 58 dct[..., 0] *= np.sqrt(1.0 / dct_size) 59 dct[..., 1:] *= np.sqrt(2.0 / dct_size) 60 else: 61 dct *= 2.0 62 return dct 63 64 def _compare(self, signals, norm, atol=5e-4, rtol=5e-4): 65 """Compares the DCT to SciPy (if available) and a NumPy implementation.""" 66 np_dct = self._np_dct2(signals, norm) 67 tf_dct = spectral_ops.dct(signals, type=2, norm=norm).eval() 68 self.assertAllClose(np_dct, tf_dct, atol=atol, rtol=rtol) 69 if fftpack: 70 scipy_dct = fftpack.dct(signals, type=2, norm=norm) 71 self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol) 72 73 def test_random(self): 74 """Test randomly generated batches of data.""" 75 with spectral_ops_test_util.fft_kernel_label_map(): 76 with self.test_session(use_gpu=True): 77 for shape in ([2, 20], [1], [2], [3], [10], [2, 20], [2, 3, 25]): 78 signals = np.random.rand(*shape).astype(np.float32) 79 for norm in (None, "ortho"): 80 self._compare(signals, norm) 81 82 def test_error(self): 83 signals = np.random.rand(10) 84 # Unsupported type. 85 with self.assertRaises(ValueError): 86 spectral_ops.dct(signals, type=3) 87 # Unknown normalization. 88 with self.assertRaises(ValueError): 89 spectral_ops.dct(signals, norm="bad") 90 with self.assertRaises(NotImplementedError): 91 spectral_ops.dct(signals, n=10) 92 with self.assertRaises(NotImplementedError): 93 spectral_ops.dct(signals, axis=0) 94 95 96 if __name__ == "__main__": 97 test.main() 98