Home | History | Annotate | Download | only in CUDA
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_COMPLEX_CUDA_H
     11 #define EIGEN_COMPLEX_CUDA_H
     12 
     13 // clang-format off
     14 
     15 namespace Eigen {
     16 
     17 namespace internal {
     18 
     19 #if defined(__CUDACC__) && defined(EIGEN_USE_GPU)
     20 
     21 // Many std::complex methods such as operator+, operator-, operator* and
     22 // operator/ are not constexpr. Due to this, clang does not treat them as device
     23 // functions and thus Eigen functors making use of these operators fail to
     24 // compile. Here, we manually specialize these functors for complex types when
     25 // building for CUDA to avoid non-constexpr methods.
     26 
     27 // Sum
     28 template<typename T> struct scalar_sum_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > {
     29   typedef typename std::complex<T> result_type;
     30 
     31   EIGEN_EMPTY_STRUCT_CTOR(scalar_sum_op)
     32   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
     33     return std::complex<T>(numext::real(a) + numext::real(b),
     34                            numext::imag(a) + numext::imag(b));
     35   }
     36 };
     37 
     38 template<typename T> struct scalar_sum_op<std::complex<T>, std::complex<T> > : scalar_sum_op<const std::complex<T>, const std::complex<T> > {};
     39 
     40 
     41 // Difference
     42 template<typename T> struct scalar_difference_op<const std::complex<T>, const std::complex<T> >  : binary_op_base<const std::complex<T>, const std::complex<T> > {
     43   typedef typename std::complex<T> result_type;
     44 
     45   EIGEN_EMPTY_STRUCT_CTOR(scalar_difference_op)
     46   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
     47     return std::complex<T>(numext::real(a) - numext::real(b),
     48                            numext::imag(a) - numext::imag(b));
     49   }
     50 };
     51 
     52 template<typename T> struct scalar_difference_op<std::complex<T>, std::complex<T> > : scalar_difference_op<const std::complex<T>, const std::complex<T> > {};
     53 
     54 
     55 // Product
     56 template<typename T> struct scalar_product_op<const std::complex<T>, const std::complex<T> >  : binary_op_base<const std::complex<T>, const std::complex<T> > {
     57   enum {
     58     Vectorizable = packet_traits<std::complex<T>>::HasMul
     59   };
     60   typedef typename std::complex<T> result_type;
     61 
     62   EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op)
     63   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
     64     const T a_real = numext::real(a);
     65     const T a_imag = numext::imag(a);
     66     const T b_real = numext::real(b);
     67     const T b_imag = numext::imag(b);
     68     return std::complex<T>(a_real * b_real - a_imag * b_imag,
     69                            a_real * b_imag + a_imag * b_real);
     70   }
     71 };
     72 
     73 template<typename T> struct scalar_product_op<std::complex<T>, std::complex<T> > : scalar_product_op<const std::complex<T>, const std::complex<T> > {};
     74 
     75 
     76 // Quotient
     77 template<typename T> struct scalar_quotient_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > {
     78   enum {
     79     Vectorizable = packet_traits<std::complex<T>>::HasDiv
     80   };
     81   typedef typename std::complex<T> result_type;
     82 
     83   EIGEN_EMPTY_STRUCT_CTOR(scalar_quotient_op)
     84   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
     85     const T a_real = numext::real(a);
     86     const T a_imag = numext::imag(a);
     87     const T b_real = numext::real(b);
     88     const T b_imag = numext::imag(b);
     89     const T norm = T(1) / (b_real * b_real + b_imag * b_imag);
     90     return std::complex<T>((a_real * b_real + a_imag * b_imag) * norm,
     91                            (a_imag * b_real - a_real * b_imag) * norm);
     92   }
     93 };
     94 
     95 template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {};
     96 
     97 #endif
     98 
     99 } // end namespace internal
    100 
    101 } // end namespace Eigen
    102 
    103 #endif // EIGEN_COMPLEX_CUDA_H
    104