Home | History | Annotate | Download | only in b_TensorEm
      1 /*
      2  * Copyright (C) 2008 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 /* ---- includes ----------------------------------------------------------- */
     18 
     19 #include "b_TensorEm/CompactMat.h"
     20 #include "b_TensorEm/Functions.h"
     21 #include "b_BasicEm/Math.h"
     22 #include "b_BasicEm/Functions.h"
     23 #include "b_BasicEm/Memory.h"
     24 
     25 /* ------------------------------------------------------------------------- */
     26 
     27 /* ========================================================================= */
     28 /*                                                                           */
     29 /* ---- \ghd{ auxiliary functions } ---------------------------------------- */
     30 /*                                                                           */
     31 /* ========================================================================= */
     32 
     33 /* ------------------------------------------------------------------------- */
     34 
     35 /** Returns dot product of inVec with indexed row
     36     The result is a floating point expresstion:
     37 		upper 16 bit: signed value
     38 		lower 16 bit: signed exponent
     39  */
     40 int32 bts_CompactMat_fltDotPrdRow( struct bbs_Context* cpA,
     41 								   struct bts_CompactMat* ptrA,
     42 							       const int16* inVecA,
     43 							       uint32 inNormBitsA,
     44 							       uint32 rowA )
     45 {
     46 	const int16* rowPtrL = ptrA->cpsArrE.arrPtrE + ptrA->wordsPerRowE * rowA;
     47 
     48 	/* extract row-header info */
     49 	uint32 offsL = *rowPtrL++;
     50 	uint32 sizeL = *rowPtrL++;
     51 	int32 factorManL = *rowPtrL++;
     52 	int32 factorExpL = *rowPtrL++;
     53 	uint32 rowNormBitsL = *rowPtrL++;
     54 
     55 	/* consider possible overflow */
     56 	uint16 overflowBitsL = ( inNormBitsA + rowNormBitsL >= 31 ) ? inNormBitsA + rowNormBitsL - 31 : 0;
     57 
     58 	const int16* inPtrL = inVecA + offsL;
     59 
     60 	count_t iL;
     61 	int32 sumL = 0;
     62 
     63 	if( overflowBitsL == 0 ) /* raw dot product fits in int32 */
     64 	{
     65 		switch( ptrA->bitsPerValueE )
     66 		{
     67 			case 16:
     68 			{
     69 				for( iL = sizeL; iL > 0; iL-- ) sumL += ( ( int32 )*rowPtrL++ * ( int32 )*inPtrL++ );
     70 			}
     71 			break;
     72 
     73 			#ifndef HW_TMS320C5x /* platforms that don't have int8 must use the 'default' implementation */
     74 
     75 			case 8:
     76 			{
     77 				const uint16* dpL = ( uint16* )rowPtrL;
     78 				for( iL = sizeL; iL >= 8; iL -= 8 )
     79 				{
     80 					sumL += ( ( int8 )  dpL[ 0 ]         * ( int32 )inPtrL[ 0 ] );
     81 					sumL += ( ( int8 )( dpL[ 0 ] >>  8 ) * ( int32 )inPtrL[ 1 ] );
     82 					sumL += ( ( int8 )  dpL[ 1 ]         * ( int32 )inPtrL[ 2 ] );
     83 					sumL += ( ( int8 )( dpL[ 1 ] >>  8 ) * ( int32 )inPtrL[ 3 ] );
     84 					sumL += ( ( int8 )  dpL[ 2 ]         * ( int32 )inPtrL[ 4 ] );
     85 					sumL += ( ( int8 )( dpL[ 2 ] >>  8 ) * ( int32 )inPtrL[ 5 ] );
     86 					sumL += ( ( int8 )  dpL[ 3 ]         * ( int32 )inPtrL[ 6 ] );
     87 					sumL += ( ( int8 )( dpL[ 3 ] >>  8 ) * ( int32 )inPtrL[ 7 ] );
     88 					dpL += 4;
     89 					inPtrL += 8;
     90 				}
     91 				for( ; iL >= 2; iL -= 2 )
     92 				{
     93 					sumL += ( ( int8 )  *dpL         * ( int32 )inPtrL[ 0 ] );
     94 					sumL += ( ( int8 )( *dpL >>  8 ) * ( int32 )inPtrL[ 1 ] );
     95 					dpL++;
     96 					inPtrL += 2;
     97 				}
     98 				if( iL > 0 )
     99 				{
    100 					sumL += ( ( int8 )*dpL++ * ( int32 )inPtrL[ 0 ] );
    101 				}
    102 			}
    103 			break;
    104 
    105 			case 6:
    106 			{
    107 				const uint16* dpL = ( uint16* )rowPtrL;
    108 				for( iL = sizeL; iL >= 8; iL -= 8 )
    109 				{
    110 					int32 lSumL = 0;
    111 					lSumL += ( ( int8 )     ( dpL[ 0 ] <<  2 )                                  * ( int32 )inPtrL[ 0 ] );
    112 					lSumL += ( ( int8 ) (   ( dpL[ 0 ] >>  4 )                       & 0x00FC ) * ( int32 )inPtrL[ 1 ] );
    113 					lSumL += ( ( int8 ) ( ( ( dpL[ 0 ] >> 10 ) | ( dpL[ 1 ] << 6 ) ) & 0x00FC ) * ( int32 )inPtrL[ 2 ] );
    114 					lSumL += ( ( int8 ) (   ( dpL[ 1 ]       )                       & 0x00FC ) * ( int32 )inPtrL[ 3 ] );
    115 					lSumL += ( ( int8 ) (   ( dpL[ 1 ] >>  6 )                       & 0x00FC ) * ( int32 )inPtrL[ 4 ] );
    116 					lSumL += ( ( int8 ) ( ( ( dpL[ 1 ] >> 12 ) | ( dpL[ 2 ] << 4 ) ) & 0x00FC ) * ( int32 )inPtrL[ 5 ] );
    117 					lSumL += ( ( int8 ) (   ( dpL[ 2 ] >>  2 )                       & 0x00FC ) * ( int32 )inPtrL[ 6 ] );
    118 					lSumL += ( ( int8 ) (   ( dpL[ 2 ] >>  8 )                       & 0x00FC ) * ( int32 )inPtrL[ 7 ] );
    119 					sumL += ( lSumL >> 2 );
    120 					dpL += 3;
    121 					inPtrL += 8;
    122 				}
    123 
    124 				{
    125 					int32 lSumL = 0;
    126 					if( iL > 0 ) lSumL += ( ( int8 )     ( dpL[ 0 ] <<  2 )                                  * ( int32 )inPtrL[ 0 ] );
    127 					if( iL > 1 ) lSumL += ( ( int8 ) (   ( dpL[ 0 ] >>  4 )                       & 0x00FC ) * ( int32 )inPtrL[ 1 ] );
    128 					if( iL > 2 ) lSumL += ( ( int8 ) ( ( ( dpL[ 0 ] >> 10 ) | ( dpL[ 1 ] << 6 ) ) & 0x00FC ) * ( int32 )inPtrL[ 2 ] );
    129 					if( iL > 3 ) lSumL += ( ( int8 ) (   ( dpL[ 1 ]       )                       & 0x00FC ) * ( int32 )inPtrL[ 3 ] );
    130 					if( iL > 4 ) lSumL += ( ( int8 ) (   ( dpL[ 1 ] >>  6 )                       & 0x00FC ) * ( int32 )inPtrL[ 4 ] );
    131 					if( iL > 5 ) lSumL += ( ( int8 ) ( ( ( dpL[ 1 ] >> 12 ) | ( dpL[ 2 ] << 4 ) ) & 0x00FC ) * ( int32 )inPtrL[ 5 ] );
    132 					if( iL > 6 ) lSumL += ( ( int8 ) (   ( dpL[ 2 ] >>  2 )                       & 0x00FC ) * ( int32 )inPtrL[ 6 ] );
    133 					sumL += ( lSumL >> 2 );
    134 				}
    135 			}
    136 			break;
    137 
    138 			case 5:
    139 			{
    140 				const uint16* dpL = ( uint16* )rowPtrL;
    141 				for( iL = sizeL; iL >= 16; iL -= 16 )
    142 				{
    143 					int32 lSumL = 0;
    144 					lSumL += ( ( int8 )     ( dpL[ 0 ] <<  3 )                                  * ( int32 )inPtrL[  0 ] );
    145 					lSumL += ( ( int8 ) (   ( dpL[ 0 ] >>  2 )                       & 0x00F8 ) * ( int32 )inPtrL[  1 ] );
    146 					lSumL += ( ( int8 ) (   ( dpL[ 0 ] >>  7 )                       & 0x00F8 ) * ( int32 )inPtrL[  2 ] );
    147 					lSumL += ( ( int8 ) ( ( ( dpL[ 0 ] >> 12 ) | ( dpL[ 1 ] << 4 ) ) & 0x00F8 ) * ( int32 )inPtrL[  3 ] );
    148 					lSumL += ( ( int8 ) (   ( dpL[ 1 ] >>  1 )                       & 0x00F8 ) * ( int32 )inPtrL[  4 ] );
    149 					lSumL += ( ( int8 ) (   ( dpL[ 1 ] >>  6 )                       & 0x00F8 ) * ( int32 )inPtrL[  5 ] );
    150 					lSumL += ( ( int8 ) ( ( ( dpL[ 1 ] >> 11 ) | ( dpL[ 2 ] << 5 ) ) & 0x00F8 ) * ( int32 )inPtrL[  6 ] );
    151 					lSumL += ( ( int8 ) (   ( dpL[ 2 ]       )                       & 0x00F8 ) * ( int32 )inPtrL[  7 ] );
    152 					lSumL += ( ( int8 ) (   ( dpL[ 2 ] >>  5 )                       & 0x00F8 ) * ( int32 )inPtrL[  8 ] );
    153 					lSumL += ( ( int8 ) ( ( ( dpL[ 2 ] >> 10 ) | ( dpL[ 3 ] << 6 ) ) & 0x00F8 ) * ( int32 )inPtrL[  9 ] );
    154 					lSumL += ( ( int8 ) (   ( dpL[ 3 ] <<  1 )                       & 0x00F8 ) * ( int32 )inPtrL[ 10 ] );
    155 					lSumL += ( ( int8 ) (   ( dpL[ 3 ] >>  4 )                       & 0x00F8 ) * ( int32 )inPtrL[ 11 ] );
    156 					lSumL += ( ( int8 ) ( ( ( dpL[ 3 ] >>  9 ) | ( dpL[ 4 ] << 7 ) ) & 0x00F8 ) * ( int32 )inPtrL[ 12 ] );
    157 					lSumL += ( ( int8 ) (   ( dpL[ 4 ] <<  2 )                       & 0x00F8 ) * ( int32 )inPtrL[ 13 ] );
    158 					lSumL += ( ( int8 ) (   ( dpL[ 4 ] >>  3 )                       & 0x00F8 ) * ( int32 )inPtrL[ 14 ] );
    159 					lSumL += ( ( int8 ) (   ( dpL[ 4 ] >>  8 )                       & 0x00F8 ) * ( int32 )inPtrL[ 15 ] );
    160 					sumL += ( lSumL >> 3 );
    161 					dpL += 5;
    162 					inPtrL += 16;
    163 				}
    164 
    165 				{
    166 					int32 lSumL = 0;
    167 					if( iL >  0 ) lSumL += ( ( int8 )     ( dpL[ 0 ] <<  3 )                                  * ( int32 )inPtrL[  0 ] );
    168 					if( iL >  1 ) lSumL += ( ( int8 ) (   ( dpL[ 0 ] >>  2 )                       & 0x00F8 ) * ( int32 )inPtrL[  1 ] );
    169 					if( iL >  2 ) lSumL += ( ( int8 ) (   ( dpL[ 0 ] >>  7 )                       & 0x00F8 ) * ( int32 )inPtrL[  2 ] );
    170 					if( iL >  3 ) lSumL += ( ( int8 ) ( ( ( dpL[ 0 ] >> 12 ) | ( dpL[ 1 ] << 4 ) ) & 0x00F8 ) * ( int32 )inPtrL[  3 ] );
    171 					if( iL >  4 ) lSumL += ( ( int8 ) (   ( dpL[ 1 ] >>  1 )                       & 0x00F8 ) * ( int32 )inPtrL[  4 ] );
    172 					if( iL >  5 ) lSumL += ( ( int8 ) (   ( dpL[ 1 ] >>  6 )                       & 0x00F8 ) * ( int32 )inPtrL[  5 ] );
    173 					if( iL >  6 ) lSumL += ( ( int8 ) ( ( ( dpL[ 1 ] >> 11 ) | ( dpL[ 2 ] << 5 ) ) & 0x00F8 ) * ( int32 )inPtrL[  6 ] );
    174 					if( iL >  7 ) lSumL += ( ( int8 ) (   ( dpL[ 2 ]       )                       & 0x00F8 ) * ( int32 )inPtrL[  7 ] );
    175 					if( iL >  8 ) lSumL += ( ( int8 ) (   ( dpL[ 2 ] >>  5 )                       & 0x00F8 ) * ( int32 )inPtrL[  8 ] );
    176 					if( iL >  9 ) lSumL += ( ( int8 ) ( ( ( dpL[ 2 ] >> 10 ) | ( dpL[ 3 ] << 6 ) ) & 0x00F8 ) * ( int32 )inPtrL[  9 ] );
    177 					if( iL > 10 ) lSumL += ( ( int8 ) (   ( dpL[ 3 ] <<  1 )                       & 0x00F8 ) * ( int32 )inPtrL[ 10 ] );
    178 					if( iL > 11 ) lSumL += ( ( int8 ) (   ( dpL[ 3 ] >>  4 )                       & 0x00F8 ) * ( int32 )inPtrL[ 11 ] );
    179 					if( iL > 12 ) lSumL += ( ( int8 ) ( ( ( dpL[ 3 ] >>  9 ) | ( dpL[ 4 ] << 7 ) ) & 0x00F8 ) * ( int32 )inPtrL[ 12 ] );
    180 					if( iL > 13 ) lSumL += ( ( int8 ) (   ( dpL[ 4 ] <<  2 )                       & 0x00F8 ) * ( int32 )inPtrL[ 13 ] );
    181 					if( iL > 14 ) lSumL += ( ( int8 ) (   ( dpL[ 4 ] >>  3 )                       & 0x00F8 ) * ( int32 )inPtrL[ 14 ] );
    182 					sumL += ( lSumL >> 3 );
    183 				}
    184 			}
    185 			break;
    186 
    187 			case 4:
    188 			{
    189 				for( iL = sizeL; iL >= 4; iL -= 4 )
    190 				{
    191 					uint16 v1L = *rowPtrL++;
    192 					int32 lSumL = 0;
    193 					lSumL += ( ( int8 )( ( v1L << 4 )        ) * ( int32 )inPtrL[ 0 ] );
    194 					lSumL += ( ( int8 )( ( v1L      ) & 0xF0 ) * ( int32 )inPtrL[ 1 ] );
    195 					lSumL += ( ( int8 )( ( v1L >> 4 ) & 0xF0 ) * ( int32 )inPtrL[ 2 ] );
    196 					lSumL += ( ( int8 )( ( v1L >> 8 ) & 0xF0 ) * ( int32 )inPtrL[ 3 ] );
    197 					inPtrL += 4;
    198 					sumL += ( lSumL >> 4 );
    199 				}
    200 				{
    201 					uint16 v1L = *rowPtrL++;
    202 					int32 lSumL = 0;
    203 					if( iL-- > 0 ) lSumL += ( ( int8 )( ( v1L << 4 )        ) * ( int32 )inPtrL[ 0 ] );
    204 					if( iL-- > 0 ) lSumL += ( ( int8 )( ( v1L      ) & 0xF0 ) * ( int32 )inPtrL[ 1 ] );
    205 					if( iL-- > 0 ) lSumL += ( ( int8 )( ( v1L >> 4 ) & 0xF0 ) * ( int32 )inPtrL[ 2 ] );
    206 					sumL += ( lSumL >> 4 );
    207 				}
    208 			}
    209 			break;
    210 
    211 			#endif /*ifndef HW_TMS320C5x*/
    212 
    213 			/* The default case can process all bit sizes including those that are explicitly encoded above
    214 			 * Use the default for all bit sizes when the platform cannot handle the int8 data type (e.g. HW_TMS320C5x)
    215 			 */
    216 			default:
    217 			{
    218 				uint32 bfL = ( ( uint32 )*rowPtrL++ ) << 16;
    219 				uint32 bitsL = ptrA->bitsPerValueE;
    220 				uint16 adjL = 16 - bitsL;
    221 				uint32 mkL = ( ( 1 << bitsL ) - 1 ) << adjL;
    222 				uint32 srL = bitsL;
    223 				for( iL = 0; iL < sizeL; iL++ )
    224 				{
    225 					if( srL > 16 )
    226 					{
    227 						bfL = ( ( ( uint32 )*rowPtrL++ ) << 16 ) | ( bfL >> 16 );
    228 						srL -= 16;
    229 					}
    230 					sumL += ( ( int16 )( ( bfL >> srL ) & mkL ) * ( int32 )inPtrL[ iL ] ) >> adjL;
    231 					srL += bitsL;
    232 				}
    233 			}
    234 		}
    235 	}
    236 	else /* raw dot product does not fit in int32 */
    237 	{
    238 		int32 roundL = 1 << ( overflowBitsL - 1 );
    239 		switch( ptrA->bitsPerValueE )
    240 		{
    241 			case 16:
    242 			{
    243 				for( iL = sizeL; iL > 0; iL-- ) sumL += ( ( ( int32 )*rowPtrL++ * ( int32 )*inPtrL++ ) + roundL ) >> overflowBitsL;
    244 			}
    245 			break;
    246 
    247 			case 8:
    248 			{
    249 				for( iL = sizeL; iL >= 2; iL -= 2 )
    250 				{
    251 					uint16 v1L = *rowPtrL++;
    252 					int32 lSumL =   ( ( int8 )  v1L         * ( int32 )inPtrL[ 0 ] )
    253 						          + ( ( int8 )( v1L >>  8 ) * ( int32 )inPtrL[ 1 ] );
    254 					sumL += ( lSumL + roundL ) >> overflowBitsL;
    255 					inPtrL += 2;
    256 				}
    257 				if( iL > 0 )
    258 				{
    259 					sumL += ( ( ( int8 )*rowPtrL++ * ( int32 )inPtrL[ 0 ] ) + roundL ) >> overflowBitsL;
    260 				}
    261 			}
    262 			break;
    263 
    264 			case 4:
    265 			{
    266 				for( iL = sizeL; iL >= 4; iL -= 4 )
    267 				{
    268 					uint16 v1L = *rowPtrL++;
    269 					int32 lSumL = 0;
    270 					lSumL += ( ( int8 )( ( v1L << 4 )        ) * ( int32 )inPtrL[ 0 ] );
    271 					lSumL += ( ( int8 )( ( v1L      ) & 0xF0 ) * ( int32 )inPtrL[ 1 ] );
    272 					lSumL += ( ( int8 )( ( v1L >> 4 ) & 0xF0 ) * ( int32 )inPtrL[ 2 ] );
    273 					lSumL += ( ( int8 )( ( v1L >> 8 ) & 0xF0 ) * ( int32 )inPtrL[ 3 ] );
    274 					inPtrL += 4;
    275 					sumL += ( ( lSumL >> 4 ) + roundL ) >> overflowBitsL;
    276 				}
    277 				{
    278 					uint16 v1L = *rowPtrL++;
    279 					int32 lSumL = 0;
    280 					if( iL-- > 0 ) lSumL += ( ( int8 )( ( v1L << 4 )        ) * ( int32 )inPtrL[ 0 ] );
    281 					if( iL-- > 0 ) lSumL += ( ( int8 )( ( v1L      ) & 0xF0 ) * ( int32 )inPtrL[ 1 ] );
    282 					if( iL-- > 0 ) lSumL += ( ( int8 )( ( v1L >> 4 ) & 0xF0 ) * ( int32 )inPtrL[ 2 ] );
    283 					sumL += ( ( lSumL >> 4 ) + roundL ) >> overflowBitsL;
    284 				}
    285 			}
    286 			break;
    287 
    288 			default:
    289 			{
    290 				uint32 bfL = ( ( uint32 )*rowPtrL++ ) << 16;
    291 				uint32 bitsL = ptrA->bitsPerValueE;
    292 				uint16 adjL = 16 - bitsL;
    293 				uint32 mkL = ( ( 1 << bitsL ) - 1 ) << adjL;
    294 				uint32 srL = bitsL;
    295 				int32 lRoundL = roundL << adjL;
    296 				int32 lAdjL = overflowBitsL + adjL;
    297 				for( iL = 0; iL < sizeL; iL++ )
    298 				{
    299 					if( srL > 16 )
    300 					{
    301 						bfL = ( ( ( uint32 )*rowPtrL++ ) << 16 ) | ( bfL >> 16 );
    302 						srL -= 16;
    303 					}
    304 					sumL += ( ( int16 )( ( bfL >> srL ) & mkL ) * ( int32 )inPtrL[ iL ] + lRoundL ) >> lAdjL;
    305 					srL += bitsL;
    306 				}
    307 			}
    308 		}
    309 	}
    310 
    311 	/* compute result */
    312 	{
    313 		int32 resultManL;
    314 		int32 resultExpL;
    315 		int32 resultLogL;
    316 		bbs_mulS32( sumL, factorManL, &resultManL, &resultExpL );
    317 		resultExpL += factorExpL + overflowBitsL;
    318 		resultLogL = bbs_intLog2( resultManL > 0 ? resultManL : -resultManL );
    319 		if( resultLogL < 30 )
    320 		{
    321 			resultManL <<= 30 - resultLogL;
    322 			resultExpL  -= 30 - resultLogL;
    323 		}
    324 
    325 		resultManL = ( ( resultManL >> 15 ) + 1 ) >> 1;
    326 		resultExpL = resultExpL + 16;
    327 
    328 		return ( ( resultManL & 0x0000FFFF ) << 16 ) | ( resultExpL & 0x0000FFFF );
    329 	}
    330 }
    331 
    332 /* ------------------------------------------------------------------------- */
    333 
    334 /* ========================================================================= */
    335 /*                                                                           */
    336 /* ---- \ghd{ constructor / destructor } ----------------------------------- */
    337 /*                                                                           */
    338 /* ========================================================================= */
    339 
    340 /* ------------------------------------------------------------------------- */
    341 
    342 void bts_CompactMat_init( struct bbs_Context* cpA,
    343 					      struct bts_CompactMat* ptrA )
    344 {
    345 	ptrA->widthE = 0;
    346 	ptrA->heightE = 0;
    347 	ptrA->bitsPerValueE = 0;
    348 	ptrA->wordsPerRowE = 0;
    349 	ptrA->maxRowBitsE = 0;
    350 	bbs_Int16Arr_init( cpA, &ptrA->cpsArrE );
    351 	bbs_Int16Arr_init( cpA, &ptrA->expArrE );
    352 
    353 }
    354 
    355 /* ------------------------------------------------------------------------- */
    356 
    357 void bts_CompactMat_exit( struct bbs_Context* cpA,
    358 					    struct bts_CompactMat* ptrA )
    359 {
    360 	ptrA->widthE = 0;
    361 	ptrA->heightE = 0;
    362 	ptrA->bitsPerValueE = 0;
    363 	ptrA->wordsPerRowE = 0;
    364 	ptrA->maxRowBitsE = 0;
    365 	bbs_Int16Arr_exit( cpA, &ptrA->cpsArrE );
    366 	bbs_Int16Arr_exit( cpA, &ptrA->expArrE );
    367 }
    368 /* ------------------------------------------------------------------------- */
    369 
    370 /* ========================================================================= */
    371 /*                                                                           */
    372 /* ---- \ghd{ operators } -------------------------------------------------- */
    373 /*                                                                           */
    374 /* ========================================================================= */
    375 
    376 /* ------------------------------------------------------------------------- */
    377 
    378 /* ========================================================================= */
    379 /*                                                                           */
    380 /* ---- \ghd{ query functions } -------------------------------------------- */
    381 /*                                                                           */
    382 /* ========================================================================= */
    383 
    384 /* ------------------------------------------------------------------------- */
    385 
    386 /* ========================================================================= */
    387 /*                                                                           */
    388 /* ---- \ghd{ modify functions } ------------------------------------------- */
    389 /*                                                                           */
    390 /* ========================================================================= */
    391 
    392 /* ------------------------------------------------------------------------- */
    393 
    394 void bts_CompactMat_create( struct bbs_Context* cpA,
    395 						    struct bts_CompactMat* ptrA,
    396 						    uint32 widthA,
    397 						    uint32 heightA,
    398 						    uint32 bitsA,
    399 							uint32 maxRowSizeA,
    400 				            struct bbs_MemSeg* mspA )
    401 {
    402 	if( bbs_Context_error( cpA ) ) return;
    403 	if( bitsA < 2 || bitsA > 16 )
    404 	{
    405 		bbs_ERROR0( "bts_CompactMat_create:\nbitsA must be between 2 and 16" );
    406 		return;
    407 	}
    408 
    409 	ptrA->widthE = widthA;
    410 	ptrA->heightE = heightA;
    411 	ptrA->bitsPerValueE = bitsA;
    412 	ptrA->wordsPerRowE = 6 /*header + 1*/ + ( ( maxRowSizeA * bitsA ) / ( 8 * sizeof( short ) ) );
    413 	ptrA->maxRowBitsE = 0;
    414 	if( ( ptrA->wordsPerRowE & 1 ) != 0 ) ptrA->wordsPerRowE++;
    415 	bbs_Int16Arr_create( cpA, &ptrA->cpsArrE, heightA * ptrA->wordsPerRowE, mspA );
    416 	bbs_Int16Arr_fill( cpA, &ptrA->cpsArrE, 0 );
    417 	bbs_Int16Arr_create( cpA, &ptrA->expArrE, ptrA->heightE, mspA );
    418 	bbs_Int16Arr_fill( cpA, &ptrA->expArrE, 0 );
    419 }
    420 
    421 /* ------------------------------------------------------------------------- */
    422 
    423 void bts_CompactMat_copy( struct bbs_Context* cpA,
    424 					      struct bts_CompactMat* ptrA,
    425 						  const struct bts_CompactMat* srcPtrA )
    426 {
    427 	ptrA->widthE = srcPtrA->widthE;
    428 	ptrA->heightE = srcPtrA->heightE;
    429 	ptrA->bitsPerValueE = srcPtrA->bitsPerValueE;
    430 	ptrA->wordsPerRowE = srcPtrA->wordsPerRowE;
    431 	ptrA->maxRowBitsE = srcPtrA->maxRowBitsE;
    432 	bbs_Int16Arr_copy( cpA, &ptrA->cpsArrE, &srcPtrA->cpsArrE );
    433 	bbs_Int16Arr_size( cpA, &ptrA->expArrE, ptrA->heightE );
    434 }
    435 
    436 /* ------------------------------------------------------------------------- */
    437 
    438 /* ========================================================================= */
    439 /*                                                                           */
    440 /* ---- \ghd{ I/O } -------------------------------------------------------- */
    441 /*                                                                           */
    442 /* ========================================================================= */
    443 
    444 /* ------------------------------------------------------------------------- */
    445 
    446 uint32 bts_CompactMat_memSize( struct bbs_Context* cpA,
    447 							 const struct bts_CompactMat *ptrA )
    448 {
    449 	return  bbs_SIZEOF16( uint32 )
    450 		  + bbs_SIZEOF16( uint32 ) /* version */
    451 		  + bbs_SIZEOF16( ptrA->widthE )
    452 		  + bbs_SIZEOF16( ptrA->heightE )
    453 		  + bbs_SIZEOF16( ptrA->bitsPerValueE )
    454 		  + bbs_SIZEOF16( ptrA->wordsPerRowE )
    455 		  + bbs_SIZEOF16( ptrA->maxRowBitsE )
    456 		  + bbs_Int16Arr_memSize( cpA, &ptrA->cpsArrE );
    457 }
    458 
    459 /* ------------------------------------------------------------------------- */
    460 
    461 uint32 bts_CompactMat_memWrite( struct bbs_Context* cpA,
    462 							  const struct bts_CompactMat* ptrA,
    463 							  uint16* memPtrA )
    464 {
    465 	uint32 memSizeL = bts_CompactMat_memSize( cpA, ptrA );
    466 	memPtrA += bbs_memWrite32( &memSizeL, memPtrA );
    467 	memPtrA += bbs_memWriteUInt32( bts_COMPACT_MAT_VERSION, memPtrA );
    468 	memPtrA += bbs_memWrite32( &ptrA->widthE, memPtrA );
    469 	memPtrA += bbs_memWrite32( &ptrA->heightE, memPtrA );
    470 	memPtrA += bbs_memWrite32( &ptrA->bitsPerValueE, memPtrA );
    471 	memPtrA += bbs_memWrite32( &ptrA->wordsPerRowE, memPtrA );
    472 	memPtrA += bbs_memWrite32( &ptrA->maxRowBitsE, memPtrA );
    473 	memPtrA += bbs_Int16Arr_memWrite( cpA, &ptrA->cpsArrE, memPtrA );
    474 	return memSizeL;
    475 }
    476 
    477 /* ------------------------------------------------------------------------- */
    478 
    479 uint32 bts_CompactMat_memRead( struct bbs_Context* cpA,
    480 							 struct bts_CompactMat* ptrA,
    481 							 const uint16* memPtrA,
    482 				             struct bbs_MemSeg* mspA )
    483 {
    484 	uint32 memSizeL, versionL;
    485 	if( bbs_Context_error( cpA ) ) return 0;
    486 	memPtrA += bbs_memRead32( &memSizeL, memPtrA );
    487 	memPtrA += bbs_memReadVersion32( cpA, &versionL, bts_COMPACT_MAT_VERSION, memPtrA );
    488 	memPtrA += bbs_memRead32( &ptrA->widthE, memPtrA );
    489 	memPtrA += bbs_memRead32( &ptrA->heightE, memPtrA );
    490 	memPtrA += bbs_memRead32( &ptrA->bitsPerValueE, memPtrA );
    491 	memPtrA += bbs_memRead32( &ptrA->wordsPerRowE, memPtrA );
    492 	memPtrA += bbs_memRead32( &ptrA->maxRowBitsE, memPtrA );
    493 	memPtrA += bbs_Int16Arr_memRead( cpA, &ptrA->cpsArrE, memPtrA, mspA );
    494 
    495 	if( memSizeL != bts_CompactMat_memSize( cpA, ptrA ) )
    496 	{
    497 		bbs_ERR0( bbs_ERR_CORRUPT_DATA, "uint32 bts_CompactMat_memRead( const struct bts_CompactMat* ptrA, const void* memPtrA ):\n"
    498                   "size mismatch" );
    499 	}
    500 
    501 	bbs_Int16Arr_create( cpA, &ptrA->expArrE, ptrA->heightE, mspA );
    502 	bbs_Int16Arr_fill( cpA, &ptrA->expArrE, 0 );
    503 
    504 	return memSizeL;
    505 }
    506 
    507 /* ------------------------------------------------------------------------- */
    508 
    509 /* ========================================================================= */
    510 /*                                                                           */
    511 /* ---- \ghd{ exec functions } --------------------------------------------- */
    512 /*                                                                           */
    513 /* ========================================================================= */
    514 
    515 /* ------------------------------------------------------------------------- */
    516 
    517 void bts_CompactMat_map( struct bbs_Context* cpA,
    518 						 const struct bts_CompactMat* ptrA,
    519 						 const int16* inVecA,
    520 						 int16* outVecA,
    521 						 int16* outExpPtrA )
    522 {
    523 	uint32 inNormBitsL = bbs_intLog2( bbs_vecNorm16( inVecA, ptrA->widthE ) ) + 1;
    524 	uint32 iL;
    525 
    526 	int16* expArrL = ( ( struct bts_CompactMat* )ptrA )->expArrE.arrPtrE;
    527 	int16 maxExpL = -32767;
    528 
    529 	for( iL = 0; iL < ptrA->heightE; iL++ )
    530 	{
    531 		int32 fltL = bts_CompactMat_fltDotPrdRow( cpA, ( struct bts_CompactMat* )ptrA, inVecA, inNormBitsL, iL );
    532 		outVecA[ iL ] = fltL >> 16;
    533 		expArrL[ iL ] = fltL & 0x0000FFFF;
    534 
    535 		maxExpL = ( expArrL[ iL ] > maxExpL ) ? expArrL[ iL ] : maxExpL;
    536 	}
    537 
    538 	if( outExpPtrA != NULL ) *outExpPtrA = maxExpL;
    539 
    540 	for( iL = 0; iL < ptrA->heightE; iL++ )
    541 	{
    542 		int32 shrL = maxExpL - expArrL[ iL ];
    543 		if( shrL > 0 )
    544 		{
    545 			outVecA[ iL ] = ( ( outVecA[ iL ] >> ( shrL - 1 ) ) + 1 ) >> 1;
    546 		}
    547 	}
    548 }
    549 
    550 /* ------------------------------------------------------------------------- */
    551 
    552 /* ========================================================================= */
    553 
    554