Home | History | Annotate | Download | only in b_TensorEm
      1 /*
      2  * Copyright (C) 2008 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 /* ---- includes ----------------------------------------------------------- */
     18 
     19 #include "b_TensorEm/Int32Mat.h"
     20 #include "b_TensorEm/Functions.h"
     21 #include "b_BasicEm/Math.h"
     22 #include "b_BasicEm/Functions.h"
     23 #include "b_BasicEm/Memory.h"
     24 
     25 /* ------------------------------------------------------------------------- */
     26 
     27 /* ========================================================================= */
     28 /*                                                                           */
     29 /* ---- \ghd{ auxiliary functions } ---------------------------------------- */
     30 /*                                                                           */
     31 /* ========================================================================= */
     32 
     33 /* ------------------------------------------------------------------------- */
     34 
     35 void bts_Int32Mat_reduceToNBits( int32* ptrA, uint32 sizeA, int32* bbpPtrA, uint32 nBitsA )
     36 {
     37 	int32 shiftL;
     38 
     39 	/* find max element */
     40 	int32 maxL = 0;
     41 	int32* ptrL = ptrA;
     42 	int32 iL = sizeA;
     43 	while( iL-- )
     44 	{
     45 		int32 xL = *ptrL++;
     46 		if( xL < 0 ) xL = -xL;
     47 		if( xL > maxL ) maxL = xL;
     48 	}
     49 
     50 	/* determine shift */
     51 	shiftL = bts_absIntLog2( maxL ) + 1 - nBitsA;
     52 
     53 	if( shiftL > 0 )
     54 	{
     55 		ptrL = ptrA;
     56 		iL = sizeA;
     57 		while( iL-- )
     58 		{
     59 			*ptrL = ( ( *ptrL >> ( shiftL - 1 ) ) + 1 ) >> 1;
     60 			ptrL++;
     61 		}
     62 
     63 		*bbpPtrA -= shiftL;
     64 	}
     65 }
     66 
     67 /* ------------------------------------------------------------------------- */
     68 
     69 /* ========================================================================= */
     70 /*                                                                           */
     71 /* ---- \ghd{ constructor / destructor } ----------------------------------- */
     72 /*                                                                           */
     73 /* ========================================================================= */
     74 
     75 /* ------------------------------------------------------------------------- */
     76 
     77 void bts_Int32Mat_init( struct bbs_Context* cpA,
     78 					    struct bts_Int32Mat* ptrA )
     79 {
     80 	ptrA->widthE = 0;
     81 	bbs_Int32Arr_init( cpA, &ptrA->arrE );
     82 }
     83 
     84 /* ------------------------------------------------------------------------- */
     85 
     86 void bts_Int32Mat_exit( struct bbs_Context* cpA,
     87 					    struct bts_Int32Mat* ptrA )
     88 {
     89 	ptrA->widthE = 0;
     90 	bbs_Int32Arr_exit( cpA, &ptrA->arrE );
     91 }
     92 /* ------------------------------------------------------------------------- */
     93 
     94 /* ========================================================================= */
     95 /*                                                                           */
     96 /* ---- \ghd{ operators } -------------------------------------------------- */
     97 /*                                                                           */
     98 /* ========================================================================= */
     99 
    100 /* ------------------------------------------------------------------------- */
    101 
    102 /* ========================================================================= */
    103 /*                                                                           */
    104 /* ---- \ghd{ query functions } -------------------------------------------- */
    105 /*                                                                           */
    106 /* ========================================================================= */
    107 
    108 /* ------------------------------------------------------------------------- */
    109 
    110 /* ========================================================================= */
    111 /*                                                                           */
    112 /* ---- \ghd{ modify functions } ------------------------------------------- */
    113 /*                                                                           */
    114 /* ========================================================================= */
    115 
    116 /* ------------------------------------------------------------------------- */
    117 
    118 void bts_Int32Mat_create( struct bbs_Context* cpA,
    119 						  struct bts_Int32Mat* ptrA,
    120 						  int32 widthA,
    121 				          struct bbs_MemSeg* mspA )
    122 {
    123 	if( bbs_Context_error( cpA ) ) return;
    124 	bbs_Int32Arr_create( cpA, &ptrA->arrE, widthA * widthA, mspA );
    125 	ptrA->widthE = widthA;
    126 }
    127 
    128 /* ------------------------------------------------------------------------- */
    129 
    130 void bts_Int32Mat_copy( struct bbs_Context* cpA,
    131 					    struct bts_Int32Mat* ptrA,
    132 						const struct bts_Int32Mat* srcPtrA )
    133 {
    134 	if( ptrA->widthE != srcPtrA->widthE )
    135 	{
    136 		bbs_ERROR0( "void bts_Int32Mat_copy( struct bts_Int32Mat* ptrA, struct bts_Int32Mat* srcPtrA ):\n"
    137 			       "size mismatch" );
    138 		return;
    139 	}
    140 
    141 	bbs_Int32Arr_copy( cpA, &ptrA->arrE, &srcPtrA->arrE );
    142 }
    143 
    144 /* ------------------------------------------------------------------------- */
    145 
    146 /* ========================================================================= */
    147 /*                                                                           */
    148 /* ---- \ghd{ I/O } -------------------------------------------------------- */
    149 /*                                                                           */
    150 /* ========================================================================= */
    151 
    152 /* ------------------------------------------------------------------------- */
    153 
    154 uint32 bts_Int32Mat_memSize( struct bbs_Context* cpA,
    155 							 const struct bts_Int32Mat *ptrA )
    156 {
    157 	return  bbs_SIZEOF16( uint32 )
    158 		  + bbs_SIZEOF16( uint32 ) /* version */
    159 		  + bbs_SIZEOF16( ptrA->widthE )
    160 		  + bbs_Int32Arr_memSize( cpA, &ptrA->arrE );
    161 }
    162 
    163 /* ------------------------------------------------------------------------- */
    164 
    165 uint32 bts_Int32Mat_memWrite( struct bbs_Context* cpA,
    166 							  const struct bts_Int32Mat* ptrA,
    167 							  uint16* memPtrA )
    168 {
    169 	uint32 memSizeL = bts_Int32Mat_memSize( cpA, ptrA );
    170 	memPtrA += bbs_memWrite32( &memSizeL, memPtrA );
    171 	memPtrA += bbs_memWriteUInt32( bts_INT32MAT_VERSION, memPtrA );
    172 	memPtrA += bbs_memWrite32( &ptrA->widthE, memPtrA );
    173 	memPtrA += bbs_Int32Arr_memWrite( cpA, &ptrA->arrE, memPtrA );
    174 	return memSizeL;
    175 }
    176 
    177 /* ------------------------------------------------------------------------- */
    178 
    179 uint32 bts_Int32Mat_memRead( struct bbs_Context* cpA,
    180 							 struct bts_Int32Mat* ptrA,
    181 							 const uint16* memPtrA,
    182 				             struct bbs_MemSeg* mspA )
    183 {
    184 	uint32 memSizeL, versionL;
    185 	if( bbs_Context_error( cpA ) ) return 0;
    186 	memPtrA += bbs_memRead32( &memSizeL, memPtrA );
    187 	memPtrA += bbs_memReadVersion32( cpA, &versionL, bts_INT32MAT_VERSION, memPtrA );
    188 	memPtrA += bbs_memRead32( &ptrA->widthE, memPtrA );
    189 	memPtrA += bbs_Int32Arr_memRead( cpA, &ptrA->arrE, memPtrA, mspA );
    190 
    191 	if( memSizeL != bts_Int32Mat_memSize( cpA, ptrA ) )
    192 	{
    193 		bbs_ERR0( bbs_ERR_CORRUPT_DATA, "uint32 bts_Int32Mat_memRead( const struct bts_Int32Mat* ptrA, const void* memPtrA ):\n"
    194                   "size mismatch" );
    195 	}
    196 	return memSizeL;
    197 }
    198 
    199 /* ------------------------------------------------------------------------- */
    200 
    201 /* ========================================================================= */
    202 /*                                                                           */
    203 /* ---- \ghd{ exec functions } --------------------------------------------- */
    204 /*                                                                           */
    205 /* ========================================================================= */
    206 
    207 /* ------------------------------------------------------------------------- */
    208 
    209 flag bts_Int32Mat_solve( struct bbs_Context* cpA,
    210 						 const int32* matA,
    211 						 int32 matWidthA,
    212 						 const int32* inVecA,
    213 						 int32* outVecA,
    214 						 int32 bbpA,
    215 						 int32* tmpMatA,
    216 						 int32* tmpVecA )
    217 {
    218 	bbs_memcpy32( tmpMatA, matA, ( matWidthA * matWidthA ) * bbs_SIZEOF32( int32 ) );
    219 
    220 	return bts_Int32Mat_solve2( cpA,
    221 		                        tmpMatA,
    222 								matWidthA,
    223 								inVecA,
    224 								outVecA,
    225 								bbpA,
    226 								tmpVecA );
    227 }
    228 
    229 /* ------------------------------------------------------------------------- */
    230 
    231 flag bts_Int32Mat_solve2( struct bbs_Context* cpA,
    232 						  int32* matA,
    233 						  int32 matWidthA,
    234 						  const int32* inVecA,
    235 						  int32* outVecA,
    236 						  int32 bbpA,
    237 						  int32* tmpVecA )
    238 {
    239 	int32 sizeL = matWidthA;
    240 	int32 bbpL = bbpA;
    241 	int32 iL, jL, kL;
    242 	int32 iPivL;
    243 	int32 jPivL;
    244 
    245 	int32* vecL      = outVecA;
    246 	int32* matL      = matA;
    247 	int32* checkArrL = tmpVecA;
    248 
    249 	for( iL = 0; iL < sizeL; iL++ )
    250 	{
    251 		checkArrL[ iL ] = 0;
    252 	}
    253 
    254 	bbs_memcpy32( outVecA, inVecA, sizeL * bbs_SIZEOF32( int32 ) );
    255 
    256 	iPivL = 0;
    257 
    258 	for( kL = 0; kL < sizeL; kL++ )
    259 	{
    260 		/* find pivot */
    261 		int32 maxAbsL = 0;
    262 		int32* pivRowL;
    263 
    264 		int32 bbp_pivRowL, bbp_vecL, shiftL;
    265 
    266 		jPivL = -1;
    267 		for( iL = 0; iL < sizeL; iL++ )
    268 		{
    269 			if( checkArrL[ iL ] != 1 )
    270 			{
    271 				int32* rowL = matL + ( iL * sizeL );
    272 				for( jL = 0; jL < sizeL; jL++ )
    273 				{
    274 					if( checkArrL[ jL ] == 0 )
    275 					{
    276 						int32 absElemL = rowL[ jL ];
    277 						if( absElemL < 0 ) absElemL = -absElemL;
    278 						if( maxAbsL < absElemL )
    279 						{
    280 							maxAbsL = absElemL;
    281 							iPivL = iL;
    282 							jPivL = jL;
    283 						}
    284 					}
    285 					else if( checkArrL[ jL ] > 1 )
    286 					{
    287 						return FALSE;
    288 					}
    289 				}
    290 			}
    291 		}
    292 
    293 		/* successfull ? */
    294 		if( jPivL < 0 )
    295 		{
    296 			return FALSE;
    297 		}
    298 
    299 		checkArrL[ jPivL ]++;
    300 
    301 		/* exchange rows to put pivot on diagonal, if neccessary */
    302 		if( iPivL != jPivL )
    303 		{
    304 			int32* row1PtrL = matL + ( iPivL * sizeL );
    305 			int32* row2PtrL = matL + ( jPivL * sizeL );
    306 			for( jL = 0; jL < sizeL; jL++ )
    307 			{
    308 				int32 tmpL = *row1PtrL;
    309 				*row1PtrL++ = *row2PtrL;
    310 				*row2PtrL++ = tmpL;
    311 			}
    312 
    313 			{
    314 				int32 tmpL = vecL[ jPivL ];
    315 				vecL[ jPivL ] = vecL[ iPivL ];
    316 				vecL[ iPivL ] = tmpL;
    317 			}
    318 		}
    319 		/* now index jPivL specifies pivot row and maximum element */
    320 
    321 
    322 		/**	Overflow protection: only if the highest bit of the largest matrix element is set,
    323 		 *	we need to shift the whole matrix and the right side vector 1 bit to the right,
    324 		 *	to make sure there can be no overflow when the pivot row gets subtracted from the
    325 		 *	other rows.
    326 		 *	Getting that close to overflow is a rare event, so this shift will happen only
    327 		 *	occasionally, or not at all.
    328 		 */
    329 		if( maxAbsL & 1073741824 )  /*( 1 << 30 )*/
    330 		{
    331 			/* right shift matrix by 1 */
    332 			int32 iL = sizeL * sizeL;
    333 			int32* ptrL = matL;
    334 			while( iL-- )
    335 			{
    336 				*ptrL = ( *ptrL + 1 ) >> 1;
    337 				ptrL++;
    338 			}
    339 
    340 			/* right shift right side vector by 1 */
    341 			iL = sizeL;
    342 			ptrL = vecL;
    343 			while( iL-- )
    344 			{
    345 				*ptrL = ( *ptrL + 1 ) >> 1;
    346 				ptrL++;
    347 			}
    348 
    349 			/* decrement bbpL */
    350 			bbpL--;
    351 		}
    352 
    353 
    354 		/* reduce elements of pivot row to 15 bit */
    355 		pivRowL = matL + jPivL * sizeL;
    356 		bbp_pivRowL = bbpL;
    357 		bts_Int32Mat_reduceToNBits( pivRowL, sizeL, &bbp_pivRowL, 15 );
    358 
    359 		/* scale pivot row such that maximum equals 1 */
    360 		{
    361 			int32 maxL = pivRowL[ jPivL ];
    362 			int32 bbp_maxL = bbp_pivRowL;
    363 			int32 factorL = 1073741824 / maxL; /*( 1 << 30 )*/
    364 
    365 			for( jL = 0; jL < sizeL; jL++ )
    366 			{
    367 				pivRowL[ jL ] = ( pivRowL[ jL ] * factorL + ( 1 << 14 ) ) >> 15;
    368 			}
    369 			bbp_pivRowL = 15;
    370 
    371 			/* set to 1 to avoid computational errors */
    372 			pivRowL[ jPivL ] = ( int32 )1 << bbp_pivRowL;
    373 
    374 			shiftL = 30 - bts_absIntLog2( vecL[ jPivL ] );
    375 
    376 			vecL[ jPivL ] = ( vecL[ jPivL ] << shiftL ) / maxL;
    377 			bbp_vecL = bbpL + shiftL - bbp_maxL;
    378 
    379 			bbs_int32ReduceToNBits( &( vecL[ jPivL ] ), &bbp_vecL, 15 );
    380 		}
    381 
    382 		/* subtract pivot row from all other rows */
    383 		for( iL = 0; iL < sizeL; iL++ )
    384 		{
    385 			if( iL != jPivL )
    386 			{
    387 				int32* rowPtrL = matL + iL * sizeL;
    388 
    389 				int32 tmpL = *( rowPtrL + jPivL );
    390 				int32 bbp_tmpL = bbpL;
    391 				bbs_int32ReduceToNBits( &tmpL, &bbp_tmpL, 15 );
    392 
    393 				shiftL = bbp_tmpL + bbp_pivRowL - bbpL;
    394 				if( shiftL > 0 )
    395 				{
    396 					for( jL = 0; jL < sizeL; jL++ )
    397 					{
    398 						*rowPtrL++ -= ( ( ( tmpL * pivRowL[ jL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1;
    399 					}
    400 				}
    401 				else
    402 				{
    403 					for( jL = 0; jL < sizeL; jL++ )
    404 					{
    405 						*rowPtrL++ -= ( tmpL * pivRowL[ jL ] ) << -shiftL;
    406 					}
    407 				}
    408 
    409 				shiftL = bbp_tmpL + bbp_vecL - bbpL;
    410 				if( shiftL > 0 )
    411 				{
    412 					vecL[ iL ] -= ( ( ( tmpL * vecL[ jPivL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1;
    413 				}
    414 				else
    415 				{
    416 					vecL[ iL ] -= ( tmpL * vecL[ jPivL ] ) << -shiftL;
    417 				}
    418 			}
    419 		}
    420 
    421 		/* change bbp of pivot row back to bbpL */
    422 		shiftL = bbpL - bbp_pivRowL;
    423 		if( shiftL >= 0 )
    424 		{
    425 			for( jL = 0; jL < sizeL; jL++ )
    426 			{
    427 				pivRowL[ jL ] <<= shiftL;
    428 			}
    429 		}
    430 		else
    431 		{
    432 			shiftL = -shiftL;
    433 			for( jL = 0; jL < sizeL; jL++ )
    434 			{
    435 				pivRowL[ jL ] = ( ( pivRowL[ jL ] >> ( shiftL - 1 ) ) + 1 ) >> 1;
    436 			}
    437 		}
    438 
    439 		shiftL = bbpL - bbp_vecL;
    440 		if( shiftL >= 0 )
    441 		{
    442 			vecL[ jPivL ] <<= shiftL;
    443 		}
    444 		else
    445 		{
    446 			shiftL = -shiftL;
    447 			vecL[ jPivL ] = ( ( vecL[ jPivL ] >> ( shiftL - 1 ) ) + 1 ) >> 1;
    448 		}
    449 /*
    450 if( sizeL <= 5 ) bts_Int32Mat_print( matL, vecL, sizeL, bbpL );
    451 */
    452 	}	/* of kL */
    453 
    454 	/* in case bbpL has been decreased by the overflow protection, change it back now */
    455 	if( bbpA > bbpL )
    456 	{
    457 		/* find largest element of solution vector */
    458 		int32 maxL = 0;
    459 		int32 iL, shiftL;
    460 		for( iL = 0; iL < sizeL; iL++ )
    461 		{
    462 			int32 xL = vecL[ iL ];
    463 			if( xL < 0 ) xL = -xL;
    464 			if( xL > maxL ) maxL = xL;
    465 		}
    466 
    467 		/* check whether we can left shift without overflow */
    468 		shiftL = 30 - bts_absIntLog2( maxL );
    469 		if( shiftL < ( bbpA - bbpL ) )
    470 		{
    471 			/*
    472 			    bbs_WARNING1( "flag bts_Int32Mat_solve2( ... ): getting overflow when trying to "
    473 				"compute solution vector with bbp = %d. Choose smaller bbp.\n", bbpA );
    474 			*/
    475 
    476 			return FALSE;
    477 		}
    478 
    479 		/* shift left */
    480 		shiftL = bbpA - bbpL;
    481 		for( iL = 0; iL < sizeL; iL++ ) vecL[ iL ] <<= shiftL;
    482 	}
    483 
    484 	return TRUE;
    485 }
    486 
    487 /* ------------------------------------------------------------------------- */
    488 
    489 /* ========================================================================= */
    490 
    491