1 /* 2 * Copyright (C) 2008 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 /* ---- includes ----------------------------------------------------------- */ 18 19 #include "b_TensorEm/Int32Mat.h" 20 #include "b_TensorEm/Functions.h" 21 #include "b_BasicEm/Math.h" 22 #include "b_BasicEm/Functions.h" 23 #include "b_BasicEm/Memory.h" 24 25 /* ------------------------------------------------------------------------- */ 26 27 /* ========================================================================= */ 28 /* */ 29 /* ---- \ghd{ auxiliary functions } ---------------------------------------- */ 30 /* */ 31 /* ========================================================================= */ 32 33 /* ------------------------------------------------------------------------- */ 34 35 void bts_Int32Mat_reduceToNBits( int32* ptrA, uint32 sizeA, int32* bbpPtrA, uint32 nBitsA ) 36 { 37 int32 shiftL; 38 39 /* find max element */ 40 int32 maxL = 0; 41 int32* ptrL = ptrA; 42 int32 iL = sizeA; 43 while( iL-- ) 44 { 45 int32 xL = *ptrL++; 46 if( xL < 0 ) xL = -xL; 47 if( xL > maxL ) maxL = xL; 48 } 49 50 /* determine shift */ 51 shiftL = bts_absIntLog2( maxL ) + 1 - nBitsA; 52 53 if( shiftL > 0 ) 54 { 55 ptrL = ptrA; 56 iL = sizeA; 57 while( iL-- ) 58 { 59 *ptrL = ( ( *ptrL >> ( shiftL - 1 ) ) + 1 ) >> 1; 60 ptrL++; 61 } 62 63 *bbpPtrA -= shiftL; 64 } 65 } 66 67 /* ------------------------------------------------------------------------- */ 68 69 /* ========================================================================= */ 70 /* */ 71 /* ---- \ghd{ constructor / destructor } ----------------------------------- */ 72 /* */ 73 /* ========================================================================= */ 74 75 /* ------------------------------------------------------------------------- */ 76 77 void bts_Int32Mat_init( struct bbs_Context* cpA, 78 struct bts_Int32Mat* ptrA ) 79 { 80 ptrA->widthE = 0; 81 bbs_Int32Arr_init( cpA, &ptrA->arrE ); 82 } 83 84 /* ------------------------------------------------------------------------- */ 85 86 void bts_Int32Mat_exit( struct bbs_Context* cpA, 87 struct bts_Int32Mat* ptrA ) 88 { 89 ptrA->widthE = 0; 90 bbs_Int32Arr_exit( cpA, &ptrA->arrE ); 91 } 92 /* ------------------------------------------------------------------------- */ 93 94 /* ========================================================================= */ 95 /* */ 96 /* ---- \ghd{ operators } -------------------------------------------------- */ 97 /* */ 98 /* ========================================================================= */ 99 100 /* ------------------------------------------------------------------------- */ 101 102 /* ========================================================================= */ 103 /* */ 104 /* ---- \ghd{ query functions } -------------------------------------------- */ 105 /* */ 106 /* ========================================================================= */ 107 108 /* ------------------------------------------------------------------------- */ 109 110 /* ========================================================================= */ 111 /* */ 112 /* ---- \ghd{ modify functions } ------------------------------------------- */ 113 /* */ 114 /* ========================================================================= */ 115 116 /* ------------------------------------------------------------------------- */ 117 118 void bts_Int32Mat_create( struct bbs_Context* cpA, 119 struct bts_Int32Mat* ptrA, 120 int32 widthA, 121 struct bbs_MemSeg* mspA ) 122 { 123 if( bbs_Context_error( cpA ) ) return; 124 bbs_Int32Arr_create( cpA, &ptrA->arrE, widthA * widthA, mspA ); 125 ptrA->widthE = widthA; 126 } 127 128 /* ------------------------------------------------------------------------- */ 129 130 void bts_Int32Mat_copy( struct bbs_Context* cpA, 131 struct bts_Int32Mat* ptrA, 132 const struct bts_Int32Mat* srcPtrA ) 133 { 134 if( ptrA->widthE != srcPtrA->widthE ) 135 { 136 bbs_ERROR0( "void bts_Int32Mat_copy( struct bts_Int32Mat* ptrA, struct bts_Int32Mat* srcPtrA ):\n" 137 "size mismatch" ); 138 return; 139 } 140 141 bbs_Int32Arr_copy( cpA, &ptrA->arrE, &srcPtrA->arrE ); 142 } 143 144 /* ------------------------------------------------------------------------- */ 145 146 /* ========================================================================= */ 147 /* */ 148 /* ---- \ghd{ I/O } -------------------------------------------------------- */ 149 /* */ 150 /* ========================================================================= */ 151 152 /* ------------------------------------------------------------------------- */ 153 154 uint32 bts_Int32Mat_memSize( struct bbs_Context* cpA, 155 const struct bts_Int32Mat *ptrA ) 156 { 157 return bbs_SIZEOF16( uint32 ) 158 + bbs_SIZEOF16( uint32 ) /* version */ 159 + bbs_SIZEOF16( ptrA->widthE ) 160 + bbs_Int32Arr_memSize( cpA, &ptrA->arrE ); 161 } 162 163 /* ------------------------------------------------------------------------- */ 164 165 uint32 bts_Int32Mat_memWrite( struct bbs_Context* cpA, 166 const struct bts_Int32Mat* ptrA, 167 uint16* memPtrA ) 168 { 169 uint32 memSizeL = bts_Int32Mat_memSize( cpA, ptrA ); 170 memPtrA += bbs_memWrite32( &memSizeL, memPtrA ); 171 memPtrA += bbs_memWriteUInt32( bts_INT32MAT_VERSION, memPtrA ); 172 memPtrA += bbs_memWrite32( &ptrA->widthE, memPtrA ); 173 memPtrA += bbs_Int32Arr_memWrite( cpA, &ptrA->arrE, memPtrA ); 174 return memSizeL; 175 } 176 177 /* ------------------------------------------------------------------------- */ 178 179 uint32 bts_Int32Mat_memRead( struct bbs_Context* cpA, 180 struct bts_Int32Mat* ptrA, 181 const uint16* memPtrA, 182 struct bbs_MemSeg* mspA ) 183 { 184 uint32 memSizeL, versionL; 185 if( bbs_Context_error( cpA ) ) return 0; 186 memPtrA += bbs_memRead32( &memSizeL, memPtrA ); 187 memPtrA += bbs_memReadVersion32( cpA, &versionL, bts_INT32MAT_VERSION, memPtrA ); 188 memPtrA += bbs_memRead32( &ptrA->widthE, memPtrA ); 189 memPtrA += bbs_Int32Arr_memRead( cpA, &ptrA->arrE, memPtrA, mspA ); 190 191 if( memSizeL != bts_Int32Mat_memSize( cpA, ptrA ) ) 192 { 193 bbs_ERR0( bbs_ERR_CORRUPT_DATA, "uint32 bts_Int32Mat_memRead( const struct bts_Int32Mat* ptrA, const void* memPtrA ):\n" 194 "size mismatch" ); 195 } 196 return memSizeL; 197 } 198 199 /* ------------------------------------------------------------------------- */ 200 201 /* ========================================================================= */ 202 /* */ 203 /* ---- \ghd{ exec functions } --------------------------------------------- */ 204 /* */ 205 /* ========================================================================= */ 206 207 /* ------------------------------------------------------------------------- */ 208 209 flag bts_Int32Mat_solve( struct bbs_Context* cpA, 210 const int32* matA, 211 int32 matWidthA, 212 const int32* inVecA, 213 int32* outVecA, 214 int32 bbpA, 215 int32* tmpMatA, 216 int32* tmpVecA ) 217 { 218 bbs_memcpy32( tmpMatA, matA, ( matWidthA * matWidthA ) * bbs_SIZEOF32( int32 ) ); 219 220 return bts_Int32Mat_solve2( cpA, 221 tmpMatA, 222 matWidthA, 223 inVecA, 224 outVecA, 225 bbpA, 226 tmpVecA ); 227 } 228 229 /* ------------------------------------------------------------------------- */ 230 231 flag bts_Int32Mat_solve2( struct bbs_Context* cpA, 232 int32* matA, 233 int32 matWidthA, 234 const int32* inVecA, 235 int32* outVecA, 236 int32 bbpA, 237 int32* tmpVecA ) 238 { 239 int32 sizeL = matWidthA; 240 int32 bbpL = bbpA; 241 int32 iL, jL, kL; 242 int32 iPivL; 243 int32 jPivL; 244 245 int32* vecL = outVecA; 246 int32* matL = matA; 247 int32* checkArrL = tmpVecA; 248 249 for( iL = 0; iL < sizeL; iL++ ) 250 { 251 checkArrL[ iL ] = 0; 252 } 253 254 bbs_memcpy32( outVecA, inVecA, sizeL * bbs_SIZEOF32( int32 ) ); 255 256 iPivL = 0; 257 258 for( kL = 0; kL < sizeL; kL++ ) 259 { 260 /* find pivot */ 261 int32 maxAbsL = 0; 262 int32* pivRowL; 263 264 int32 bbp_pivRowL, bbp_vecL, shiftL; 265 266 jPivL = -1; 267 for( iL = 0; iL < sizeL; iL++ ) 268 { 269 if( checkArrL[ iL ] != 1 ) 270 { 271 int32* rowL = matL + ( iL * sizeL ); 272 for( jL = 0; jL < sizeL; jL++ ) 273 { 274 if( checkArrL[ jL ] == 0 ) 275 { 276 int32 absElemL = rowL[ jL ]; 277 if( absElemL < 0 ) absElemL = -absElemL; 278 if( maxAbsL < absElemL ) 279 { 280 maxAbsL = absElemL; 281 iPivL = iL; 282 jPivL = jL; 283 } 284 } 285 else if( checkArrL[ jL ] > 1 ) 286 { 287 return FALSE; 288 } 289 } 290 } 291 } 292 293 /* successfull ? */ 294 if( jPivL < 0 ) 295 { 296 return FALSE; 297 } 298 299 checkArrL[ jPivL ]++; 300 301 /* exchange rows to put pivot on diagonal, if neccessary */ 302 if( iPivL != jPivL ) 303 { 304 int32* row1PtrL = matL + ( iPivL * sizeL ); 305 int32* row2PtrL = matL + ( jPivL * sizeL ); 306 for( jL = 0; jL < sizeL; jL++ ) 307 { 308 int32 tmpL = *row1PtrL; 309 *row1PtrL++ = *row2PtrL; 310 *row2PtrL++ = tmpL; 311 } 312 313 { 314 int32 tmpL = vecL[ jPivL ]; 315 vecL[ jPivL ] = vecL[ iPivL ]; 316 vecL[ iPivL ] = tmpL; 317 } 318 } 319 /* now index jPivL specifies pivot row and maximum element */ 320 321 322 /** Overflow protection: only if the highest bit of the largest matrix element is set, 323 * we need to shift the whole matrix and the right side vector 1 bit to the right, 324 * to make sure there can be no overflow when the pivot row gets subtracted from the 325 * other rows. 326 * Getting that close to overflow is a rare event, so this shift will happen only 327 * occasionally, or not at all. 328 */ 329 if( maxAbsL & 1073741824 ) /*( 1 << 30 )*/ 330 { 331 /* right shift matrix by 1 */ 332 int32 iL = sizeL * sizeL; 333 int32* ptrL = matL; 334 while( iL-- ) 335 { 336 *ptrL = ( *ptrL + 1 ) >> 1; 337 ptrL++; 338 } 339 340 /* right shift right side vector by 1 */ 341 iL = sizeL; 342 ptrL = vecL; 343 while( iL-- ) 344 { 345 *ptrL = ( *ptrL + 1 ) >> 1; 346 ptrL++; 347 } 348 349 /* decrement bbpL */ 350 bbpL--; 351 } 352 353 354 /* reduce elements of pivot row to 15 bit */ 355 pivRowL = matL + jPivL * sizeL; 356 bbp_pivRowL = bbpL; 357 bts_Int32Mat_reduceToNBits( pivRowL, sizeL, &bbp_pivRowL, 15 ); 358 359 /* scale pivot row such that maximum equals 1 */ 360 { 361 int32 maxL = pivRowL[ jPivL ]; 362 int32 bbp_maxL = bbp_pivRowL; 363 int32 factorL = 1073741824 / maxL; /*( 1 << 30 )*/ 364 365 for( jL = 0; jL < sizeL; jL++ ) 366 { 367 pivRowL[ jL ] = ( pivRowL[ jL ] * factorL + ( 1 << 14 ) ) >> 15; 368 } 369 bbp_pivRowL = 15; 370 371 /* set to 1 to avoid computational errors */ 372 pivRowL[ jPivL ] = ( int32 )1 << bbp_pivRowL; 373 374 shiftL = 30 - bts_absIntLog2( vecL[ jPivL ] ); 375 376 vecL[ jPivL ] = ( vecL[ jPivL ] << shiftL ) / maxL; 377 bbp_vecL = bbpL + shiftL - bbp_maxL; 378 379 bbs_int32ReduceToNBits( &( vecL[ jPivL ] ), &bbp_vecL, 15 ); 380 } 381 382 /* subtract pivot row from all other rows */ 383 for( iL = 0; iL < sizeL; iL++ ) 384 { 385 if( iL != jPivL ) 386 { 387 int32* rowPtrL = matL + iL * sizeL; 388 389 int32 tmpL = *( rowPtrL + jPivL ); 390 int32 bbp_tmpL = bbpL; 391 bbs_int32ReduceToNBits( &tmpL, &bbp_tmpL, 15 ); 392 393 shiftL = bbp_tmpL + bbp_pivRowL - bbpL; 394 if( shiftL > 0 ) 395 { 396 for( jL = 0; jL < sizeL; jL++ ) 397 { 398 *rowPtrL++ -= ( ( ( tmpL * pivRowL[ jL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1; 399 } 400 } 401 else 402 { 403 for( jL = 0; jL < sizeL; jL++ ) 404 { 405 *rowPtrL++ -= ( tmpL * pivRowL[ jL ] ) << -shiftL; 406 } 407 } 408 409 shiftL = bbp_tmpL + bbp_vecL - bbpL; 410 if( shiftL > 0 ) 411 { 412 vecL[ iL ] -= ( ( ( tmpL * vecL[ jPivL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1; 413 } 414 else 415 { 416 vecL[ iL ] -= ( tmpL * vecL[ jPivL ] ) << -shiftL; 417 } 418 } 419 } 420 421 /* change bbp of pivot row back to bbpL */ 422 shiftL = bbpL - bbp_pivRowL; 423 if( shiftL >= 0 ) 424 { 425 for( jL = 0; jL < sizeL; jL++ ) 426 { 427 pivRowL[ jL ] <<= shiftL; 428 } 429 } 430 else 431 { 432 shiftL = -shiftL; 433 for( jL = 0; jL < sizeL; jL++ ) 434 { 435 pivRowL[ jL ] = ( ( pivRowL[ jL ] >> ( shiftL - 1 ) ) + 1 ) >> 1; 436 } 437 } 438 439 shiftL = bbpL - bbp_vecL; 440 if( shiftL >= 0 ) 441 { 442 vecL[ jPivL ] <<= shiftL; 443 } 444 else 445 { 446 shiftL = -shiftL; 447 vecL[ jPivL ] = ( ( vecL[ jPivL ] >> ( shiftL - 1 ) ) + 1 ) >> 1; 448 } 449 /* 450 if( sizeL <= 5 ) bts_Int32Mat_print( matL, vecL, sizeL, bbpL ); 451 */ 452 } /* of kL */ 453 454 /* in case bbpL has been decreased by the overflow protection, change it back now */ 455 if( bbpA > bbpL ) 456 { 457 /* find largest element of solution vector */ 458 int32 maxL = 0; 459 int32 iL, shiftL; 460 for( iL = 0; iL < sizeL; iL++ ) 461 { 462 int32 xL = vecL[ iL ]; 463 if( xL < 0 ) xL = -xL; 464 if( xL > maxL ) maxL = xL; 465 } 466 467 /* check whether we can left shift without overflow */ 468 shiftL = 30 - bts_absIntLog2( maxL ); 469 if( shiftL < ( bbpA - bbpL ) ) 470 { 471 /* 472 bbs_WARNING1( "flag bts_Int32Mat_solve2( ... ): getting overflow when trying to " 473 "compute solution vector with bbp = %d. Choose smaller bbp.\n", bbpA ); 474 */ 475 476 return FALSE; 477 } 478 479 /* shift left */ 480 shiftL = bbpA - bbpL; 481 for( iL = 0; iL < sizeL; iL++ ) vecL[ iL ] <<= shiftL; 482 } 483 484 return TRUE; 485 } 486 487 /* ------------------------------------------------------------------------- */ 488 489 /* ========================================================================= */ 490 491