1 /* 2 * LZMAEncoder 3 * 4 * Authors: Lasse Collin <lasse.collin (at) tukaani.org> 5 * Igor Pavlov <http://7-zip.org/> 6 * 7 * This file has been put into the public domain. 8 * You can do whatever you want with this file. 9 */ 10 11 package org.tukaani.xz.lzma; 12 13 import org.tukaani.xz.lz.LZEncoder; 14 import org.tukaani.xz.lz.Matches; 15 import org.tukaani.xz.rangecoder.RangeEncoder; 16 17 public abstract class LZMAEncoder extends LZMACoder { 18 public static final int MODE_FAST = 1; 19 public static final int MODE_NORMAL = 2; 20 21 /** 22 * LZMA2 chunk is considered full when its uncompressed size exceeds 23 * <code>LZMA2_UNCOMPRESSED_LIMIT</code>. 24 * <p> 25 * A compressed LZMA2 chunk can hold 2 MiB of uncompressed data. 26 * A single LZMA symbol may indicate up to MATCH_LEN_MAX bytes 27 * of data, so the LZMA2 chunk is considered full when there is 28 * less space than MATCH_LEN_MAX bytes. 29 */ 30 private static final int LZMA2_UNCOMPRESSED_LIMIT 31 = (2 << 20) - MATCH_LEN_MAX; 32 33 /** 34 * LZMA2 chunk is considered full when its compressed size exceeds 35 * <code>LZMA2_COMPRESSED_LIMIT</code>. 36 * <p> 37 * The maximum compressed size of a LZMA2 chunk is 64 KiB. 38 * A single LZMA symbol might use 20 bytes of space even though 39 * it usually takes just one byte or so. Two more bytes are needed 40 * for LZMA2 uncompressed chunks (see LZMA2OutputStream.writeChunk). 41 * Leave a little safety margin and use 26 bytes. 42 */ 43 private static final int LZMA2_COMPRESSED_LIMIT = (64 << 10) - 26; 44 45 private static final int DIST_PRICE_UPDATE_INTERVAL = FULL_DISTANCES; 46 private static final int ALIGN_PRICE_UPDATE_INTERVAL = ALIGN_SIZE; 47 48 private final RangeEncoder rc; 49 final LZEncoder lz; 50 final LiteralEncoder literalEncoder; 51 final LengthEncoder matchLenEncoder; 52 final LengthEncoder repLenEncoder; 53 final int niceLen; 54 55 private int distPriceCount = 0; 56 private int alignPriceCount = 0; 57 58 private final int distSlotPricesSize; 59 private final int[][] distSlotPrices; 60 private final int[][] fullDistPrices 61 = new int[DIST_STATES][FULL_DISTANCES]; 62 private final int[] alignPrices = new int[ALIGN_SIZE]; 63 64 int back = 0; 65 int readAhead = -1; 66 private int uncompressedSize = 0; 67 68 public static int getMemoryUsage(int mode, int dictSize, 69 int extraSizeBefore, int mf) { 70 int m = 80; 71 72 switch (mode) { 73 case MODE_FAST: 74 m += LZMAEncoderFast.getMemoryUsage( 75 dictSize, extraSizeBefore, mf); 76 break; 77 78 case MODE_NORMAL: 79 m += LZMAEncoderNormal.getMemoryUsage( 80 dictSize, extraSizeBefore, mf); 81 break; 82 83 default: 84 throw new IllegalArgumentException(); 85 } 86 87 return m; 88 } 89 90 public static LZMAEncoder getInstance( 91 RangeEncoder rc, int lc, int lp, int pb, int mode, 92 int dictSize, int extraSizeBefore, 93 int niceLen, int mf, int depthLimit) { 94 switch (mode) { 95 case MODE_FAST: 96 return new LZMAEncoderFast(rc, lc, lp, pb, 97 dictSize, extraSizeBefore, 98 niceLen, mf, depthLimit); 99 100 case MODE_NORMAL: 101 return new LZMAEncoderNormal(rc, lc, lp, pb, 102 dictSize, extraSizeBefore, 103 niceLen, mf, depthLimit); 104 } 105 106 throw new IllegalArgumentException(); 107 } 108 109 /** 110 * Gets an integer [0, 63] matching the highest two bits of an integer. 111 * This is like bit scan reverse (BSR) on x86 except that this also 112 * cares about the second highest bit. 113 */ 114 public static int getDistSlot(int dist) { 115 if (dist <= DIST_MODEL_START) 116 return dist; 117 118 int n = dist; 119 int i = 31; 120 121 if ((n & 0xFFFF0000) == 0) { 122 n <<= 16; 123 i = 15; 124 } 125 126 if ((n & 0xFF000000) == 0) { 127 n <<= 8; 128 i -= 8; 129 } 130 131 if ((n & 0xF0000000) == 0) { 132 n <<= 4; 133 i -= 4; 134 } 135 136 if ((n & 0xC0000000) == 0) { 137 n <<= 2; 138 i -= 2; 139 } 140 141 if ((n & 0x80000000) == 0) 142 --i; 143 144 return (i << 1) + ((dist >>> (i - 1)) & 1); 145 } 146 147 /** 148 * Gets the next LZMA symbol. 149 * <p> 150 * There are three types of symbols: literal (a single byte), 151 * repeated match, and normal match. The symbol is indicated 152 * by the return value and by the variable <code>back</code>. 153 * <p> 154 * Literal: <code>back == -1</code> and return value is <code>1</code>. 155 * The literal itself needs to be read from <code>lz</code> separately. 156 * <p> 157 * Repeated match: <code>back</code> is in the range [0, 3] and 158 * the return value is the length of the repeated match. 159 * <p> 160 * Normal match: <code>back - REPS<code> (<code>back - 4</code>) 161 * is the distance of the match and the return value is the length 162 * of the match. 163 */ 164 abstract int getNextSymbol(); 165 166 LZMAEncoder(RangeEncoder rc, LZEncoder lz, 167 int lc, int lp, int pb, int dictSize, int niceLen) { 168 super(pb); 169 this.rc = rc; 170 this.lz = lz; 171 this.niceLen = niceLen; 172 173 literalEncoder = new LiteralEncoder(lc, lp); 174 matchLenEncoder = new LengthEncoder(pb, niceLen); 175 repLenEncoder = new LengthEncoder(pb, niceLen); 176 177 distSlotPricesSize = getDistSlot(dictSize - 1) + 1; 178 distSlotPrices = new int[DIST_STATES][distSlotPricesSize]; 179 180 reset(); 181 } 182 183 public LZEncoder getLZEncoder() { 184 return lz; 185 } 186 187 public void reset() { 188 super.reset(); 189 literalEncoder.reset(); 190 matchLenEncoder.reset(); 191 repLenEncoder.reset(); 192 distPriceCount = 0; 193 alignPriceCount = 0; 194 195 uncompressedSize += readAhead + 1; 196 readAhead = -1; 197 } 198 199 public int getUncompressedSize() { 200 return uncompressedSize; 201 } 202 203 public void resetUncompressedSize() { 204 uncompressedSize = 0; 205 } 206 207 /** 208 * Compresses for LZMA2. 209 * 210 * @return true if the LZMA2 chunk became full, false otherwise 211 */ 212 public boolean encodeForLZMA2() { 213 if (!lz.isStarted() && !encodeInit()) 214 return false; 215 216 while (uncompressedSize <= LZMA2_UNCOMPRESSED_LIMIT 217 && rc.getPendingSize() <= LZMA2_COMPRESSED_LIMIT) 218 if (!encodeSymbol()) 219 return false; 220 221 return true; 222 } 223 224 private boolean encodeInit() { 225 assert readAhead == -1; 226 if (!lz.hasEnoughData(0)) 227 return false; 228 229 // The first symbol must be a literal unless using 230 // a preset dictionary. This code isn't run if using 231 // a preset dictionary. 232 skip(1); 233 rc.encodeBit(isMatch[state.get()], 0, 0); 234 literalEncoder.encodeInit(); 235 236 --readAhead; 237 assert readAhead == -1; 238 239 ++uncompressedSize; 240 assert uncompressedSize == 1; 241 242 return true; 243 } 244 245 private boolean encodeSymbol() { 246 if (!lz.hasEnoughData(readAhead + 1)) 247 return false; 248 249 int len = getNextSymbol(); 250 251 assert readAhead >= 0; 252 int posState = (lz.getPos() - readAhead) & posMask; 253 254 if (back == -1) { 255 // Literal i.e. eight-bit byte 256 assert len == 1; 257 rc.encodeBit(isMatch[state.get()], posState, 0); 258 literalEncoder.encode(); 259 } else { 260 // Some type of match 261 rc.encodeBit(isMatch[state.get()], posState, 1); 262 if (back < REPS) { 263 // Repeated match i.e. the same distance 264 // has been used earlier. 265 assert lz.getMatchLen(-readAhead, reps[back], len) == len; 266 rc.encodeBit(isRep, state.get(), 1); 267 encodeRepMatch(back, len, posState); 268 } else { 269 // Normal match 270 assert lz.getMatchLen(-readAhead, back - REPS, len) == len; 271 rc.encodeBit(isRep, state.get(), 0); 272 encodeMatch(back - REPS, len, posState); 273 } 274 } 275 276 readAhead -= len; 277 uncompressedSize += len; 278 279 return true; 280 } 281 282 private void encodeMatch(int dist, int len, int posState) { 283 state.updateMatch(); 284 matchLenEncoder.encode(len, posState); 285 286 int distSlot = getDistSlot(dist); 287 rc.encodeBitTree(distSlots[getDistState(len)], distSlot); 288 289 if (distSlot >= DIST_MODEL_START) { 290 int footerBits = (distSlot >>> 1) - 1; 291 int base = (2 | (distSlot & 1)) << footerBits; 292 int distReduced = dist - base; 293 294 if (distSlot < DIST_MODEL_END) { 295 rc.encodeReverseBitTree( 296 distSpecial[distSlot - DIST_MODEL_START], 297 distReduced); 298 } else { 299 rc.encodeDirectBits(distReduced >>> ALIGN_BITS, 300 footerBits - ALIGN_BITS); 301 rc.encodeReverseBitTree(distAlign, distReduced & ALIGN_MASK); 302 --alignPriceCount; 303 } 304 } 305 306 reps[3] = reps[2]; 307 reps[2] = reps[1]; 308 reps[1] = reps[0]; 309 reps[0] = dist; 310 311 --distPriceCount; 312 } 313 314 private void encodeRepMatch(int rep, int len, int posState) { 315 if (rep == 0) { 316 rc.encodeBit(isRep0, state.get(), 0); 317 rc.encodeBit(isRep0Long[state.get()], posState, len == 1 ? 0 : 1); 318 } else { 319 int dist = reps[rep]; 320 rc.encodeBit(isRep0, state.get(), 1); 321 322 if (rep == 1) { 323 rc.encodeBit(isRep1, state.get(), 0); 324 } else { 325 rc.encodeBit(isRep1, state.get(), 1); 326 rc.encodeBit(isRep2, state.get(), rep - 2); 327 328 if (rep == 3) 329 reps[3] = reps[2]; 330 331 reps[2] = reps[1]; 332 } 333 334 reps[1] = reps[0]; 335 reps[0] = dist; 336 } 337 338 if (len == 1) { 339 state.updateShortRep(); 340 } else { 341 repLenEncoder.encode(len, posState); 342 state.updateLongRep(); 343 } 344 } 345 346 Matches getMatches() { 347 ++readAhead; 348 Matches matches = lz.getMatches(); 349 assert lz.verifyMatches(matches); 350 return matches; 351 } 352 353 void skip(int len) { 354 readAhead += len; 355 lz.skip(len); 356 } 357 358 int getAnyMatchPrice(State state, int posState) { 359 return RangeEncoder.getBitPrice(isMatch[state.get()][posState], 1); 360 } 361 362 int getNormalMatchPrice(int anyMatchPrice, State state) { 363 return anyMatchPrice 364 + RangeEncoder.getBitPrice(isRep[state.get()], 0); 365 } 366 367 int getAnyRepPrice(int anyMatchPrice, State state) { 368 return anyMatchPrice 369 + RangeEncoder.getBitPrice(isRep[state.get()], 1); 370 } 371 372 int getShortRepPrice(int anyRepPrice, State state, int posState) { 373 return anyRepPrice 374 + RangeEncoder.getBitPrice(isRep0[state.get()], 0) 375 + RangeEncoder.getBitPrice(isRep0Long[state.get()][posState], 376 0); 377 } 378 379 int getLongRepPrice(int anyRepPrice, int rep, State state, int posState) { 380 int price = anyRepPrice; 381 382 if (rep == 0) { 383 price += RangeEncoder.getBitPrice(isRep0[state.get()], 0) 384 + RangeEncoder.getBitPrice( 385 isRep0Long[state.get()][posState], 1); 386 } else { 387 price += RangeEncoder.getBitPrice(isRep0[state.get()], 1); 388 389 if (rep == 1) 390 price += RangeEncoder.getBitPrice(isRep1[state.get()], 0); 391 else 392 price += RangeEncoder.getBitPrice(isRep1[state.get()], 1) 393 + RangeEncoder.getBitPrice(isRep2[state.get()], 394 rep - 2); 395 } 396 397 return price; 398 } 399 400 int getLongRepAndLenPrice(int rep, int len, State state, int posState) { 401 int anyMatchPrice = getAnyMatchPrice(state, posState); 402 int anyRepPrice = getAnyRepPrice(anyMatchPrice, state); 403 int longRepPrice = getLongRepPrice(anyRepPrice, rep, state, posState); 404 return longRepPrice + repLenEncoder.getPrice(len, posState); 405 } 406 407 int getMatchAndLenPrice(int normalMatchPrice, 408 int dist, int len, int posState) { 409 int price = normalMatchPrice 410 + matchLenEncoder.getPrice(len, posState); 411 int distState = getDistState(len); 412 413 if (dist < FULL_DISTANCES) { 414 price += fullDistPrices[distState][dist]; 415 } else { 416 // Note that distSlotPrices includes also 417 // the price of direct bits. 418 int distSlot = getDistSlot(dist); 419 price += distSlotPrices[distState][distSlot] 420 + alignPrices[dist & ALIGN_MASK]; 421 } 422 423 return price; 424 } 425 426 private void updateDistPrices() { 427 distPriceCount = DIST_PRICE_UPDATE_INTERVAL; 428 429 for (int distState = 0; distState < DIST_STATES; ++distState) { 430 for (int distSlot = 0; distSlot < distSlotPricesSize; ++distSlot) 431 distSlotPrices[distState][distSlot] 432 = RangeEncoder.getBitTreePrice( 433 distSlots[distState], distSlot); 434 435 for (int distSlot = DIST_MODEL_END; distSlot < distSlotPricesSize; 436 ++distSlot) { 437 int count = (distSlot >>> 1) - 1 - ALIGN_BITS; 438 distSlotPrices[distState][distSlot] 439 += RangeEncoder.getDirectBitsPrice(count); 440 } 441 442 for (int dist = 0; dist < DIST_MODEL_START; ++dist) 443 fullDistPrices[distState][dist] 444 = distSlotPrices[distState][dist]; 445 } 446 447 int dist = DIST_MODEL_START; 448 for (int distSlot = DIST_MODEL_START; distSlot < DIST_MODEL_END; 449 ++distSlot) { 450 int footerBits = (distSlot >>> 1) - 1; 451 int base = (2 | (distSlot & 1)) << footerBits; 452 453 int limit = distSpecial[distSlot - DIST_MODEL_START].length; 454 for (int i = 0; i < limit; ++i) { 455 int distReduced = dist - base; 456 int price = RangeEncoder.getReverseBitTreePrice( 457 distSpecial[distSlot - DIST_MODEL_START], 458 distReduced); 459 460 for (int distState = 0; distState < DIST_STATES; ++distState) 461 fullDistPrices[distState][dist] 462 = distSlotPrices[distState][distSlot] + price; 463 464 ++dist; 465 } 466 } 467 468 assert dist == FULL_DISTANCES; 469 } 470 471 private void updateAlignPrices() { 472 alignPriceCount = ALIGN_PRICE_UPDATE_INTERVAL; 473 474 for (int i = 0; i < ALIGN_SIZE; ++i) 475 alignPrices[i] = RangeEncoder.getReverseBitTreePrice(distAlign, 476 i); 477 } 478 479 /** 480 * Updates the lookup tables used for calculating match distance 481 * and length prices. The updating is skipped for performance reasons 482 * if the tables haven't changed much since the previous update. 483 */ 484 void updatePrices() { 485 if (distPriceCount <= 0) 486 updateDistPrices(); 487 488 if (alignPriceCount <= 0) 489 updateAlignPrices(); 490 491 matchLenEncoder.updatePrices(); 492 repLenEncoder.updatePrices(); 493 } 494 495 496 class LiteralEncoder extends LiteralCoder { 497 private final LiteralSubencoder[] subencoders; 498 499 LiteralEncoder(int lc, int lp) { 500 super(lc, lp); 501 502 subencoders = new LiteralSubencoder[1 << (lc + lp)]; 503 for (int i = 0; i < subencoders.length; ++i) 504 subencoders[i] = new LiteralSubencoder(); 505 } 506 507 void reset() { 508 for (int i = 0; i < subencoders.length; ++i) 509 subencoders[i].reset(); 510 } 511 512 void encodeInit() { 513 // When encoding the first byte of the stream, there is 514 // no previous byte in the dictionary so the encode function 515 // wouldn't work. 516 assert readAhead >= 0; 517 subencoders[0].encode(); 518 } 519 520 void encode() { 521 assert readAhead >= 0; 522 int i = getSubcoderIndex(lz.getByte(1 + readAhead), 523 lz.getPos() - readAhead); 524 subencoders[i].encode(); 525 } 526 527 int getPrice(int curByte, int matchByte, 528 int prevByte, int pos, State state) { 529 int price = RangeEncoder.getBitPrice( 530 isMatch[state.get()][pos & posMask], 0); 531 532 int i = getSubcoderIndex(prevByte, pos); 533 price += state.isLiteral() 534 ? subencoders[i].getNormalPrice(curByte) 535 : subencoders[i].getMatchedPrice(curByte, matchByte); 536 537 return price; 538 } 539 540 private class LiteralSubencoder extends LiteralSubcoder { 541 void encode() { 542 int symbol = lz.getByte(readAhead) | 0x100; 543 544 if (state.isLiteral()) { 545 int subencoderIndex; 546 int bit; 547 548 do { 549 subencoderIndex = symbol >>> 8; 550 bit = (symbol >>> 7) & 1; 551 rc.encodeBit(probs, subencoderIndex, bit); 552 symbol <<= 1; 553 } while (symbol < 0x10000); 554 555 } else { 556 int matchByte = lz.getByte(reps[0] + 1 + readAhead); 557 int offset = 0x100; 558 int subencoderIndex; 559 int matchBit; 560 int bit; 561 562 do { 563 matchByte <<= 1; 564 matchBit = matchByte & offset; 565 subencoderIndex = offset + matchBit + (symbol >>> 8); 566 bit = (symbol >>> 7) & 1; 567 rc.encodeBit(probs, subencoderIndex, bit); 568 symbol <<= 1; 569 offset &= ~(matchByte ^ symbol); 570 } while (symbol < 0x10000); 571 } 572 573 state.updateLiteral(); 574 } 575 576 int getNormalPrice(int symbol) { 577 int price = 0; 578 int subencoderIndex; 579 int bit; 580 581 symbol |= 0x100; 582 583 do { 584 subencoderIndex = symbol >>> 8; 585 bit = (symbol >>> 7) & 1; 586 price += RangeEncoder.getBitPrice(probs[subencoderIndex], 587 bit); 588 symbol <<= 1; 589 } while (symbol < (0x100 << 8)); 590 591 return price; 592 } 593 594 int getMatchedPrice(int symbol, int matchByte) { 595 int price = 0; 596 int offset = 0x100; 597 int subencoderIndex; 598 int matchBit; 599 int bit; 600 601 symbol |= 0x100; 602 603 do { 604 matchByte <<= 1; 605 matchBit = matchByte & offset; 606 subencoderIndex = offset + matchBit + (symbol >>> 8); 607 bit = (symbol >>> 7) & 1; 608 price += RangeEncoder.getBitPrice(probs[subencoderIndex], 609 bit); 610 symbol <<= 1; 611 offset &= ~(matchByte ^ symbol); 612 } while (symbol < (0x100 << 8)); 613 614 return price; 615 } 616 } 617 } 618 619 620 class LengthEncoder extends LengthCoder { 621 /** 622 * The prices are updated after at least 623 * <code>PRICE_UPDATE_INTERVAL</code> many lengths 624 * have been encoded with the same posState. 625 */ 626 private static final int PRICE_UPDATE_INTERVAL = 32; // FIXME? 627 628 private final int[] counters; 629 private final int[][] prices; 630 631 LengthEncoder(int pb, int niceLen) { 632 int posStates = 1 << pb; 633 counters = new int[posStates]; 634 635 // Always allocate at least LOW_SYMBOLS + MID_SYMBOLS because 636 // it makes updatePrices slightly simpler. The prices aren't 637 // usually needed anyway if niceLen < 18. 638 int lenSymbols = Math.max(niceLen - MATCH_LEN_MIN + 1, 639 LOW_SYMBOLS + MID_SYMBOLS); 640 prices = new int[posStates][lenSymbols]; 641 } 642 643 void reset() { 644 super.reset(); 645 646 // Reset counters to zero to force price update before 647 // the prices are needed. 648 for (int i = 0; i < counters.length; ++i) 649 counters[i] = 0; 650 } 651 652 void encode(int len, int posState) { 653 len -= MATCH_LEN_MIN; 654 655 if (len < LOW_SYMBOLS) { 656 rc.encodeBit(choice, 0, 0); 657 rc.encodeBitTree(low[posState], len); 658 } else { 659 rc.encodeBit(choice, 0, 1); 660 len -= LOW_SYMBOLS; 661 662 if (len < MID_SYMBOLS) { 663 rc.encodeBit(choice, 1, 0); 664 rc.encodeBitTree(mid[posState], len); 665 } else { 666 rc.encodeBit(choice, 1, 1); 667 rc.encodeBitTree(high, len - MID_SYMBOLS); 668 } 669 } 670 671 --counters[posState]; 672 } 673 674 int getPrice(int len, int posState) { 675 return prices[posState][len - MATCH_LEN_MIN]; 676 } 677 678 void updatePrices() { 679 for (int posState = 0; posState < counters.length; ++posState) { 680 if (counters[posState] <= 0) { 681 counters[posState] = PRICE_UPDATE_INTERVAL; 682 updatePrices(posState); 683 } 684 } 685 } 686 687 private void updatePrices(int posState) { 688 int choice0Price = RangeEncoder.getBitPrice(choice[0], 0); 689 690 int i = 0; 691 for (; i < LOW_SYMBOLS; ++i) 692 prices[posState][i] = choice0Price 693 + RangeEncoder.getBitTreePrice(low[posState], i); 694 695 choice0Price = RangeEncoder.getBitPrice(choice[0], 1); 696 int choice1Price = RangeEncoder.getBitPrice(choice[1], 0); 697 698 for (; i < LOW_SYMBOLS + MID_SYMBOLS; ++i) 699 prices[posState][i] = choice0Price + choice1Price 700 + RangeEncoder.getBitTreePrice(mid[posState], 701 i - LOW_SYMBOLS); 702 703 choice1Price = RangeEncoder.getBitPrice(choice[1], 1); 704 705 for (; i < prices[posState].length; ++i) 706 prices[posState][i] = choice0Price + choice1Price 707 + RangeEncoder.getBitTreePrice(high, i - LOW_SYMBOLS 708 - MID_SYMBOLS); 709 } 710 } 711 } 712