1 /* 2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_ 12 #define WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_ 13 14 #include <algorithm> 15 #include <cstring> 16 #include <string> 17 #include <vector> 18 19 #include "webrtc/base/checks.h" 20 #include "webrtc/base/constructormagic.h" 21 #include "webrtc/base/scoped_ptr.h" 22 23 namespace { 24 25 // Wrappers to get around the compiler warning resulting from the fact that 26 // there's no std::sqrt overload for ints. We cast all non-complex types to 27 // a double for the sqrt method. 28 template <typename T> 29 T sqrt_wrapper(T x) { 30 return sqrt(static_cast<double>(x)); 31 } 32 33 template <typename S> 34 std::complex<S> sqrt_wrapper(std::complex<S> x) { 35 return sqrt(x); 36 } 37 } // namespace 38 39 namespace webrtc { 40 41 // Matrix is a class for doing standard matrix operations on 2 dimensional 42 // matrices of any size. Results of matrix operations are stored in the 43 // calling object. Function overloads exist for both in-place (the calling 44 // object is used as both an operand and the result) and out-of-place (all 45 // operands are passed in as parameters) operations. If operand dimensions 46 // mismatch, the program crashes. Out-of-place operations change the size of 47 // the calling object, if necessary, before operating. 48 // 49 // 'In-place' operations that inherently change the size of the matrix (eg. 50 // Transpose, Multiply on different-sized matrices) must make temporary copies 51 // (|scratch_elements_| and |scratch_data_|) of existing data to complete the 52 // operations. 53 // 54 // The data is stored contiguously. Data can be accessed internally as a flat 55 // array, |data_|, or as an array of row pointers, |elements_|, but is 56 // available to users only as an array of row pointers through |elements()|. 57 // Memory for storage is allocated when a matrix is resized only if the new 58 // size overflows capacity. Memory needed temporarily for any operations is 59 // similarly resized only if the new size overflows capacity. 60 // 61 // If you pass in storage through the ctor, that storage is copied into the 62 // matrix. TODO(claguna): albeit tricky, allow for data to be referenced 63 // instead of copied, and owned by the user. 64 template <typename T> 65 class Matrix { 66 public: 67 Matrix() : num_rows_(0), num_columns_(0) {} 68 69 // Allocates space for the elements and initializes all values to zero. 70 Matrix(size_t num_rows, size_t num_columns) 71 : num_rows_(num_rows), num_columns_(num_columns) { 72 Resize(); 73 scratch_data_.resize(num_rows_ * num_columns_); 74 scratch_elements_.resize(num_rows_); 75 } 76 77 // Copies |data| into the new Matrix. 78 Matrix(const T* data, size_t num_rows, size_t num_columns) 79 : num_rows_(0), num_columns_(0) { 80 CopyFrom(data, num_rows, num_columns); 81 scratch_data_.resize(num_rows_ * num_columns_); 82 scratch_elements_.resize(num_rows_); 83 } 84 85 virtual ~Matrix() {} 86 87 // Deep copy an existing matrix. 88 void CopyFrom(const Matrix& other) { 89 CopyFrom(&other.data_[0], other.num_rows_, other.num_columns_); 90 } 91 92 // Copy |data| into the Matrix. The current data is lost. 93 void CopyFrom(const T* const data, size_t num_rows, size_t num_columns) { 94 Resize(num_rows, num_columns); 95 memcpy(&data_[0], data, num_rows_ * num_columns_ * sizeof(data_[0])); 96 } 97 98 Matrix& CopyFromColumn(const T* const* src, 99 size_t column_index, 100 size_t num_rows) { 101 Resize(1, num_rows); 102 for (size_t i = 0; i < num_columns_; ++i) { 103 data_[i] = src[i][column_index]; 104 } 105 106 return *this; 107 } 108 109 void Resize(size_t num_rows, size_t num_columns) { 110 if (num_rows != num_rows_ || num_columns != num_columns_) { 111 num_rows_ = num_rows; 112 num_columns_ = num_columns; 113 Resize(); 114 } 115 } 116 117 // Accessors and mutators. 118 size_t num_rows() const { return num_rows_; } 119 size_t num_columns() const { return num_columns_; } 120 T* const* elements() { return &elements_[0]; } 121 const T* const* elements() const { return &elements_[0]; } 122 123 T Trace() { 124 RTC_CHECK_EQ(num_rows_, num_columns_); 125 126 T trace = 0; 127 for (size_t i = 0; i < num_rows_; ++i) { 128 trace += elements_[i][i]; 129 } 130 return trace; 131 } 132 133 // Matrix Operations. Returns *this to support method chaining. 134 Matrix& Transpose() { 135 CopyDataToScratch(); 136 Resize(num_columns_, num_rows_); 137 return Transpose(scratch_elements()); 138 } 139 140 Matrix& Transpose(const Matrix& operand) { 141 RTC_CHECK_EQ(operand.num_rows_, num_columns_); 142 RTC_CHECK_EQ(operand.num_columns_, num_rows_); 143 144 return Transpose(operand.elements()); 145 } 146 147 template <typename S> 148 Matrix& Scale(const S& scalar) { 149 for (size_t i = 0; i < data_.size(); ++i) { 150 data_[i] *= scalar; 151 } 152 153 return *this; 154 } 155 156 template <typename S> 157 Matrix& Scale(const Matrix& operand, const S& scalar) { 158 CopyFrom(operand); 159 return Scale(scalar); 160 } 161 162 Matrix& Add(const Matrix& operand) { 163 RTC_CHECK_EQ(num_rows_, operand.num_rows_); 164 RTC_CHECK_EQ(num_columns_, operand.num_columns_); 165 166 for (size_t i = 0; i < data_.size(); ++i) { 167 data_[i] += operand.data_[i]; 168 } 169 170 return *this; 171 } 172 173 Matrix& Add(const Matrix& lhs, const Matrix& rhs) { 174 CopyFrom(lhs); 175 return Add(rhs); 176 } 177 178 Matrix& Subtract(const Matrix& operand) { 179 RTC_CHECK_EQ(num_rows_, operand.num_rows_); 180 RTC_CHECK_EQ(num_columns_, operand.num_columns_); 181 182 for (size_t i = 0; i < data_.size(); ++i) { 183 data_[i] -= operand.data_[i]; 184 } 185 186 return *this; 187 } 188 189 Matrix& Subtract(const Matrix& lhs, const Matrix& rhs) { 190 CopyFrom(lhs); 191 return Subtract(rhs); 192 } 193 194 Matrix& PointwiseMultiply(const Matrix& operand) { 195 RTC_CHECK_EQ(num_rows_, operand.num_rows_); 196 RTC_CHECK_EQ(num_columns_, operand.num_columns_); 197 198 for (size_t i = 0; i < data_.size(); ++i) { 199 data_[i] *= operand.data_[i]; 200 } 201 202 return *this; 203 } 204 205 Matrix& PointwiseMultiply(const Matrix& lhs, const Matrix& rhs) { 206 CopyFrom(lhs); 207 return PointwiseMultiply(rhs); 208 } 209 210 Matrix& PointwiseDivide(const Matrix& operand) { 211 RTC_CHECK_EQ(num_rows_, operand.num_rows_); 212 RTC_CHECK_EQ(num_columns_, operand.num_columns_); 213 214 for (size_t i = 0; i < data_.size(); ++i) { 215 data_[i] /= operand.data_[i]; 216 } 217 218 return *this; 219 } 220 221 Matrix& PointwiseDivide(const Matrix& lhs, const Matrix& rhs) { 222 CopyFrom(lhs); 223 return PointwiseDivide(rhs); 224 } 225 226 Matrix& PointwiseSquareRoot() { 227 for (size_t i = 0; i < data_.size(); ++i) { 228 data_[i] = sqrt_wrapper(data_[i]); 229 } 230 231 return *this; 232 } 233 234 Matrix& PointwiseSquareRoot(const Matrix& operand) { 235 CopyFrom(operand); 236 return PointwiseSquareRoot(); 237 } 238 239 Matrix& PointwiseAbsoluteValue() { 240 for (size_t i = 0; i < data_.size(); ++i) { 241 data_[i] = abs(data_[i]); 242 } 243 244 return *this; 245 } 246 247 Matrix& PointwiseAbsoluteValue(const Matrix& operand) { 248 CopyFrom(operand); 249 return PointwiseAbsoluteValue(); 250 } 251 252 Matrix& PointwiseSquare() { 253 for (size_t i = 0; i < data_.size(); ++i) { 254 data_[i] *= data_[i]; 255 } 256 257 return *this; 258 } 259 260 Matrix& PointwiseSquare(const Matrix& operand) { 261 CopyFrom(operand); 262 return PointwiseSquare(); 263 } 264 265 Matrix& Multiply(const Matrix& lhs, const Matrix& rhs) { 266 RTC_CHECK_EQ(lhs.num_columns_, rhs.num_rows_); 267 RTC_CHECK_EQ(num_rows_, lhs.num_rows_); 268 RTC_CHECK_EQ(num_columns_, rhs.num_columns_); 269 270 return Multiply(lhs.elements(), rhs.num_rows_, rhs.elements()); 271 } 272 273 Matrix& Multiply(const Matrix& rhs) { 274 RTC_CHECK_EQ(num_columns_, rhs.num_rows_); 275 276 CopyDataToScratch(); 277 Resize(num_rows_, rhs.num_columns_); 278 return Multiply(scratch_elements(), rhs.num_rows_, rhs.elements()); 279 } 280 281 std::string ToString() const { 282 std::ostringstream ss; 283 ss << std::endl << "Matrix" << std::endl; 284 285 for (size_t i = 0; i < num_rows_; ++i) { 286 for (size_t j = 0; j < num_columns_; ++j) { 287 ss << elements_[i][j] << " "; 288 } 289 ss << std::endl; 290 } 291 ss << std::endl; 292 293 return ss.str(); 294 } 295 296 protected: 297 void SetNumRows(const size_t num_rows) { num_rows_ = num_rows; } 298 void SetNumColumns(const size_t num_columns) { num_columns_ = num_columns; } 299 T* data() { return &data_[0]; } 300 const T* data() const { return &data_[0]; } 301 const T* const* scratch_elements() const { return &scratch_elements_[0]; } 302 303 // Resize the matrix. If an increase in capacity is required, the current 304 // data is lost. 305 void Resize() { 306 size_t size = num_rows_ * num_columns_; 307 data_.resize(size); 308 elements_.resize(num_rows_); 309 310 for (size_t i = 0; i < num_rows_; ++i) { 311 elements_[i] = &data_[i * num_columns_]; 312 } 313 } 314 315 // Copies data_ into scratch_data_ and updates scratch_elements_ accordingly. 316 void CopyDataToScratch() { 317 scratch_data_ = data_; 318 scratch_elements_.resize(num_rows_); 319 320 for (size_t i = 0; i < num_rows_; ++i) { 321 scratch_elements_[i] = &scratch_data_[i * num_columns_]; 322 } 323 } 324 325 private: 326 size_t num_rows_; 327 size_t num_columns_; 328 std::vector<T> data_; 329 std::vector<T*> elements_; 330 331 // Stores temporary copies of |data_| and |elements_| for in-place operations 332 // where referring to original data is necessary. 333 std::vector<T> scratch_data_; 334 std::vector<T*> scratch_elements_; 335 336 // Helpers for Transpose and Multiply operations that unify in-place and 337 // out-of-place solutions. 338 Matrix& Transpose(const T* const* src) { 339 for (size_t i = 0; i < num_rows_; ++i) { 340 for (size_t j = 0; j < num_columns_; ++j) { 341 elements_[i][j] = src[j][i]; 342 } 343 } 344 345 return *this; 346 } 347 348 Matrix& Multiply(const T* const* lhs, 349 size_t num_rows_rhs, 350 const T* const* rhs) { 351 for (size_t row = 0; row < num_rows_; ++row) { 352 for (size_t col = 0; col < num_columns_; ++col) { 353 T cur_element = 0; 354 for (size_t i = 0; i < num_rows_rhs; ++i) { 355 cur_element += lhs[row][i] * rhs[i][col]; 356 } 357 358 elements_[row][col] = cur_element; 359 } 360 } 361 362 return *this; 363 } 364 365 RTC_DISALLOW_COPY_AND_ASSIGN(Matrix); 366 }; 367 368 } // namespace webrtc 369 370 #endif // WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_ 371