Home | History | Annotate | Download | only in public
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // map.h: a minimalist view-existing-buffer-as-a-matrix class,
     16 // which is how gemmlowp interfaces with external matrix data.
     17 
     18 #ifndef GEMMLOWP_PUBLIC_MAP_H_
     19 #define GEMMLOWP_PUBLIC_MAP_H_
     20 
     21 #include "../internal/common.h"
     22 #include "../internal/iterator.h"
     23 
     24 namespace gemmlowp {
     25 
     26 // The two storage orders allowed to map buffers as matrices: ColMajor
     27 // means column-major, RowMajor means row-major.
     28 enum class MapOrder { ColMajor, RowMajor };
     29 
     30 // A MatrixMap is a view of an existing buffer as a matrix. It does not own
     31 // the buffer.
     32 template <typename tScalar, MapOrder tOrder>
     33 class MatrixMap {
     34  public:
     35   typedef tScalar Scalar;
     36   static const MapOrder kOrder = tOrder;
     37 
     38  protected:
     39   Scalar* data_;  // not owned.
     40   int rows_, cols_, stride_;
     41 
     42  public:
     43   MatrixMap() : data_(nullptr), rows_(0), cols_(0), stride_(0) {}
     44   MatrixMap(Scalar* data, int rows, int cols, int stride)
     45       : data_(data), rows_(rows), cols_(cols), stride_(stride) {}
     46   MatrixMap(const MatrixMap& other)
     47       : data_(other.data_),
     48         rows_(other.rows_),
     49         cols_(other.cols_),
     50         stride_(other.stride_) {}
     51 
     52   int rows() const { return rows_; }
     53   int cols() const { return cols_; }
     54   int stride() const { return stride_; }
     55   int rows_stride() const { return kOrder == MapOrder::ColMajor ? 1 : stride_; }
     56   int cols_stride() const { return kOrder == MapOrder::RowMajor ? 1 : stride_; }
     57   Scalar* data() const { return data_; }
     58   Scalar* data(int row, int col) const {
     59     return data_ + row * rows_stride() + col * cols_stride();
     60   }
     61   Scalar& operator()(int row, int col) const { return *data(row, col); }
     62 
     63   MatrixMap block(int start_row, int start_col, int block_rows,
     64                   int block_cols) const {
     65     assert(start_row >= 0);
     66     assert(start_row + block_rows <= rows_);
     67     assert(start_col >= 0);
     68     assert(start_col + block_cols <= cols_);
     69 
     70     return MatrixMap(data(start_row, start_col), block_rows, block_cols,
     71                      stride_);
     72   }
     73 };
     74 
     75 enum class VectorShape { Col, Row };
     76 
     77 // A VectorMap is a view of an existing buffer as a vector. It does not own
     78 // the buffer.
     79 template <typename tScalar, VectorShape tShape>
     80 class VectorMap {
     81  public:
     82   typedef tScalar Scalar;
     83   static const VectorShape kShape = tShape;
     84 
     85  protected:
     86   Scalar* data_;  // not owned.
     87   int size_;
     88 
     89  public:
     90   VectorMap() : data_(nullptr), size_(0) {}
     91   VectorMap(Scalar* data, int size) : data_(data), size_(size) {}
     92   VectorMap(const VectorMap& other) : data_(other.data_), size_(other.size_) {}
     93 
     94   int size() const { return size_; }
     95   Scalar* data() const { return data_; }
     96   Scalar* data(int index) const { return data_ + index; }
     97   Scalar& operator()(int index) const { return *data(index); }
     98 };
     99 
    100 // A VectorDup is a (duplicated value) vector where all components are the same.
    101 template <typename tScalar, VectorShape tShape>
    102 class VectorDup {
    103  public:
    104   typedef tScalar Scalar;
    105   static const VectorShape kShape = tShape;
    106 
    107  protected:
    108   Scalar data_;
    109   int size_;
    110 
    111  public:
    112   VectorDup() : data_(0), size_(0) {}
    113   VectorDup(Scalar data, int size) : data_(data), size_(size) {}
    114   VectorDup(const VectorDup& other) : data_(other.data_), size_(other.size_) {}
    115 
    116   int size() const { return size_; }
    117   Scalar& operator()(int index) const { return data_; }
    118 };
    119 
    120 }  // namespace gemmlowp
    121 
    122 #endif  // GEMMLOWP_PUBLIC_MAP_H_
    123