Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_IO_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_IO_H
     12 
     13 namespace Eigen {
     14 
     15 namespace internal {
     16 
     17 // Print the tensor as a 2d matrix
     18 template <typename Tensor, int Rank>
     19 struct TensorPrinter {
     20   static void run (std::ostream& os, const Tensor& tensor) {
     21     typedef typename internal::remove_const<typename Tensor::Scalar>::type Scalar;
     22     typedef typename Tensor::Index Index;
     23     const Index total_size = internal::array_prod(tensor.dimensions());
     24     if (total_size > 0) {
     25       const Index first_dim = Eigen::internal::array_get<0>(tensor.dimensions());
     26       static const int layout = Tensor::Layout;
     27       Map<const Array<Scalar, Dynamic, Dynamic, layout> > matrix(const_cast<Scalar*>(tensor.data()), first_dim, total_size/first_dim);
     28       os << matrix;
     29     }
     30   }
     31 };
     32 
     33 
     34 // Print the tensor as a vector
     35 template <typename Tensor>
     36 struct TensorPrinter<Tensor, 1> {
     37   static void run (std::ostream& os, const Tensor& tensor) {
     38     typedef typename internal::remove_const<typename Tensor::Scalar>::type Scalar;
     39     typedef typename Tensor::Index Index;
     40     const Index total_size = internal::array_prod(tensor.dimensions());
     41     if (total_size > 0) {
     42       Map<const Array<Scalar, Dynamic, 1> > array(const_cast<Scalar*>(tensor.data()), total_size);
     43       os << array;
     44     }
     45   }
     46 };
     47 
     48 
     49 // Print the tensor as a scalar
     50 template <typename Tensor>
     51 struct TensorPrinter<Tensor, 0> {
     52   static void run (std::ostream& os, const Tensor& tensor) {
     53     os << tensor.coeff(0);
     54   }
     55 };
     56 }
     57 
     58 template <typename T>
     59 std::ostream& operator << (std::ostream& os, const TensorBase<T, ReadOnlyAccessors>& expr) {
     60   typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
     61   typedef typename Evaluator::Dimensions Dimensions;
     62 
     63   // Evaluate the expression if needed
     64   TensorForcedEvalOp<const T> eval = expr.eval();
     65   Evaluator tensor(eval, DefaultDevice());
     66   tensor.evalSubExprsIfNeeded(NULL);
     67 
     68   // Print the result
     69   static const int rank = internal::array_size<Dimensions>::value;
     70   internal::TensorPrinter<Evaluator, rank>::run(os, tensor);
     71 
     72   // Cleanup.
     73   tensor.cleanup();
     74   return os;
     75 }
     76 
     77 } // end namespace Eigen
     78 
     79 #endif // EIGEN_CXX11_TENSOR_TENSOR_IO_H
     80