1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 12 13 namespace Eigen { 14 15 /** \class TensorContraction 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief Tensor contraction class. 19 * 20 * 21 */ 22 namespace internal { 23 24 template<typename Dimensions, typename LhsXprType, typename RhsXprType> 25 struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > 26 { 27 // Type promotion to handle the case where the types of the lhs and the rhs are different. 28 typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type, 29 typename remove_const<typename RhsXprType::Scalar>::type>::ResScalar Scalar; 30 31 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind, 32 typename traits<RhsXprType>::StorageKind>::ret StorageKind; 33 typedef typename promote_index_type<typename traits<LhsXprType>::Index, 34 typename traits<RhsXprType>::Index>::type Index; 35 typedef typename LhsXprType::Nested LhsNested; 36 typedef typename RhsXprType::Nested RhsNested; 37 typedef typename remove_reference<LhsNested>::type _LhsNested; 38 typedef typename remove_reference<RhsNested>::type _RhsNested; 39 40 // From NumDims below. 41 static const int NumDimensions = traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value; 42 static const int Layout = traits<LhsXprType>::Layout; 43 44 enum { 45 Flags = 0 46 }; 47 }; 48 49 template<typename Dimensions, typename LhsXprType, typename RhsXprType> 50 struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, Eigen::Dense> 51 { 52 typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type; 53 }; 54 55 template<typename Dimensions, typename LhsXprType, typename RhsXprType> 56 struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type> 57 { 58 typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type; 59 }; 60 61 template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename Device_> 62 struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_>, Device_> > { 63 typedef Indices_ Indices; 64 typedef LeftArgType_ LeftArgType; 65 typedef RightArgType_ RightArgType; 66 typedef Device_ Device; 67 68 // From NumDims below. 69 static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value; 70 }; 71 72 } // end namespace internal 73 74 template<typename Indices, typename LhsXprType, typename RhsXprType> 75 class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType>, ReadOnlyAccessors> 76 { 77 public: 78 typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar; 79 typedef typename internal::gebp_traits<typename LhsXprType::CoeffReturnType, 80 typename RhsXprType::CoeffReturnType>::ResScalar CoeffReturnType; 81 typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested; 82 typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind; 83 typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index; 84 85 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp( 86 const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims) 87 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {} 88 89 EIGEN_DEVICE_FUNC 90 const Indices& indices() const { return m_indices; } 91 92 /** \returns the nested expressions */ 93 EIGEN_DEVICE_FUNC 94 const typename internal::remove_all<typename LhsXprType::Nested>::type& 95 lhsExpression() const { return m_lhs_xpr; } 96 97 EIGEN_DEVICE_FUNC 98 const typename internal::remove_all<typename RhsXprType::Nested>::type& 99 rhsExpression() const { return m_rhs_xpr; } 100 101 protected: 102 typename LhsXprType::Nested m_lhs_xpr; 103 typename RhsXprType::Nested m_rhs_xpr; 104 const Indices m_indices; 105 }; 106 107 108 template<typename Derived> 109 struct TensorContractionEvaluatorBase 110 { 111 typedef typename internal::traits<Derived>::Indices Indices; 112 typedef typename internal::traits<Derived>::LeftArgType LeftArgType; 113 typedef typename internal::traits<Derived>::RightArgType RightArgType; 114 typedef typename internal::traits<Derived>::Device Device; 115 116 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; 117 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; 118 typedef typename XprType::Index Index; 119 typedef typename XprType::CoeffReturnType CoeffReturnType; 120 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 121 122 enum { 123 IsAligned = true, 124 PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1), 125 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 126 CoordAccess = false, // to be implemented 127 RawAccess = true 128 }; 129 130 // Most of the code is assuming that both input tensors are ColMajor. If the 131 // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: 132 // If we want to compute A * B = C, where A is LHS and B is RHS, the code 133 // will pretend B is LHS and A is RHS. 134 typedef typename internal::conditional< 135 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; 136 typedef typename internal::conditional< 137 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; 138 139 static const int LDims = 140 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value; 141 static const int RDims = 142 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; 143 static const int ContractDims = internal::array_size<Indices>::value; 144 static const int NumDims = LDims + RDims - 2 * ContractDims; 145 146 typedef array<Index, ContractDims> contract_t; 147 typedef array<Index, LDims - ContractDims> left_nocontract_t; 148 typedef array<Index, RDims - ContractDims> right_nocontract_t; 149 150 typedef DSizes<Index, NumDims> Dimensions; 151 152 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 153 TensorContractionEvaluatorBase(const XprType& op, const Device& device) 154 : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), 155 op.lhsExpression(), op.rhsExpression()), device), 156 m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), 157 op.rhsExpression(), op.lhsExpression()), device), 158 m_device(device), 159 m_result(NULL) { 160 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == 161 static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)), 162 YOU_MADE_A_PROGRAMMING_MISTAKE); 163 164 165 DSizes<Index, LDims> eval_left_dims; 166 DSizes<Index, RDims> eval_right_dims; 167 array<IndexPair<Index>, ContractDims> eval_op_indices; 168 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 169 // For ColMajor, we keep using the existing dimensions 170 for (int i = 0; i < LDims; i++) { 171 eval_left_dims[i] = m_leftImpl.dimensions()[i]; 172 } 173 for (int i = 0; i < RDims; i++) { 174 eval_right_dims[i] = m_rightImpl.dimensions()[i]; 175 } 176 // We keep the pairs of contracting indices. 177 for (int i = 0; i < ContractDims; i++) { 178 eval_op_indices[i].first = op.indices()[i].first; 179 eval_op_indices[i].second = op.indices()[i].second; 180 } 181 } else { 182 // For RowMajor, we need to reverse the existing dimensions 183 for (int i = 0; i < LDims; i++) { 184 eval_left_dims[i] = m_leftImpl.dimensions()[LDims - i - 1]; 185 } 186 for (int i = 0; i < RDims; i++) { 187 eval_right_dims[i] = m_rightImpl.dimensions()[RDims - i - 1]; 188 } 189 // We need to flip all the pairs of contracting indices as well as 190 // reversing the dimensions. 191 for (int i = 0; i < ContractDims; i++) { 192 eval_op_indices[i].first = LDims - 1 - op.indices()[ContractDims - 1 - i].second; 193 eval_op_indices[i].second = RDims - 1 - op.indices()[ContractDims - 1 - i].first; 194 } 195 } 196 197 // Check for duplicate axes and make sure the first index in eval_op_indices 198 // is increasing. Using O(n^2) sorting is OK since ContractDims is small 199 for (int i = 0; i < ContractDims; i++) { 200 for (int j = i + 1; j < ContractDims; j++) { 201 eigen_assert(eval_op_indices[j].first != eval_op_indices[i].first && 202 eval_op_indices[j].second != eval_op_indices[i].second && 203 "contraction axes should be unique"); 204 if (eval_op_indices[j].first < eval_op_indices[i].first) { 205 numext::swap(eval_op_indices[j], eval_op_indices[i]); 206 } 207 } 208 } 209 210 array<Index, LDims> lhs_strides; 211 lhs_strides[0] = 1; 212 for (int i = 0; i < LDims-1; ++i) { 213 lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i]; 214 } 215 216 array<Index, RDims> rhs_strides; 217 rhs_strides[0] = 1; 218 for (int i = 0; i < RDims-1; ++i) { 219 rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i]; 220 } 221 222 if (m_i_strides.size() > 0) m_i_strides[0] = 1; 223 if (m_j_strides.size() > 0) m_j_strides[0] = 1; 224 if (m_k_strides.size() > 0) m_k_strides[0] = 1; 225 226 m_i_size = 1; 227 m_j_size = 1; 228 m_k_size = 1; 229 230 // To compute the dimension, we simply concatenate the non-contracting 231 // dimensions of the left and then the right tensor. Additionally, we also 232 // compute the strides corresponding to the left non-contracting 233 // dimensions and right non-contracting dimensions. 234 m_lhs_inner_dim_contiguous = true; 235 int dim_idx = 0; 236 unsigned int nocontract_idx = 0; 237 238 for (int i = 0; i < LDims; i++) { 239 // find if we are contracting on index i of left tensor 240 bool contracting = false; 241 for (int j = 0; j < ContractDims; j++) { 242 if (eval_op_indices[j].first == i) { 243 contracting = true; 244 break; 245 } 246 } 247 if (!contracting) { 248 // add dimension size to output dimensions 249 m_dimensions[dim_idx] = eval_left_dims[i]; 250 m_left_nocontract_strides[nocontract_idx] = lhs_strides[i]; 251 if (dim_idx != i) { 252 m_lhs_inner_dim_contiguous = false; 253 } 254 if (nocontract_idx+1 < internal::array_size<left_nocontract_t>::value) { 255 m_i_strides[nocontract_idx+1] = 256 m_i_strides[nocontract_idx] * eval_left_dims[i]; 257 } else { 258 m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[i]; 259 } 260 dim_idx++; 261 nocontract_idx++; 262 } 263 } 264 265 nocontract_idx = 0; 266 for (int i = 0; i < RDims; i++) { 267 bool contracting = false; 268 // find if we are contracting on index i of right tensor 269 for (int j = 0; j < ContractDims; j++) { 270 if (eval_op_indices[j].second == i) { 271 contracting = true; 272 break; 273 } 274 } 275 if (!contracting) { 276 m_dimensions[dim_idx] = eval_right_dims[i]; 277 if (nocontract_idx+1 < internal::array_size<right_nocontract_t>::value) { 278 m_j_strides[nocontract_idx+1] = 279 m_j_strides[nocontract_idx] * eval_right_dims[i]; 280 } else { 281 m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[i]; 282 } 283 m_right_nocontract_strides[nocontract_idx] = rhs_strides[i]; 284 dim_idx++; 285 nocontract_idx++; 286 } 287 } 288 289 // Now compute the strides corresponding to the contracting dimensions. We 290 // assumed above that non-contracting axes are represented in the same order 291 // in the matrix as they are in the tensor. This is not the case for 292 // contracting axes. As the contracting axes must be of the same size in 293 // each tensor, we'll only look at the first tensor here. 294 m_rhs_inner_dim_contiguous = true; 295 m_rhs_inner_dim_reordered = false; 296 for (int i = 0; i < ContractDims; i++) { 297 Index left = eval_op_indices[i].first; 298 Index right = eval_op_indices[i].second; 299 300 Index size = eval_left_dims[left]; 301 eigen_assert(size == eval_right_dims[right] && 302 "Contraction axes must be same size"); 303 304 if (i+1 < static_cast<int>(internal::array_size<contract_t>::value)) { 305 m_k_strides[i+1] = m_k_strides[i] * size; 306 } else { 307 m_k_size = m_k_strides[i] * size; 308 } 309 m_left_contracting_strides[i] = lhs_strides[left]; 310 m_right_contracting_strides[i] = rhs_strides[right]; 311 312 if (i > 0 && right < eval_op_indices[i-1].second) { 313 m_rhs_inner_dim_reordered = true; 314 } 315 if (right != i) { 316 m_rhs_inner_dim_contiguous = false; 317 } 318 } 319 320 // If the layout is RowMajor, we need to reverse the m_dimensions 321 if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) { 322 for (int i = 0, j = NumDims - 1; i < j; i++, j--) { 323 numext::swap(m_dimensions[i], m_dimensions[j]); 324 } 325 } 326 } 327 328 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 329 330 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { 331 m_leftImpl.evalSubExprsIfNeeded(NULL); 332 m_rightImpl.evalSubExprsIfNeeded(NULL); 333 if (data) { 334 evalTo(data); 335 return false; 336 } else { 337 m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); 338 evalTo(m_result); 339 return true; 340 } 341 } 342 343 EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const { 344 if (this->m_lhs_inner_dim_contiguous) { 345 if (this->m_rhs_inner_dim_contiguous) { 346 if (this->m_rhs_inner_dim_reordered) { 347 static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer); 348 } 349 else { 350 static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer); 351 } 352 } 353 else { 354 if (this->m_rhs_inner_dim_reordered) { 355 static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer); 356 } 357 else { 358 static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer); 359 } 360 } 361 } 362 else { 363 if (this->m_rhs_inner_dim_contiguous) { 364 if (this->m_rhs_inner_dim_reordered) { 365 static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer); 366 } 367 else { 368 static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer); 369 } 370 } 371 else { 372 if (this->m_rhs_inner_dim_reordered) { 373 static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer); 374 } 375 else { 376 static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer); 377 } 378 } 379 } 380 } 381 382 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> 383 EIGEN_DEVICE_FUNC void evalGemv(Scalar* buffer) const { 384 const Index rows = m_i_size; 385 const Index cols = m_k_size; 386 387 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; 388 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; 389 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; 390 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; 391 const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size; 392 const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size; 393 const int lhs_alignment = LeftEvaluator::IsAligned ? Aligned : Unaligned; 394 const int rhs_alignment = RightEvaluator::IsAligned ? Aligned : Unaligned; 395 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, 396 LeftEvaluator, left_nocontract_t, 397 contract_t, lhs_packet_size, 398 lhs_inner_dim_contiguous, 399 false, lhs_alignment> LhsMapper; 400 401 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, 402 RightEvaluator, right_nocontract_t, 403 contract_t, rhs_packet_size, 404 rhs_inner_dim_contiguous, 405 rhs_inner_dim_reordered, rhs_alignment> RhsMapper; 406 407 LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides, 408 m_left_contracting_strides, m_k_strides); 409 RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides, 410 m_right_contracting_strides, m_k_strides); 411 412 const Scalar alpha(1); 413 const Index resIncr(1); 414 415 // zero out the result buffer (which must be of size at least rows * sizeof(Scalar) 416 m_device.memset(buffer, 0, rows * sizeof(Scalar)); 417 418 internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run( 419 rows, cols, lhs, rhs, 420 buffer, resIncr, alpha); 421 } 422 423 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> 424 EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const { 425 // columns in left side, rows in right side 426 const Index k = this->m_k_size; 427 428 // rows in left side 429 const Index m = this->m_i_size; 430 431 // columns in right side 432 const Index n = this->m_j_size; 433 434 // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) 435 this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); 436 437 // define mr, nr, and all of my data mapper types 438 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; 439 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; 440 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits; 441 442 const Index nr = Traits::nr; 443 const Index mr = Traits::mr; 444 445 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; 446 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; 447 448 const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size; 449 const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size; 450 451 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, 452 LeftEvaluator, left_nocontract_t, 453 contract_t, lhs_packet_size, 454 lhs_inner_dim_contiguous, 455 false, Unaligned> LhsMapper; 456 457 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, 458 RightEvaluator, right_nocontract_t, 459 contract_t, rhs_packet_size, 460 rhs_inner_dim_contiguous, 461 rhs_inner_dim_reordered, Unaligned> RhsMapper; 462 463 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; 464 465 // Declare GEBP packing and kernel structs 466 internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress, ColMajor> pack_lhs; 467 internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, nr, ColMajor> pack_rhs; 468 469 internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, mr, nr, false, false> gebp; 470 471 // initialize data mappers 472 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides, 473 this->m_left_contracting_strides, this->m_k_strides); 474 475 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides, 476 this->m_right_contracting_strides, this->m_k_strides); 477 478 OutputMapper output(buffer, m); 479 480 // Sizes of the blocks to load in cache. See the Goto paper for details. 481 internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1); 482 const Index kc = blocking.kc(); 483 const Index mc = numext::mini(m, blocking.mc()); 484 const Index nc = numext::mini(n, blocking.nc()); 485 const Index sizeA = mc * kc; 486 const Index sizeB = kc * nc; 487 488 LhsScalar* blockA = static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar))); 489 RhsScalar* blockB = static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar))); 490 491 for(Index i2=0; i2<m; i2+=mc) 492 { 493 const Index actual_mc = numext::mini(i2+mc,m)-i2; 494 for (Index k2 = 0; k2 < k; k2 += kc) { 495 // make sure we don't overshoot right edge of left matrix, then pack vertical panel 496 const Index actual_kc = numext::mini(k2 + kc, k) - k2; 497 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0); 498 499 // series of horizontal blocks 500 for (Index j2 = 0; j2 < n; j2 += nc) { 501 // make sure we don't overshoot right edge of right matrix, then pack block 502 const Index actual_nc = numext::mini(j2 + nc, n) - j2; 503 pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc, 0, 0); 504 505 // call gebp (matrix kernel) 506 // The parameters here are copied from Eigen's GEMM implementation 507 gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0); 508 } 509 } 510 } 511 512 this->m_device.deallocate(blockA); 513 this->m_device.deallocate(blockB); 514 } 515 516 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 517 m_leftImpl.cleanup(); 518 m_rightImpl.cleanup(); 519 520 if (m_result != NULL) { 521 m_device.deallocate(m_result); 522 m_result = NULL; 523 } 524 } 525 526 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 527 return m_result[index]; 528 } 529 530 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool) const { 531 return TensorOpCost(sizeof(CoeffReturnType), 0, 0); 532 } 533 534 template<int LoadMode> 535 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const { 536 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index); 537 } 538 539 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar* data() const { return m_result; } 540 541 protected: 542 // Prevent assignment 543 TensorContractionEvaluatorBase& operator = (const TensorContractionEvaluatorBase&); 544 Dimensions m_dimensions; 545 546 contract_t m_k_strides; 547 contract_t m_left_contracting_strides; 548 contract_t m_right_contracting_strides; 549 550 bool m_lhs_inner_dim_contiguous; 551 bool m_rhs_inner_dim_contiguous; 552 bool m_rhs_inner_dim_reordered; 553 554 left_nocontract_t m_i_strides; 555 right_nocontract_t m_j_strides; 556 left_nocontract_t m_left_nocontract_strides; 557 right_nocontract_t m_right_nocontract_strides; 558 559 Index m_i_size; 560 Index m_j_size; 561 Index m_k_size; 562 563 TensorEvaluator<EvalLeftArgType, Device> m_leftImpl; 564 TensorEvaluator<EvalRightArgType, Device> m_rightImpl; 565 const Device& m_device; 566 Scalar* m_result; 567 }; 568 569 570 // evaluator for default device 571 template<typename Indices, typename LeftArgType, typename RightArgType, typename Device> 572 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> : 573 public TensorContractionEvaluatorBase< 574 TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > { 575 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; 576 typedef TensorContractionEvaluatorBase<Self> Base; 577 578 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; 579 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; 580 typedef typename XprType::Index Index; 581 typedef typename XprType::CoeffReturnType CoeffReturnType; 582 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 583 584 enum { 585 Layout = TensorEvaluator<LeftArgType, Device>::Layout 586 }; 587 588 // Most of the code is assuming that both input tensors are ColMajor. If the 589 // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: 590 // If we want to compute A * B = C, where A is LHS and B is RHS, the code 591 // will pretend B is LHS and A is RHS. 592 typedef typename internal::conditional< 593 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; 594 typedef typename internal::conditional< 595 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; 596 597 static const int LDims = 598 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value; 599 static const int RDims = 600 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; 601 static const int ContractDims = internal::array_size<Indices>::value; 602 603 typedef array<Index, ContractDims> contract_t; 604 typedef array<Index, LDims - ContractDims> left_nocontract_t; 605 typedef array<Index, RDims - ContractDims> right_nocontract_t; 606 607 static const int NumDims = LDims + RDims - 2 * ContractDims; 608 609 // Could we use NumDimensions here? 610 typedef DSizes<Index, NumDims> Dimensions; 611 612 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : 613 Base(op, device) { } 614 615 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> 616 EIGEN_DEVICE_FUNC void evalProduct(Scalar* buffer) const { 617 if (this->m_j_size == 1) { 618 this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer); 619 return; 620 } 621 622 this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer); 623 } 624 }; 625 626 } // end namespace Eigen 627 628 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 629