Home | History | Annotate | Download | only in PBQP
      1 //===- Math.h - PBQP Vector and Matrix classes ------------------*- C++ -*-===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 
     10 #ifndef LLVM_CODEGEN_PBQP_MATH_H
     11 #define LLVM_CODEGEN_PBQP_MATH_H
     12 
     13 #include "llvm/ADT/Hashing.h"
     14 #include "llvm/ADT/STLExtras.h"
     15 #include <algorithm>
     16 #include <cassert>
     17 #include <functional>
     18 #include <memory>
     19 
     20 namespace llvm {
     21 namespace PBQP {
     22 
     23 using PBQPNum = float;
     24 
     25 /// \brief PBQP Vector class.
     26 class Vector {
     27   friend hash_code hash_value(const Vector &);
     28 
     29 public:
     30   /// \brief Construct a PBQP vector of the given size.
     31   explicit Vector(unsigned Length)
     32     : Length(Length), Data(llvm::make_unique<PBQPNum []>(Length)) {}
     33 
     34   /// \brief Construct a PBQP vector with initializer.
     35   Vector(unsigned Length, PBQPNum InitVal)
     36     : Length(Length), Data(llvm::make_unique<PBQPNum []>(Length)) {
     37     std::fill(Data.get(), Data.get() + Length, InitVal);
     38   }
     39 
     40   /// \brief Copy construct a PBQP vector.
     41   Vector(const Vector &V)
     42     : Length(V.Length), Data(llvm::make_unique<PBQPNum []>(Length)) {
     43     std::copy(V.Data.get(), V.Data.get() + Length, Data.get());
     44   }
     45 
     46   /// \brief Move construct a PBQP vector.
     47   Vector(Vector &&V)
     48     : Length(V.Length), Data(std::move(V.Data)) {
     49     V.Length = 0;
     50   }
     51 
     52   /// \brief Comparison operator.
     53   bool operator==(const Vector &V) const {
     54     assert(Length != 0 && Data && "Invalid vector");
     55     if (Length != V.Length)
     56       return false;
     57     return std::equal(Data.get(), Data.get() + Length, V.Data.get());
     58   }
     59 
     60   /// \brief Return the length of the vector
     61   unsigned getLength() const {
     62     assert(Length != 0 && Data && "Invalid vector");
     63     return Length;
     64   }
     65 
     66   /// \brief Element access.
     67   PBQPNum& operator[](unsigned Index) {
     68     assert(Length != 0 && Data && "Invalid vector");
     69     assert(Index < Length && "Vector element access out of bounds.");
     70     return Data[Index];
     71   }
     72 
     73   /// \brief Const element access.
     74   const PBQPNum& operator[](unsigned Index) const {
     75     assert(Length != 0 && Data && "Invalid vector");
     76     assert(Index < Length && "Vector element access out of bounds.");
     77     return Data[Index];
     78   }
     79 
     80   /// \brief Add another vector to this one.
     81   Vector& operator+=(const Vector &V) {
     82     assert(Length != 0 && Data && "Invalid vector");
     83     assert(Length == V.Length && "Vector length mismatch.");
     84     std::transform(Data.get(), Data.get() + Length, V.Data.get(), Data.get(),
     85                    std::plus<PBQPNum>());
     86     return *this;
     87   }
     88 
     89   /// \brief Returns the index of the minimum value in this vector
     90   unsigned minIndex() const {
     91     assert(Length != 0 && Data && "Invalid vector");
     92     return std::min_element(Data.get(), Data.get() + Length) - Data.get();
     93   }
     94 
     95 private:
     96   unsigned Length;
     97   std::unique_ptr<PBQPNum []> Data;
     98 };
     99 
    100 /// \brief Return a hash_value for the given vector.
    101 inline hash_code hash_value(const Vector &V) {
    102   unsigned *VBegin = reinterpret_cast<unsigned*>(V.Data.get());
    103   unsigned *VEnd = reinterpret_cast<unsigned*>(V.Data.get() + V.Length);
    104   return hash_combine(V.Length, hash_combine_range(VBegin, VEnd));
    105 }
    106 
    107 /// \brief Output a textual representation of the given vector on the given
    108 ///        output stream.
    109 template <typename OStream>
    110 OStream& operator<<(OStream &OS, const Vector &V) {
    111   assert((V.getLength() != 0) && "Zero-length vector badness.");
    112 
    113   OS << "[ " << V[0];
    114   for (unsigned i = 1; i < V.getLength(); ++i)
    115     OS << ", " << V[i];
    116   OS << " ]";
    117 
    118   return OS;
    119 }
    120 
    121 /// \brief PBQP Matrix class
    122 class Matrix {
    123 private:
    124   friend hash_code hash_value(const Matrix &);
    125 
    126 public:
    127   /// \brief Construct a PBQP Matrix with the given dimensions.
    128   Matrix(unsigned Rows, unsigned Cols) :
    129     Rows(Rows), Cols(Cols), Data(llvm::make_unique<PBQPNum []>(Rows * Cols)) {
    130   }
    131 
    132   /// \brief Construct a PBQP Matrix with the given dimensions and initial
    133   /// value.
    134   Matrix(unsigned Rows, unsigned Cols, PBQPNum InitVal)
    135     : Rows(Rows), Cols(Cols),
    136       Data(llvm::make_unique<PBQPNum []>(Rows * Cols)) {
    137     std::fill(Data.get(), Data.get() + (Rows * Cols), InitVal);
    138   }
    139 
    140   /// \brief Copy construct a PBQP matrix.
    141   Matrix(const Matrix &M)
    142     : Rows(M.Rows), Cols(M.Cols),
    143       Data(llvm::make_unique<PBQPNum []>(Rows * Cols)) {
    144     std::copy(M.Data.get(), M.Data.get() + (Rows * Cols), Data.get());
    145   }
    146 
    147   /// \brief Move construct a PBQP matrix.
    148   Matrix(Matrix &&M)
    149     : Rows(M.Rows), Cols(M.Cols), Data(std::move(M.Data)) {
    150     M.Rows = M.Cols = 0;
    151   }
    152 
    153   /// \brief Comparison operator.
    154   bool operator==(const Matrix &M) const {
    155     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    156     if (Rows != M.Rows || Cols != M.Cols)
    157       return false;
    158     return std::equal(Data.get(), Data.get() + (Rows * Cols), M.Data.get());
    159   }
    160 
    161   /// \brief Return the number of rows in this matrix.
    162   unsigned getRows() const {
    163     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    164     return Rows;
    165   }
    166 
    167   /// \brief Return the number of cols in this matrix.
    168   unsigned getCols() const {
    169     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    170     return Cols;
    171   }
    172 
    173   /// \brief Matrix element access.
    174   PBQPNum* operator[](unsigned R) {
    175     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    176     assert(R < Rows && "Row out of bounds.");
    177     return Data.get() + (R * Cols);
    178   }
    179 
    180   /// \brief Matrix element access.
    181   const PBQPNum* operator[](unsigned R) const {
    182     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    183     assert(R < Rows && "Row out of bounds.");
    184     return Data.get() + (R * Cols);
    185   }
    186 
    187   /// \brief Returns the given row as a vector.
    188   Vector getRowAsVector(unsigned R) const {
    189     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    190     Vector V(Cols);
    191     for (unsigned C = 0; C < Cols; ++C)
    192       V[C] = (*this)[R][C];
    193     return V;
    194   }
    195 
    196   /// \brief Returns the given column as a vector.
    197   Vector getColAsVector(unsigned C) const {
    198     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    199     Vector V(Rows);
    200     for (unsigned R = 0; R < Rows; ++R)
    201       V[R] = (*this)[R][C];
    202     return V;
    203   }
    204 
    205   /// \brief Matrix transpose.
    206   Matrix transpose() const {
    207     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    208     Matrix M(Cols, Rows);
    209     for (unsigned r = 0; r < Rows; ++r)
    210       for (unsigned c = 0; c < Cols; ++c)
    211         M[c][r] = (*this)[r][c];
    212     return M;
    213   }
    214 
    215   /// \brief Add the given matrix to this one.
    216   Matrix& operator+=(const Matrix &M) {
    217     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    218     assert(Rows == M.Rows && Cols == M.Cols &&
    219            "Matrix dimensions mismatch.");
    220     std::transform(Data.get(), Data.get() + (Rows * Cols), M.Data.get(),
    221                    Data.get(), std::plus<PBQPNum>());
    222     return *this;
    223   }
    224 
    225   Matrix operator+(const Matrix &M) {
    226     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    227     Matrix Tmp(*this);
    228     Tmp += M;
    229     return Tmp;
    230   }
    231 
    232 private:
    233   unsigned Rows, Cols;
    234   std::unique_ptr<PBQPNum []> Data;
    235 };
    236 
    237 /// \brief Return a hash_code for the given matrix.
    238 inline hash_code hash_value(const Matrix &M) {
    239   unsigned *MBegin = reinterpret_cast<unsigned*>(M.Data.get());
    240   unsigned *MEnd =
    241     reinterpret_cast<unsigned*>(M.Data.get() + (M.Rows * M.Cols));
    242   return hash_combine(M.Rows, M.Cols, hash_combine_range(MBegin, MEnd));
    243 }
    244 
    245 /// \brief Output a textual representation of the given matrix on the given
    246 ///        output stream.
    247 template <typename OStream>
    248 OStream& operator<<(OStream &OS, const Matrix &M) {
    249   assert((M.getRows() != 0) && "Zero-row matrix badness.");
    250   for (unsigned i = 0; i < M.getRows(); ++i)
    251     OS << M.getRowAsVector(i) << "\n";
    252   return OS;
    253 }
    254 
    255 template <typename Metadata>
    256 class MDVector : public Vector {
    257 public:
    258   MDVector(const Vector &v) : Vector(v), md(*this) {}
    259   MDVector(Vector &&v) : Vector(std::move(v)), md(*this) { }
    260 
    261   const Metadata& getMetadata() const { return md; }
    262 
    263 private:
    264   Metadata md;
    265 };
    266 
    267 template <typename Metadata>
    268 inline hash_code hash_value(const MDVector<Metadata> &V) {
    269   return hash_value(static_cast<const Vector&>(V));
    270 }
    271 
    272 template <typename Metadata>
    273 class MDMatrix : public Matrix {
    274 public:
    275   MDMatrix(const Matrix &m) : Matrix(m), md(*this) {}
    276   MDMatrix(Matrix &&m) : Matrix(std::move(m)), md(*this) { }
    277 
    278   const Metadata& getMetadata() const { return md; }
    279 
    280 private:
    281   Metadata md;
    282 };
    283 
    284 template <typename Metadata>
    285 inline hash_code hash_value(const MDMatrix<Metadata> &M) {
    286   return hash_value(static_cast<const Matrix&>(M));
    287 }
    288 
    289 } // end namespace PBQP
    290 } // end namespace llvm
    291 
    292 #endif // LLVM_CODEGEN_PBQP_MATH_H
    293