Home | History | Annotate | Download | only in beamformer
      1 /*
      2  *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #ifndef WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_
     12 #define WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_
     13 
     14 #include <algorithm>
     15 #include <cstring>
     16 #include <string>
     17 #include <vector>
     18 
     19 #include "webrtc/base/checks.h"
     20 #include "webrtc/base/constructormagic.h"
     21 #include "webrtc/base/scoped_ptr.h"
     22 
     23 namespace {
     24 
     25 // Wrappers to get around the compiler warning resulting from the fact that
     26 // there's no std::sqrt overload for ints. We cast all non-complex types to
     27 // a double for the sqrt method.
     28 template <typename T>
     29 T sqrt_wrapper(T x) {
     30   return sqrt(static_cast<double>(x));
     31 }
     32 
     33 template <typename S>
     34 std::complex<S> sqrt_wrapper(std::complex<S> x) {
     35   return sqrt(x);
     36 }
     37 } // namespace
     38 
     39 namespace webrtc {
     40 
     41 // Matrix is a class for doing standard matrix operations on 2 dimensional
     42 // matrices of any size. Results of matrix operations are stored in the
     43 // calling object. Function overloads exist for both in-place (the calling
     44 // object is used as both an operand and the result) and out-of-place (all
     45 // operands are passed in as parameters) operations. If operand dimensions
     46 // mismatch, the program crashes. Out-of-place operations change the size of
     47 // the calling object, if necessary, before operating.
     48 //
     49 // 'In-place' operations that inherently change the size of the matrix (eg.
     50 // Transpose, Multiply on different-sized matrices) must make temporary copies
     51 // (|scratch_elements_| and |scratch_data_|) of existing data to complete the
     52 // operations.
     53 //
     54 // The data is stored contiguously. Data can be accessed internally as a flat
     55 // array, |data_|, or as an array of row pointers, |elements_|, but is
     56 // available to users only as an array of row pointers through |elements()|.
     57 // Memory for storage is allocated when a matrix is resized only if the new
     58 // size overflows capacity. Memory needed temporarily for any operations is
     59 // similarly resized only if the new size overflows capacity.
     60 //
     61 // If you pass in storage through the ctor, that storage is copied into the
     62 // matrix. TODO(claguna): albeit tricky, allow for data to be referenced
     63 // instead of copied, and owned by the user.
     64 template <typename T>
     65 class Matrix {
     66  public:
     67   Matrix() : num_rows_(0), num_columns_(0) {}
     68 
     69   // Allocates space for the elements and initializes all values to zero.
     70   Matrix(size_t num_rows, size_t num_columns)
     71       : num_rows_(num_rows), num_columns_(num_columns) {
     72     Resize();
     73     scratch_data_.resize(num_rows_ * num_columns_);
     74     scratch_elements_.resize(num_rows_);
     75   }
     76 
     77   // Copies |data| into the new Matrix.
     78   Matrix(const T* data, size_t num_rows, size_t num_columns)
     79       : num_rows_(0), num_columns_(0) {
     80     CopyFrom(data, num_rows, num_columns);
     81     scratch_data_.resize(num_rows_ * num_columns_);
     82     scratch_elements_.resize(num_rows_);
     83   }
     84 
     85   virtual ~Matrix() {}
     86 
     87   // Deep copy an existing matrix.
     88   void CopyFrom(const Matrix& other) {
     89     CopyFrom(&other.data_[0], other.num_rows_, other.num_columns_);
     90   }
     91 
     92   // Copy |data| into the Matrix. The current data is lost.
     93   void CopyFrom(const T* const data, size_t num_rows, size_t num_columns) {
     94     Resize(num_rows, num_columns);
     95     memcpy(&data_[0], data, num_rows_ * num_columns_ * sizeof(data_[0]));
     96   }
     97 
     98   Matrix& CopyFromColumn(const T* const* src,
     99                          size_t column_index,
    100                          size_t num_rows) {
    101     Resize(1, num_rows);
    102     for (size_t i = 0; i < num_columns_; ++i) {
    103       data_[i] = src[i][column_index];
    104     }
    105 
    106     return *this;
    107   }
    108 
    109   void Resize(size_t num_rows, size_t num_columns) {
    110     if (num_rows != num_rows_ || num_columns != num_columns_) {
    111       num_rows_ = num_rows;
    112       num_columns_ = num_columns;
    113       Resize();
    114     }
    115   }
    116 
    117   // Accessors and mutators.
    118   size_t num_rows() const { return num_rows_; }
    119   size_t num_columns() const { return num_columns_; }
    120   T* const* elements() { return &elements_[0]; }
    121   const T* const* elements() const { return &elements_[0]; }
    122 
    123   T Trace() {
    124     RTC_CHECK_EQ(num_rows_, num_columns_);
    125 
    126     T trace = 0;
    127     for (size_t i = 0; i < num_rows_; ++i) {
    128       trace += elements_[i][i];
    129     }
    130     return trace;
    131   }
    132 
    133   // Matrix Operations. Returns *this to support method chaining.
    134   Matrix& Transpose() {
    135     CopyDataToScratch();
    136     Resize(num_columns_, num_rows_);
    137     return Transpose(scratch_elements());
    138   }
    139 
    140   Matrix& Transpose(const Matrix& operand) {
    141     RTC_CHECK_EQ(operand.num_rows_, num_columns_);
    142     RTC_CHECK_EQ(operand.num_columns_, num_rows_);
    143 
    144     return Transpose(operand.elements());
    145   }
    146 
    147   template <typename S>
    148   Matrix& Scale(const S& scalar) {
    149     for (size_t i = 0; i < data_.size(); ++i) {
    150       data_[i] *= scalar;
    151     }
    152 
    153     return *this;
    154   }
    155 
    156   template <typename S>
    157   Matrix& Scale(const Matrix& operand, const S& scalar) {
    158     CopyFrom(operand);
    159     return Scale(scalar);
    160   }
    161 
    162   Matrix& Add(const Matrix& operand) {
    163     RTC_CHECK_EQ(num_rows_, operand.num_rows_);
    164     RTC_CHECK_EQ(num_columns_, operand.num_columns_);
    165 
    166     for (size_t i = 0; i < data_.size(); ++i) {
    167       data_[i] += operand.data_[i];
    168     }
    169 
    170     return *this;
    171   }
    172 
    173   Matrix& Add(const Matrix& lhs, const Matrix& rhs) {
    174     CopyFrom(lhs);
    175     return Add(rhs);
    176   }
    177 
    178   Matrix& Subtract(const Matrix& operand) {
    179     RTC_CHECK_EQ(num_rows_, operand.num_rows_);
    180     RTC_CHECK_EQ(num_columns_, operand.num_columns_);
    181 
    182     for (size_t i = 0; i < data_.size(); ++i) {
    183       data_[i] -= operand.data_[i];
    184     }
    185 
    186     return *this;
    187   }
    188 
    189   Matrix& Subtract(const Matrix& lhs, const Matrix& rhs) {
    190     CopyFrom(lhs);
    191     return Subtract(rhs);
    192   }
    193 
    194   Matrix& PointwiseMultiply(const Matrix& operand) {
    195     RTC_CHECK_EQ(num_rows_, operand.num_rows_);
    196     RTC_CHECK_EQ(num_columns_, operand.num_columns_);
    197 
    198     for (size_t i = 0; i < data_.size(); ++i) {
    199       data_[i] *= operand.data_[i];
    200     }
    201 
    202     return *this;
    203   }
    204 
    205   Matrix& PointwiseMultiply(const Matrix& lhs, const Matrix& rhs) {
    206     CopyFrom(lhs);
    207     return PointwiseMultiply(rhs);
    208   }
    209 
    210   Matrix& PointwiseDivide(const Matrix& operand) {
    211     RTC_CHECK_EQ(num_rows_, operand.num_rows_);
    212     RTC_CHECK_EQ(num_columns_, operand.num_columns_);
    213 
    214     for (size_t i = 0; i < data_.size(); ++i) {
    215       data_[i] /= operand.data_[i];
    216     }
    217 
    218     return *this;
    219   }
    220 
    221   Matrix& PointwiseDivide(const Matrix& lhs, const Matrix& rhs) {
    222     CopyFrom(lhs);
    223     return PointwiseDivide(rhs);
    224   }
    225 
    226   Matrix& PointwiseSquareRoot() {
    227     for (size_t i = 0; i < data_.size(); ++i) {
    228       data_[i] = sqrt_wrapper(data_[i]);
    229     }
    230 
    231     return *this;
    232   }
    233 
    234   Matrix& PointwiseSquareRoot(const Matrix& operand) {
    235     CopyFrom(operand);
    236     return PointwiseSquareRoot();
    237   }
    238 
    239   Matrix& PointwiseAbsoluteValue() {
    240     for (size_t i = 0; i < data_.size(); ++i) {
    241       data_[i] = abs(data_[i]);
    242     }
    243 
    244     return *this;
    245   }
    246 
    247   Matrix& PointwiseAbsoluteValue(const Matrix& operand) {
    248     CopyFrom(operand);
    249     return PointwiseAbsoluteValue();
    250   }
    251 
    252   Matrix& PointwiseSquare() {
    253     for (size_t i = 0; i < data_.size(); ++i) {
    254       data_[i] *= data_[i];
    255     }
    256 
    257     return *this;
    258   }
    259 
    260   Matrix& PointwiseSquare(const Matrix& operand) {
    261     CopyFrom(operand);
    262     return PointwiseSquare();
    263   }
    264 
    265   Matrix& Multiply(const Matrix& lhs, const Matrix& rhs) {
    266     RTC_CHECK_EQ(lhs.num_columns_, rhs.num_rows_);
    267     RTC_CHECK_EQ(num_rows_, lhs.num_rows_);
    268     RTC_CHECK_EQ(num_columns_, rhs.num_columns_);
    269 
    270     return Multiply(lhs.elements(), rhs.num_rows_, rhs.elements());
    271   }
    272 
    273   Matrix& Multiply(const Matrix& rhs) {
    274     RTC_CHECK_EQ(num_columns_, rhs.num_rows_);
    275 
    276     CopyDataToScratch();
    277     Resize(num_rows_, rhs.num_columns_);
    278     return Multiply(scratch_elements(), rhs.num_rows_, rhs.elements());
    279   }
    280 
    281   std::string ToString() const {
    282     std::ostringstream ss;
    283     ss << std::endl << "Matrix" << std::endl;
    284 
    285     for (size_t i = 0; i < num_rows_; ++i) {
    286       for (size_t j = 0; j < num_columns_; ++j) {
    287         ss << elements_[i][j] << " ";
    288       }
    289       ss << std::endl;
    290     }
    291     ss << std::endl;
    292 
    293     return ss.str();
    294   }
    295 
    296  protected:
    297   void SetNumRows(const size_t num_rows) { num_rows_ = num_rows; }
    298   void SetNumColumns(const size_t num_columns) { num_columns_ = num_columns; }
    299   T* data() { return &data_[0]; }
    300   const T* data() const { return &data_[0]; }
    301   const T* const* scratch_elements() const { return &scratch_elements_[0]; }
    302 
    303   // Resize the matrix. If an increase in capacity is required, the current
    304   // data is lost.
    305   void Resize() {
    306     size_t size = num_rows_ * num_columns_;
    307     data_.resize(size);
    308     elements_.resize(num_rows_);
    309 
    310     for (size_t i = 0; i < num_rows_; ++i) {
    311       elements_[i] = &data_[i * num_columns_];
    312     }
    313   }
    314 
    315   // Copies data_ into scratch_data_ and updates scratch_elements_ accordingly.
    316   void CopyDataToScratch() {
    317     scratch_data_ = data_;
    318     scratch_elements_.resize(num_rows_);
    319 
    320     for (size_t i = 0; i < num_rows_; ++i) {
    321       scratch_elements_[i] = &scratch_data_[i * num_columns_];
    322     }
    323   }
    324 
    325  private:
    326   size_t num_rows_;
    327   size_t num_columns_;
    328   std::vector<T> data_;
    329   std::vector<T*> elements_;
    330 
    331   // Stores temporary copies of |data_| and |elements_| for in-place operations
    332   // where referring to original data is necessary.
    333   std::vector<T> scratch_data_;
    334   std::vector<T*> scratch_elements_;
    335 
    336   // Helpers for Transpose and Multiply operations that unify in-place and
    337   // out-of-place solutions.
    338   Matrix& Transpose(const T* const* src) {
    339     for (size_t i = 0; i < num_rows_; ++i) {
    340       for (size_t j = 0; j < num_columns_; ++j) {
    341         elements_[i][j] = src[j][i];
    342       }
    343     }
    344 
    345     return *this;
    346   }
    347 
    348   Matrix& Multiply(const T* const* lhs,
    349                    size_t num_rows_rhs,
    350                    const T* const* rhs) {
    351     for (size_t row = 0; row < num_rows_; ++row) {
    352       for (size_t col = 0; col < num_columns_; ++col) {
    353         T cur_element = 0;
    354         for (size_t i = 0; i < num_rows_rhs; ++i) {
    355           cur_element += lhs[row][i] * rhs[i][col];
    356         }
    357 
    358         elements_[row][col] = cur_element;
    359       }
    360     }
    361 
    362     return *this;
    363   }
    364 
    365   RTC_DISALLOW_COPY_AND_ASSIGN(Matrix);
    366 };
    367 
    368 }  // namespace webrtc
    369 
    370 #endif  // WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_
    371