1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_COMPLEX_CUDA_H 11 #define EIGEN_COMPLEX_CUDA_H 12 13 // clang-format off 14 15 namespace Eigen { 16 17 namespace internal { 18 19 #if defined(__CUDACC__) && defined(EIGEN_USE_GPU) 20 21 // Many std::complex methods such as operator+, operator-, operator* and 22 // operator/ are not constexpr. Due to this, clang does not treat them as device 23 // functions and thus Eigen functors making use of these operators fail to 24 // compile. Here, we manually specialize these functors for complex types when 25 // building for CUDA to avoid non-constexpr methods. 26 27 // Sum 28 template<typename T> struct scalar_sum_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > { 29 typedef typename std::complex<T> result_type; 30 31 EIGEN_EMPTY_STRUCT_CTOR(scalar_sum_op) 32 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { 33 return std::complex<T>(numext::real(a) + numext::real(b), 34 numext::imag(a) + numext::imag(b)); 35 } 36 }; 37 38 template<typename T> struct scalar_sum_op<std::complex<T>, std::complex<T> > : scalar_sum_op<const std::complex<T>, const std::complex<T> > {}; 39 40 41 // Difference 42 template<typename T> struct scalar_difference_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > { 43 typedef typename std::complex<T> result_type; 44 45 EIGEN_EMPTY_STRUCT_CTOR(scalar_difference_op) 46 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { 47 return std::complex<T>(numext::real(a) - numext::real(b), 48 numext::imag(a) - numext::imag(b)); 49 } 50 }; 51 52 template<typename T> struct scalar_difference_op<std::complex<T>, std::complex<T> > : scalar_difference_op<const std::complex<T>, const std::complex<T> > {}; 53 54 55 // Product 56 template<typename T> struct scalar_product_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > { 57 enum { 58 Vectorizable = packet_traits<std::complex<T>>::HasMul 59 }; 60 typedef typename std::complex<T> result_type; 61 62 EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op) 63 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { 64 const T a_real = numext::real(a); 65 const T a_imag = numext::imag(a); 66 const T b_real = numext::real(b); 67 const T b_imag = numext::imag(b); 68 return std::complex<T>(a_real * b_real - a_imag * b_imag, 69 a_real * b_imag + a_imag * b_real); 70 } 71 }; 72 73 template<typename T> struct scalar_product_op<std::complex<T>, std::complex<T> > : scalar_product_op<const std::complex<T>, const std::complex<T> > {}; 74 75 76 // Quotient 77 template<typename T> struct scalar_quotient_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > { 78 enum { 79 Vectorizable = packet_traits<std::complex<T>>::HasDiv 80 }; 81 typedef typename std::complex<T> result_type; 82 83 EIGEN_EMPTY_STRUCT_CTOR(scalar_quotient_op) 84 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const { 85 const T a_real = numext::real(a); 86 const T a_imag = numext::imag(a); 87 const T b_real = numext::real(b); 88 const T b_imag = numext::imag(b); 89 const T norm = T(1) / (b_real * b_real + b_imag * b_imag); 90 return std::complex<T>((a_real * b_real + a_imag * b_imag) * norm, 91 (a_imag * b_real - a_real * b_imag) * norm); 92 } 93 }; 94 95 template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {}; 96 97 #endif 98 99 } // end namespace internal 100 101 } // end namespace Eigen 102 103 #endif // EIGEN_COMPLEX_CUDA_H 104