1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Utilities for writing test involving spectral_ops.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import ops 22 from tensorflow.python.platform import test 23 24 25 def _use_eigen_kernels(): 26 use_eigen_kernels = False # Eigen kernels are default 27 if test.is_gpu_available(cuda_only=True): 28 use_eigen_kernels = False 29 return use_eigen_kernels 30 31 32 def fft_kernel_label_map(): 33 """Returns a generator overriding kernel selection. 34 35 This is used to force testing of the eigen kernels, even 36 when they are not the default registered kernels. 37 38 Returns: 39 A generator in which to wrap every test. 40 """ 41 if _use_eigen_kernels(): 42 d = dict([(op, "eigen") 43 for op in [ 44 "FFT", "FFT2D", "FFT3D", "IFFT", "IFFT2D", "IFFT3D", 45 "IRFFT", "IRFFT2D", "IRFFT3D", "RFFT", "RFFT2D", "RFFT3D" 46 ]]) 47 return ops.get_default_graph()._kernel_label_map(d) # pylint: disable=protected-access 48 else: 49 return ops.get_default_graph()._kernel_label_map({}) # pylint: disable=protected-access 50 51