1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H 12 13 namespace Eigen { 14 15 /** \class TensorEvaluator 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief The tensor evaluator classes. 19 * 20 * These classes are responsible for the evaluation of the tensor expression. 21 * 22 * TODO: add support for more types of expressions, in particular expressions 23 * leading to lvalues (slicing, reshaping, etc...) 24 */ 25 26 // Generic evaluator 27 template<typename Derived, typename Device> 28 struct TensorEvaluator 29 { 30 typedef typename Derived::Index Index; 31 typedef typename Derived::Scalar Scalar; 32 typedef typename Derived::Scalar CoeffReturnType; 33 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 34 typedef typename Derived::Dimensions Dimensions; 35 36 // NumDimensions is -1 for variable dim tensors 37 static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ? 38 internal::traits<Derived>::NumDimensions : 0; 39 40 enum { 41 IsAligned = Derived::IsAligned, 42 PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1), 43 Layout = Derived::Layout, 44 CoordAccess = NumCoords > 0, 45 RawAccess = true 46 }; 47 48 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device) 49 : m_data(const_cast<typename internal::traits<Derived>::template MakePointer<Scalar>::Type>(m.data())), m_dims(m.dimensions()), m_device(device), m_impl(m) 50 { } 51 52 // Used for accessor extraction in SYCL Managed TensorMap: 53 const Derived& derived() const { return m_impl; } 54 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; } 55 56 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* dest) { 57 if (dest) { 58 m_device.memcpy((void*)dest, m_data, sizeof(Scalar) * m_dims.TotalSize()); 59 return false; 60 } 61 return true; 62 } 63 64 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } 65 66 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 67 eigen_assert(m_data); 68 return m_data[index]; 69 } 70 71 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { 72 eigen_assert(m_data); 73 return m_data[index]; 74 } 75 76 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 77 PacketReturnType packet(Index index) const 78 { 79 return internal::ploadt<PacketReturnType, LoadMode>(m_data + index); 80 } 81 82 template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 83 void writePacket(Index index, const PacketReturnType& x) 84 { 85 return internal::pstoret<Scalar, PacketReturnType, StoreMode>(m_data + index, x); 86 } 87 88 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const { 89 eigen_assert(m_data); 90 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 91 return m_data[m_dims.IndexOfColMajor(coords)]; 92 } else { 93 return m_data[m_dims.IndexOfRowMajor(coords)]; 94 } 95 } 96 97 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<DenseIndex, NumCoords>& coords) { 98 eigen_assert(m_data); 99 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 100 return m_data[m_dims.IndexOfColMajor(coords)]; 101 } else { 102 return m_data[m_dims.IndexOfRowMajor(coords)]; 103 } 104 } 105 106 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { 107 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, 108 internal::unpacket_traits<PacketReturnType>::size); 109 } 110 111 EIGEN_DEVICE_FUNC typename internal::traits<Derived>::template MakePointer<Scalar>::Type data() const { return m_data; } 112 113 /// required by sycl in order to construct sycl buffer from raw pointer 114 const Device& device() const{return m_device;} 115 116 protected: 117 typename internal::traits<Derived>::template MakePointer<Scalar>::Type m_data; 118 Dimensions m_dims; 119 const Device& m_device; 120 const Derived& m_impl; 121 }; 122 123 namespace { 124 template <typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE 125 T loadConstant(const T* address) { 126 return *address; 127 } 128 // Use the texture cache on CUDA devices whenever possible 129 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 130 template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE 131 float loadConstant(const float* address) { 132 return __ldg(address); 133 } 134 template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE 135 double loadConstant(const double* address) { 136 return __ldg(address); 137 } 138 template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE 139 Eigen::half loadConstant(const Eigen::half* address) { 140 return Eigen::half(half_impl::raw_uint16_to_half(__ldg(&address->x))); 141 } 142 #endif 143 } 144 145 146 // Default evaluator for rvalues 147 template<typename Derived, typename Device> 148 struct TensorEvaluator<const Derived, Device> 149 { 150 typedef typename Derived::Index Index; 151 typedef typename Derived::Scalar Scalar; 152 typedef typename Derived::Scalar CoeffReturnType; 153 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 154 typedef typename Derived::Dimensions Dimensions; 155 156 // NumDimensions is -1 for variable dim tensors 157 static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ? 158 internal::traits<Derived>::NumDimensions : 0; 159 160 enum { 161 IsAligned = Derived::IsAligned, 162 PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1), 163 Layout = Derived::Layout, 164 CoordAccess = NumCoords > 0, 165 RawAccess = true 166 }; 167 168 // Used for accessor extraction in SYCL Managed TensorMap: 169 const Derived& derived() const { return m_impl; } 170 171 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device) 172 : m_data(m.data()), m_dims(m.dimensions()), m_device(device), m_impl(m) 173 { } 174 175 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; } 176 177 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { 178 if (!NumTraits<typename internal::remove_const<Scalar>::type>::RequireInitialization && data) { 179 m_device.memcpy((void*)data, m_data, m_dims.TotalSize() * sizeof(Scalar)); 180 return false; 181 } 182 return true; 183 } 184 185 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } 186 187 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 188 eigen_assert(m_data); 189 return loadConstant(m_data+index); 190 } 191 192 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 193 PacketReturnType packet(Index index) const 194 { 195 return internal::ploadt_ro<PacketReturnType, LoadMode>(m_data + index); 196 } 197 198 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const { 199 eigen_assert(m_data); 200 const Index index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? m_dims.IndexOfColMajor(coords) 201 : m_dims.IndexOfRowMajor(coords); 202 return loadConstant(m_data+index); 203 } 204 205 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { 206 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, 207 internal::unpacket_traits<PacketReturnType>::size); 208 } 209 210 EIGEN_DEVICE_FUNC typename internal::traits<Derived>::template MakePointer<const Scalar>::Type data() const { return m_data; } 211 212 /// added for sycl in order to construct the buffer from the sycl device 213 const Device& device() const{return m_device;} 214 215 protected: 216 typename internal::traits<Derived>::template MakePointer<const Scalar>::Type m_data; 217 Dimensions m_dims; 218 const Device& m_device; 219 const Derived& m_impl; 220 }; 221 222 223 224 225 // -------------------- CwiseNullaryOp -------------------- 226 227 template<typename NullaryOp, typename ArgType, typename Device> 228 struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device> 229 { 230 typedef TensorCwiseNullaryOp<NullaryOp, ArgType> XprType; 231 232 enum { 233 IsAligned = true, 234 PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess, 235 Layout = TensorEvaluator<ArgType, Device>::Layout, 236 CoordAccess = false, // to be implemented 237 RawAccess = false 238 }; 239 240 EIGEN_DEVICE_FUNC 241 TensorEvaluator(const XprType& op, const Device& device) 242 : m_functor(op.functor()), m_argImpl(op.nestedExpression(), device), m_wrapper() 243 { } 244 245 typedef typename XprType::Index Index; 246 typedef typename XprType::Scalar Scalar; 247 typedef typename internal::traits<XprType>::Scalar CoeffReturnType; 248 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 249 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; 250 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions; 251 252 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); } 253 254 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) { return true; } 255 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } 256 257 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const 258 { 259 return m_wrapper(m_functor, index); 260 } 261 262 template<int LoadMode> 263 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const 264 { 265 return m_wrapper.template packetOp<PacketReturnType, Index>(m_functor, index); 266 } 267 268 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 269 costPerCoeff(bool vectorized) const { 270 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, 271 internal::unpacket_traits<PacketReturnType>::size); 272 } 273 274 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; } 275 276 /// required by sycl in order to extract the accessor 277 const TensorEvaluator<ArgType, Device>& impl() const { return m_argImpl; } 278 /// required by sycl in order to extract the accessor 279 NullaryOp functor() const { return m_functor; } 280 281 282 private: 283 const NullaryOp m_functor; 284 TensorEvaluator<ArgType, Device> m_argImpl; 285 const internal::nullary_wrapper<CoeffReturnType,NullaryOp> m_wrapper; 286 }; 287 288 289 290 // -------------------- CwiseUnaryOp -------------------- 291 292 template<typename UnaryOp, typename ArgType, typename Device> 293 struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device> 294 { 295 typedef TensorCwiseUnaryOp<UnaryOp, ArgType> XprType; 296 297 enum { 298 IsAligned = TensorEvaluator<ArgType, Device>::IsAligned, 299 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess, 300 Layout = TensorEvaluator<ArgType, Device>::Layout, 301 CoordAccess = false, // to be implemented 302 RawAccess = false 303 }; 304 305 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) 306 : m_functor(op.functor()), 307 m_argImpl(op.nestedExpression(), device) 308 { } 309 310 typedef typename XprType::Index Index; 311 typedef typename XprType::Scalar Scalar; 312 typedef typename internal::traits<XprType>::Scalar CoeffReturnType; 313 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 314 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; 315 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions; 316 317 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); } 318 319 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { 320 m_argImpl.evalSubExprsIfNeeded(NULL); 321 return true; 322 } 323 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 324 m_argImpl.cleanup(); 325 } 326 327 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const 328 { 329 return m_functor(m_argImpl.coeff(index)); 330 } 331 332 template<int LoadMode> 333 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const 334 { 335 return m_functor.packetOp(m_argImpl.template packet<LoadMode>(index)); 336 } 337 338 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { 339 const double functor_cost = internal::functor_traits<UnaryOp>::Cost; 340 return m_argImpl.costPerCoeff(vectorized) + 341 TensorOpCost(0, 0, functor_cost, vectorized, PacketSize); 342 } 343 344 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; } 345 346 /// required by sycl in order to extract the accessor 347 const TensorEvaluator<ArgType, Device> & impl() const { return m_argImpl; } 348 /// added for sycl in order to construct the buffer from sycl device 349 UnaryOp functor() const { return m_functor; } 350 351 352 private: 353 const UnaryOp m_functor; 354 TensorEvaluator<ArgType, Device> m_argImpl; 355 }; 356 357 358 // -------------------- CwiseBinaryOp -------------------- 359 360 template<typename BinaryOp, typename LeftArgType, typename RightArgType, typename Device> 361 struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType>, Device> 362 { 363 typedef TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType> XprType; 364 365 enum { 366 IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned, 367 PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess & 368 internal::functor_traits<BinaryOp>::PacketAccess, 369 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 370 CoordAccess = false, // to be implemented 371 RawAccess = false 372 }; 373 374 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) 375 : m_functor(op.functor()), 376 m_leftImpl(op.lhsExpression(), device), 377 m_rightImpl(op.rhsExpression(), device) 378 { 379 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE); 380 eigen_assert(dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions())); 381 } 382 383 typedef typename XprType::Index Index; 384 typedef typename XprType::Scalar Scalar; 385 typedef typename internal::traits<XprType>::Scalar CoeffReturnType; 386 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 387 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; 388 typedef typename TensorEvaluator<LeftArgType, Device>::Dimensions Dimensions; 389 390 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const 391 { 392 // TODO: use right impl instead if right impl dimensions are known at compile time. 393 return m_leftImpl.dimensions(); 394 } 395 396 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) { 397 m_leftImpl.evalSubExprsIfNeeded(NULL); 398 m_rightImpl.evalSubExprsIfNeeded(NULL); 399 return true; 400 } 401 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 402 m_leftImpl.cleanup(); 403 m_rightImpl.cleanup(); 404 } 405 406 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const 407 { 408 return m_functor(m_leftImpl.coeff(index), m_rightImpl.coeff(index)); 409 } 410 template<int LoadMode> 411 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const 412 { 413 return m_functor.packetOp(m_leftImpl.template packet<LoadMode>(index), m_rightImpl.template packet<LoadMode>(index)); 414 } 415 416 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 417 costPerCoeff(bool vectorized) const { 418 const double functor_cost = internal::functor_traits<BinaryOp>::Cost; 419 return m_leftImpl.costPerCoeff(vectorized) + 420 m_rightImpl.costPerCoeff(vectorized) + 421 TensorOpCost(0, 0, functor_cost, vectorized, PacketSize); 422 } 423 424 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; } 425 /// required by sycl in order to extract the accessor 426 const TensorEvaluator<LeftArgType, Device>& left_impl() const { return m_leftImpl; } 427 /// required by sycl in order to extract the accessor 428 const TensorEvaluator<RightArgType, Device>& right_impl() const { return m_rightImpl; } 429 /// required by sycl in order to extract the accessor 430 BinaryOp functor() const { return m_functor; } 431 432 private: 433 const BinaryOp m_functor; 434 TensorEvaluator<LeftArgType, Device> m_leftImpl; 435 TensorEvaluator<RightArgType, Device> m_rightImpl; 436 }; 437 438 // -------------------- CwiseTernaryOp -------------------- 439 440 template<typename TernaryOp, typename Arg1Type, typename Arg2Type, typename Arg3Type, typename Device> 441 struct TensorEvaluator<const TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type>, Device> 442 { 443 typedef TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type> XprType; 444 445 enum { 446 IsAligned = TensorEvaluator<Arg1Type, Device>::IsAligned & TensorEvaluator<Arg2Type, Device>::IsAligned & TensorEvaluator<Arg3Type, Device>::IsAligned, 447 PacketAccess = TensorEvaluator<Arg1Type, Device>::PacketAccess & TensorEvaluator<Arg2Type, Device>::PacketAccess & TensorEvaluator<Arg3Type, Device>::PacketAccess & 448 internal::functor_traits<TernaryOp>::PacketAccess, 449 Layout = TensorEvaluator<Arg1Type, Device>::Layout, 450 CoordAccess = false, // to be implemented 451 RawAccess = false 452 }; 453 454 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) 455 : m_functor(op.functor()), 456 m_arg1Impl(op.arg1Expression(), device), 457 m_arg2Impl(op.arg2Expression(), device), 458 m_arg3Impl(op.arg3Expression(), device) 459 { 460 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<Arg1Type, Device>::Layout) == static_cast<int>(TensorEvaluator<Arg3Type, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE); 461 462 EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::StorageKind, 463 typename internal::traits<Arg2Type>::StorageKind>::value), 464 STORAGE_KIND_MUST_MATCH) 465 EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::StorageKind, 466 typename internal::traits<Arg3Type>::StorageKind>::value), 467 STORAGE_KIND_MUST_MATCH) 468 EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::Index, 469 typename internal::traits<Arg2Type>::Index>::value), 470 STORAGE_INDEX_MUST_MATCH) 471 EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::Index, 472 typename internal::traits<Arg3Type>::Index>::value), 473 STORAGE_INDEX_MUST_MATCH) 474 475 eigen_assert(dimensions_match(m_arg1Impl.dimensions(), m_arg2Impl.dimensions()) && dimensions_match(m_arg1Impl.dimensions(), m_arg3Impl.dimensions())); 476 } 477 478 typedef typename XprType::Index Index; 479 typedef typename XprType::Scalar Scalar; 480 typedef typename internal::traits<XprType>::Scalar CoeffReturnType; 481 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 482 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; 483 typedef typename TensorEvaluator<Arg1Type, Device>::Dimensions Dimensions; 484 485 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const 486 { 487 // TODO: use arg2 or arg3 dimensions if they are known at compile time. 488 return m_arg1Impl.dimensions(); 489 } 490 491 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) { 492 m_arg1Impl.evalSubExprsIfNeeded(NULL); 493 m_arg2Impl.evalSubExprsIfNeeded(NULL); 494 m_arg3Impl.evalSubExprsIfNeeded(NULL); 495 return true; 496 } 497 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 498 m_arg1Impl.cleanup(); 499 m_arg2Impl.cleanup(); 500 m_arg3Impl.cleanup(); 501 } 502 503 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const 504 { 505 return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index)); 506 } 507 template<int LoadMode> 508 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const 509 { 510 return m_functor.packetOp(m_arg1Impl.template packet<LoadMode>(index), 511 m_arg2Impl.template packet<LoadMode>(index), 512 m_arg3Impl.template packet<LoadMode>(index)); 513 } 514 515 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 516 costPerCoeff(bool vectorized) const { 517 const double functor_cost = internal::functor_traits<TernaryOp>::Cost; 518 return m_arg1Impl.costPerCoeff(vectorized) + 519 m_arg2Impl.costPerCoeff(vectorized) + 520 m_arg3Impl.costPerCoeff(vectorized) + 521 TensorOpCost(0, 0, functor_cost, vectorized, PacketSize); 522 } 523 524 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; } 525 526 /// required by sycl in order to extract the accessor 527 const TensorEvaluator<Arg1Type, Device> & arg1Impl() const { return m_arg1Impl; } 528 /// required by sycl in order to extract the accessor 529 const TensorEvaluator<Arg2Type, Device>& arg2Impl() const { return m_arg2Impl; } 530 /// required by sycl in order to extract the accessor 531 const TensorEvaluator<Arg3Type, Device>& arg3Impl() const { return m_arg3Impl; } 532 533 private: 534 const TernaryOp m_functor; 535 TensorEvaluator<Arg1Type, Device> m_arg1Impl; 536 TensorEvaluator<Arg2Type, Device> m_arg2Impl; 537 TensorEvaluator<Arg3Type, Device> m_arg3Impl; 538 }; 539 540 541 // -------------------- SelectOp -------------------- 542 543 template<typename IfArgType, typename ThenArgType, typename ElseArgType, typename Device> 544 struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>, Device> 545 { 546 typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType; 547 typedef typename XprType::Scalar Scalar; 548 549 enum { 550 IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & TensorEvaluator<ElseArgType, Device>::IsAligned, 551 PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess & TensorEvaluator<ElseArgType, Device>::PacketAccess & 552 internal::packet_traits<Scalar>::HasBlend, 553 Layout = TensorEvaluator<IfArgType, Device>::Layout, 554 CoordAccess = false, // to be implemented 555 RawAccess = false 556 }; 557 558 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) 559 : m_condImpl(op.ifExpression(), device), 560 m_thenImpl(op.thenExpression(), device), 561 m_elseImpl(op.elseExpression(), device) 562 { 563 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<IfArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<ThenArgType, Device>::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE); 564 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<IfArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<ElseArgType, Device>::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE); 565 eigen_assert(dimensions_match(m_condImpl.dimensions(), m_thenImpl.dimensions())); 566 eigen_assert(dimensions_match(m_thenImpl.dimensions(), m_elseImpl.dimensions())); 567 } 568 569 typedef typename XprType::Index Index; 570 typedef typename internal::traits<XprType>::Scalar CoeffReturnType; 571 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 572 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; 573 typedef typename TensorEvaluator<IfArgType, Device>::Dimensions Dimensions; 574 575 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const 576 { 577 // TODO: use then or else impl instead if they happen to be known at compile time. 578 return m_condImpl.dimensions(); 579 } 580 581 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) { 582 m_condImpl.evalSubExprsIfNeeded(NULL); 583 m_thenImpl.evalSubExprsIfNeeded(NULL); 584 m_elseImpl.evalSubExprsIfNeeded(NULL); 585 return true; 586 } 587 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 588 m_condImpl.cleanup(); 589 m_thenImpl.cleanup(); 590 m_elseImpl.cleanup(); 591 } 592 593 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const 594 { 595 return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index); 596 } 597 template<int LoadMode> 598 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const 599 { 600 internal::Selector<PacketSize> select; 601 for (Index i = 0; i < PacketSize; ++i) { 602 select.select[i] = m_condImpl.coeff(index+i); 603 } 604 return internal::pblend(select, 605 m_thenImpl.template packet<LoadMode>(index), 606 m_elseImpl.template packet<LoadMode>(index)); 607 } 608 609 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 610 costPerCoeff(bool vectorized) const { 611 return m_condImpl.costPerCoeff(vectorized) + 612 m_thenImpl.costPerCoeff(vectorized) 613 .cwiseMax(m_elseImpl.costPerCoeff(vectorized)); 614 } 615 616 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType* data() const { return NULL; } 617 /// required by sycl in order to extract the accessor 618 const TensorEvaluator<IfArgType, Device> & cond_impl() const { return m_condImpl; } 619 /// required by sycl in order to extract the accessor 620 const TensorEvaluator<ThenArgType, Device>& then_impl() const { return m_thenImpl; } 621 /// required by sycl in order to extract the accessor 622 const TensorEvaluator<ElseArgType, Device>& else_impl() const { return m_elseImpl; } 623 624 private: 625 TensorEvaluator<IfArgType, Device> m_condImpl; 626 TensorEvaluator<ThenArgType, Device> m_thenImpl; 627 TensorEvaluator<ElseArgType, Device> m_elseImpl; 628 }; 629 630 631 } // end namespace Eigen 632 633 #endif // EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H 634