Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
     12 
     13 namespace Eigen {
     14 
     15 /** \class TensorDevice
     16   * \ingroup CXX11_Tensor_Module
     17   *
     18   * \brief Pseudo expression providing an operator = that will evaluate its argument
     19   * on the specified computing 'device' (GPU, thread pool, ...)
     20   *
     21   * Example:
     22   *    C.device(EIGEN_GPU) = A + B;
     23   *
     24   * Todo: operator *= and /=.
     25   */
     26 
     27 template <typename ExpressionType, typename DeviceType> class TensorDevice {
     28   public:
     29     TensorDevice(const DeviceType& device, ExpressionType& expression) : m_device(device), m_expression(expression) {}
     30 
     31     template<typename OtherDerived>
     32     EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) {
     33       typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
     34       Assign assign(m_expression, other);
     35       internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
     36       return *this;
     37     }
     38 
     39     template<typename OtherDerived>
     40     EIGEN_STRONG_INLINE TensorDevice& operator+=(const OtherDerived& other) {
     41       typedef typename OtherDerived::Scalar Scalar;
     42       typedef TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const ExpressionType, const OtherDerived> Sum;
     43       Sum sum(m_expression, other);
     44       typedef TensorAssignOp<ExpressionType, const Sum> Assign;
     45       Assign assign(m_expression, sum);
     46       internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
     47       return *this;
     48     }
     49 
     50     template<typename OtherDerived>
     51     EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
     52       typedef typename OtherDerived::Scalar Scalar;
     53       typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
     54       Difference difference(m_expression, other);
     55       typedef TensorAssignOp<ExpressionType, const Difference> Assign;
     56       Assign assign(m_expression, difference);
     57       internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
     58       return *this;
     59     }
     60 
     61   protected:
     62     const DeviceType& m_device;
     63     ExpressionType& m_expression;
     64 };
     65 
     66 } // end namespace Eigen
     67 
     68 #endif // EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
     69