Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Mehdi Goli    Codeplay Software Ltd.
      5 // Ralph Potter  Codeplay Software Ltd.
      6 // Luke Iwanski  Codeplay Software Ltd.
      7 // Contact: <eigen (at) codeplay.com>
      8 //
      9 // This Source Code Form is subject to the terms of the Mozilla
     10 // Public License v. 2.0. If a copy of the MPL was not distributed
     11 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     12 
     13 /*****************************************************************
     14  * TensorSyclextractFunctors.h
     15  *
     16  * \brief:
     17  *  Used to extract all the functors allocated to each node of the expression
     18 *tree.
     19  *
     20 *****************************************************************/
     21 
     22 #ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP
     23 #define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP
     24 
     25 namespace Eigen {
     26 namespace TensorSycl {
     27 namespace internal {
     28 /// \struct FunctorExtractor:  This struct is used to extract the functors
     29 /// constructed on
     30 /// the host-side, to pack them and reuse them in reconstruction of the
     31 /// expression on the device.
     32 /// We have to do that as in Eigen the functors are not stateless so we cannot
     33 /// re-instantiate them on the device.
     34 /// We have to pass instantiated functors to the device.
     35 // This struct is used for leafNode (TensorMap) and nodes behaving like leafNode (TensorForcedEval).
     36 template <typename Evaluator> struct FunctorExtractor{
     37   typedef typename Evaluator::Dimensions Dimensions;
     38   const Dimensions m_dimensions;
     39   const Dimensions& dimensions() const { return m_dimensions; }
     40   FunctorExtractor(const Evaluator& expr)
     41   : m_dimensions(expr.dimensions()) {}
     42 
     43 };
     44 
     45 /// specialisation of the \ref FunctorExtractor struct when the node type is
     46 /// const TensorCwiseNullaryOp, const TensorCwiseUnaryOp, and const TensorBroadcastingOp
     47 template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev>
     48 struct FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> > {
     49   FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
     50   OP func;
     51   FunctorExtractor(const TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev>& expr)
     52   : rhsExpr(expr.impl()), func(expr.functor()) {}
     53 };
     54 /// specialisation of the \ref FunctorExtractor struct when the node type is
     55 /// TensorCwiseNullaryOp, TensorCwiseUnaryOp, and TensorBroadcastingOp
     56 template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev>
     57 struct FunctorExtractor<TensorEvaluator<UnaryCategory<OP, RHSExpr>, Dev> >
     58 : FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> >{};
     59 
     60 /// specialisation of the \ref FunctorExtractor struct when the node type is
     61 /// const TensorCwiseBinaryOp
     62 template <template<class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev>
     63 struct FunctorExtractor<TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> > {
     64   FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr;
     65   FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
     66   OP func;
     67   FunctorExtractor(const TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev>& expr)
     68   : lhsExpr(expr.left_impl()),rhsExpr(expr.right_impl()),func(expr.functor()) {}
     69 };
     70 
     71 /// specialisation of the \ref FunctorExtractor struct when the node type is
     72 /// const TensorCwiseBinaryOp
     73 template <template <class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev>
     74 struct FunctorExtractor<TensorEvaluator<BinaryCategory<OP,  LHSExpr, RHSExpr>, Dev> >
     75 : FunctorExtractor<TensorEvaluator<const BinaryCategory<OP,  LHSExpr, RHSExpr>, Dev> >{};
     76 
     77 /// specialisation of the \ref FunctorExtractor struct when the node type is
     78 /// const TensorCwiseTernaryOp
     79 template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr,typename Dev>
     80 struct FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> > {
     81   FunctorExtractor<TensorEvaluator<Arg1Expr, Dev> > arg1Expr;
     82   FunctorExtractor<TensorEvaluator<Arg2Expr, Dev> > arg2Expr;
     83   FunctorExtractor<TensorEvaluator<Arg3Expr, Dev> > arg3Expr;
     84   OP func;
     85   FunctorExtractor(const TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>& expr)
     86   : arg1Expr(expr.arg1Impl()), arg2Expr(expr.arg2Impl()), arg3Expr(expr.arg3Impl()), func(expr.functor()) {}
     87 };
     88 
     89 /// specialisation of the \ref FunctorExtractor struct when the node type is
     90 /// TensorCwiseTernaryOp
     91 template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr, typename Dev>
     92 struct FunctorExtractor<TensorEvaluator< TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> >
     93 :FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> >{};
     94 
     95 /// specialisation of the \ref FunctorExtractor struct when the node type is
     96 /// const TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated.
     97 template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev>
     98 struct FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > {
     99   FunctorExtractor<TensorEvaluator<IfExpr, Dev> > ifExpr;
    100   FunctorExtractor<TensorEvaluator<ThenExpr, Dev> > thenExpr;
    101   FunctorExtractor<TensorEvaluator<ElseExpr, Dev> > elseExpr;
    102   FunctorExtractor(const TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>& expr)
    103   : ifExpr(expr.cond_impl()), thenExpr(expr.then_impl()), elseExpr(expr.else_impl()) {}
    104 };
    105 
    106 /// specialisation of the \ref FunctorExtractor struct when the node type is
    107 /// TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated
    108 template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev>
    109 struct FunctorExtractor<TensorEvaluator<TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> >
    110 :FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > {};
    111 
    112 /// specialisation of the \ref FunctorExtractor struct when the node type is
    113 /// const TensorAssignOp. This is an specialisation without OP so it has to be separated.
    114 template <typename LHSExpr, typename RHSExpr, typename Dev>
    115 struct FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> > {
    116   FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr;
    117   FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
    118   FunctorExtractor(const TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev>& expr)
    119   : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {}
    120 };
    121 
    122 /// specialisation of the \ref FunctorExtractor struct when the node type is
    123 /// TensorAssignOp. This is an specialisation without OP so it has to be separated.
    124 template <typename LHSExpr, typename RHSExpr, typename Dev>
    125 struct FunctorExtractor<TensorEvaluator<TensorAssignOp<LHSExpr, RHSExpr>, Dev> >
    126 :FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> >{};
    127 
    128 
    129 /// specialisation of the \ref FunctorExtractor struct when the node type is
    130 /// const TensorEvalToOp, This is an specialisation without OP so it has to be separated.
    131 template <typename RHSExpr, typename Dev>
    132 struct FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > {
    133   FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
    134   FunctorExtractor(const TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev>& expr)
    135   : rhsExpr(expr.impl()) {}
    136 };
    137 
    138 /// specialisation of the \ref FunctorExtractor struct when the node type is
    139 /// TensorEvalToOp. This is a specialisation without OP so it has to be separated.
    140 template <typename RHSExpr, typename Dev>
    141 struct FunctorExtractor<TensorEvaluator<TensorEvalToOp<RHSExpr>, Dev> >
    142 : FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > {};
    143 
    144 template<typename Dim, size_t NumOutputDim> struct DimConstr {
    145 template<typename InDim>
    146   static inline Dim getDim(InDim dims ) {return dims;}
    147 };
    148 
    149 template<typename Dim> struct DimConstr<Dim, 0> {
    150   template<typename InDim>
    151     static inline Dim getDim(InDim dims ) {return Dim(dims.TotalSize());}
    152 };
    153 
    154 template<typename Op, typename Dims, typename ArgType, template <class> class MakePointer_, typename Device>
    155 struct FunctorExtractor<TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>{
    156   typedef TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device> Evaluator;
    157   typedef typename Eigen::internal::conditional<Evaluator::NumOutputDims==0, DSizes<typename Evaluator::Index, 1>, typename Evaluator::Dimensions >::type Dimensions;
    158   const Dimensions m_dimensions;
    159   const Dimensions& dimensions() const { return m_dimensions; }
    160   FunctorExtractor(const TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>& expr)
    161   : m_dimensions(DimConstr<Dimensions, Evaluator::NumOutputDims>::getDim(expr.dimensions())) {}
    162 };
    163 
    164 
    165 template<typename Op, typename Dims, typename ArgType, template <class> class MakePointer_, typename Device>
    166 struct FunctorExtractor<TensorEvaluator<TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>
    167 : FunctorExtractor<TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>{};
    168 /// template deduction function for FunctorExtractor
    169 template <typename Evaluator>
    170 auto inline extractFunctors(const Evaluator& evaluator)-> FunctorExtractor<Evaluator> {
    171   return FunctorExtractor<Evaluator>(evaluator);
    172 }
    173 }  // namespace internal
    174 }  // namespace TensorSycl
    175 }  // namespace Eigen
    176 
    177 #endif  // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP
    178