Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
     12 
     13 
     14 namespace Eigen {
     15 namespace internal {
     16 
     17 enum {
     18   ShardByRow = 0,
     19   ShardByCol = 1
     20 };
     21 
     22 
     23 // Default Blocking Strategy
     24 template <typename LhsMapper, typename RhsMapper, typename Index, int ShardingType=ShardByCol>
     25 class TensorContractionBlocking {
     26  public:
     27 
     28   typedef typename LhsMapper::Scalar LhsScalar;
     29   typedef typename RhsMapper::Scalar RhsScalar;
     30 
     31   EIGEN_DEVICE_FUNC TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) :
     32       kc_(k), mc_(m), nc_(n)
     33   {
     34     if (ShardingType == ShardByCol) {
     35       computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, mc_, nc_, num_threads);
     36     }
     37     else {
     38       computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads);
     39     }
     40   }
     41 
     42   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index kc() const { return kc_; }
     43   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index mc() const { return mc_; }
     44   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index nc() const { return nc_; }
     45 
     46  private:
     47   Index kc_;
     48   Index mc_;
     49   Index nc_;
     50 };
     51 
     52 
     53 } // end namespace internal
     54 } // end namespace Eigen
     55 
     56 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
     57