Home | History | Annotate | Download | only in zlib
      1 /* infback.c -- inflate using a call-back interface
      2  * Copyright (C) 1995-2009 Mark Adler
      3  * For conditions of distribution and use, see copyright notice in zlib.h
      4  */
      5 
      6 /*
      7    This code is largely copied from inflate.c.  Normally either infback.o or
      8    inflate.o would be linked into an application--not both.  The interface
      9    with inffast.c is retained so that optimized assembler-coded versions of
     10    inflate_fast() can be used with either inflate.c or infback.c.
     11  */
     12 
     13 #include "zutil.h"
     14 #include "inftrees.h"
     15 #include "inflate.h"
     16 #include "inffast.h"
     17 
     18 /* function prototypes */
     19 local void fixedtables OF((struct inflate_state FAR *state));
     20 
     21 /*
     22    strm provides memory allocation functions in zalloc and zfree, or
     23    Z_NULL to use the library memory allocation functions.
     24 
     25    windowBits is in the range 8..15, and window is a user-supplied
     26    window and output buffer that is 2**windowBits bytes.
     27  */
     28 int ZEXPORT inflateBackInit_(strm, windowBits, window, version, stream_size)
     29 z_streamp strm;
     30 int windowBits;
     31 unsigned char FAR *window;
     32 const char *version;
     33 int stream_size;
     34 {
     35     struct inflate_state FAR *state;
     36 
     37     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
     38         stream_size != (int)(sizeof(z_stream)))
     39         return Z_VERSION_ERROR;
     40     if (strm == Z_NULL || window == Z_NULL ||
     41         windowBits < 8 || windowBits > 15)
     42         return Z_STREAM_ERROR;
     43     strm->msg = Z_NULL;                 /* in case we return an error */
     44     if (strm->zalloc == (alloc_func)0) {
     45         strm->zalloc = zcalloc;
     46         strm->opaque = (voidpf)0;
     47     }
     48     if (strm->zfree == (free_func)0) strm->zfree = zcfree;
     49     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
     50                                                sizeof(struct inflate_state));
     51     if (state == Z_NULL) return Z_MEM_ERROR;
     52     Tracev((stderr, "inflate: allocated\n"));
     53     strm->state = (struct internal_state FAR *)state;
     54     state->dmax = 32768U;
     55     state->wbits = windowBits;
     56     state->wsize = 1U << windowBits;
     57     state->window = window;
     58     state->wnext = 0;
     59     state->whave = 0;
     60     return Z_OK;
     61 }
     62 
     63 /*
     64    Return state with length and distance decoding tables and index sizes set to
     65    fixed code decoding.  Normally this returns fixed tables from inffixed.h.
     66    If BUILDFIXED is defined, then instead this routine builds the tables the
     67    first time it's called, and returns those tables the first time and
     68    thereafter.  This reduces the size of the code by about 2K bytes, in
     69    exchange for a little execution time.  However, BUILDFIXED should not be
     70    used for threaded applications, since the rewriting of the tables and virgin
     71    may not be thread-safe.
     72  */
     73 local void fixedtables(state)
     74 struct inflate_state FAR *state;
     75 {
     76 #ifdef BUILDFIXED
     77     static int virgin = 1;
     78     static code *lenfix, *distfix;
     79     static code fixed[544];
     80 
     81     /* build fixed huffman tables if first call (may not be thread safe) */
     82     if (virgin) {
     83         unsigned sym, bits;
     84         static code *next;
     85 
     86         /* literal/length table */
     87         sym = 0;
     88         while (sym < 144) state->lens[sym++] = 8;
     89         while (sym < 256) state->lens[sym++] = 9;
     90         while (sym < 280) state->lens[sym++] = 7;
     91         while (sym < 288) state->lens[sym++] = 8;
     92         next = fixed;
     93         lenfix = next;
     94         bits = 9;
     95         inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
     96 
     97         /* distance table */
     98         sym = 0;
     99         while (sym < 32) state->lens[sym++] = 5;
    100         distfix = next;
    101         bits = 5;
    102         inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
    103 
    104         /* do this just once */
    105         virgin = 0;
    106     }
    107 #else /* !BUILDFIXED */
    108 #   include "inffixed.h"
    109 #endif /* BUILDFIXED */
    110     state->lencode = lenfix;
    111     state->lenbits = 9;
    112     state->distcode = distfix;
    113     state->distbits = 5;
    114 }
    115 
    116 /* Macros for inflateBack(): */
    117 
    118 /* Load returned state from inflate_fast() */
    119 #define LOAD() \
    120     do { \
    121         put = strm->next_out; \
    122         left = strm->avail_out; \
    123         next = strm->next_in; \
    124         have = strm->avail_in; \
    125         hold = state->hold; \
    126         bits = state->bits; \
    127     } while (0)
    128 
    129 /* Set state from registers for inflate_fast() */
    130 #define RESTORE() \
    131     do { \
    132         strm->next_out = put; \
    133         strm->avail_out = left; \
    134         strm->next_in = next; \
    135         strm->avail_in = have; \
    136         state->hold = hold; \
    137         state->bits = bits; \
    138     } while (0)
    139 
    140 /* Clear the input bit accumulator */
    141 #define INITBITS() \
    142     do { \
    143         hold = 0; \
    144         bits = 0; \
    145     } while (0)
    146 
    147 /* Assure that some input is available.  If input is requested, but denied,
    148    then return a Z_BUF_ERROR from inflateBack(). */
    149 #define PULL() \
    150     do { \
    151         if (have == 0) { \
    152             have = in(in_desc, &next); \
    153             if (have == 0) { \
    154                 next = Z_NULL; \
    155                 ret = Z_BUF_ERROR; \
    156                 goto inf_leave; \
    157             } \
    158         } \
    159     } while (0)
    160 
    161 /* Get a byte of input into the bit accumulator, or return from inflateBack()
    162    with an error if there is no input available. */
    163 #define PULLBYTE() \
    164     do { \
    165         PULL(); \
    166         have--; \
    167         hold += (unsigned long)(*next++) << bits; \
    168         bits += 8; \
    169     } while (0)
    170 
    171 /* Assure that there are at least n bits in the bit accumulator.  If there is
    172    not enough available input to do that, then return from inflateBack() with
    173    an error. */
    174 #define NEEDBITS(n) \
    175     do { \
    176         while (bits < (unsigned)(n)) \
    177             PULLBYTE(); \
    178     } while (0)
    179 
    180 /* Return the low n bits of the bit accumulator (n < 16) */
    181 #define BITS(n) \
    182     ((unsigned)hold & ((1U << (n)) - 1))
    183 
    184 /* Remove n bits from the bit accumulator */
    185 #define DROPBITS(n) \
    186     do { \
    187         hold >>= (n); \
    188         bits -= (unsigned)(n); \
    189     } while (0)
    190 
    191 /* Remove zero to seven bits as needed to go to a byte boundary */
    192 #define BYTEBITS() \
    193     do { \
    194         hold >>= bits & 7; \
    195         bits -= bits & 7; \
    196     } while (0)
    197 
    198 /* Assure that some output space is available, by writing out the window
    199    if it's full.  If the write fails, return from inflateBack() with a
    200    Z_BUF_ERROR. */
    201 #define ROOM() \
    202     do { \
    203         if (left == 0) { \
    204             put = state->window; \
    205             left = state->wsize; \
    206             state->whave = left; \
    207             if (out(out_desc, put, left)) { \
    208                 ret = Z_BUF_ERROR; \
    209                 goto inf_leave; \
    210             } \
    211         } \
    212     } while (0)
    213 
    214 /*
    215    strm provides the memory allocation functions and window buffer on input,
    216    and provides information on the unused input on return.  For Z_DATA_ERROR
    217    returns, strm will also provide an error message.
    218 
    219    in() and out() are the call-back input and output functions.  When
    220    inflateBack() needs more input, it calls in().  When inflateBack() has
    221    filled the window with output, or when it completes with data in the
    222    window, it calls out() to write out the data.  The application must not
    223    change the provided input until in() is called again or inflateBack()
    224    returns.  The application must not change the window/output buffer until
    225    inflateBack() returns.
    226 
    227    in() and out() are called with a descriptor parameter provided in the
    228    inflateBack() call.  This parameter can be a structure that provides the
    229    information required to do the read or write, as well as accumulated
    230    information on the input and output such as totals and check values.
    231 
    232    in() should return zero on failure.  out() should return non-zero on
    233    failure.  If either in() or out() fails, than inflateBack() returns a
    234    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
    235    was in() or out() that caused in the error.  Otherwise,  inflateBack()
    236    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
    237    error, or Z_MEM_ERROR if it could not allocate memory for the state.
    238    inflateBack() can also return Z_STREAM_ERROR if the input parameters
    239    are not correct, i.e. strm is Z_NULL or the state was not initialized.
    240  */
    241 int ZEXPORT inflateBack(strm, in, in_desc, out, out_desc)
    242 z_streamp strm;
    243 in_func in;
    244 void FAR *in_desc;
    245 out_func out;
    246 void FAR *out_desc;
    247 {
    248     struct inflate_state FAR *state;
    249     unsigned char FAR *next;    /* next input */
    250     unsigned char FAR *put;     /* next output */
    251     unsigned have, left;        /* available input and output */
    252     unsigned long hold;         /* bit buffer */
    253     unsigned bits;              /* bits in bit buffer */
    254     unsigned copy;              /* number of stored or match bytes to copy */
    255     unsigned char FAR *from;    /* where to copy match bytes from */
    256     code here;                  /* current decoding table entry */
    257     code last;                  /* parent table entry */
    258     unsigned len;               /* length to copy for repeats, bits to drop */
    259     int ret;                    /* return code */
    260     static const unsigned short order[19] = /* permutation of code lengths */
    261         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
    262 
    263     /* Check that the strm exists and that the state was initialized */
    264     if (strm == Z_NULL || strm->state == Z_NULL)
    265         return Z_STREAM_ERROR;
    266     state = (struct inflate_state FAR *)strm->state;
    267 
    268     /* Reset the state */
    269     strm->msg = Z_NULL;
    270     state->mode = TYPE;
    271     state->last = 0;
    272     state->whave = 0;
    273     next = strm->next_in;
    274     have = next != Z_NULL ? strm->avail_in : 0;
    275     hold = 0;
    276     bits = 0;
    277     put = state->window;
    278     left = state->wsize;
    279 
    280     /* Inflate until end of block marked as last */
    281     for (;;)
    282         switch (state->mode) {
    283         case TYPE:
    284             /* determine and dispatch block type */
    285             if (state->last) {
    286                 BYTEBITS();
    287                 state->mode = DONE;
    288                 break;
    289             }
    290             NEEDBITS(3);
    291             state->last = BITS(1);
    292             DROPBITS(1);
    293             switch (BITS(2)) {
    294             case 0:                             /* stored block */
    295                 Tracev((stderr, "inflate:     stored block%s\n",
    296                         state->last ? " (last)" : ""));
    297                 state->mode = STORED;
    298                 break;
    299             case 1:                             /* fixed block */
    300                 fixedtables(state);
    301                 Tracev((stderr, "inflate:     fixed codes block%s\n",
    302                         state->last ? " (last)" : ""));
    303                 state->mode = LEN;              /* decode codes */
    304                 break;
    305             case 2:                             /* dynamic block */
    306                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
    307                         state->last ? " (last)" : ""));
    308                 state->mode = TABLE;
    309                 break;
    310             case 3:
    311                 strm->msg = (char *)"invalid block type";
    312                 state->mode = BAD;
    313             }
    314             DROPBITS(2);
    315             break;
    316 
    317         case STORED:
    318             /* get and verify stored block length */
    319             BYTEBITS();                         /* go to byte boundary */
    320             NEEDBITS(32);
    321             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
    322                 strm->msg = (char *)"invalid stored block lengths";
    323                 state->mode = BAD;
    324                 break;
    325             }
    326             state->length = (unsigned)hold & 0xffff;
    327             Tracev((stderr, "inflate:       stored length %u\n",
    328                     state->length));
    329             INITBITS();
    330 
    331             /* copy stored block from input to output */
    332             while (state->length != 0) {
    333                 copy = state->length;
    334                 PULL();
    335                 ROOM();
    336                 if (copy > have) copy = have;
    337                 if (copy > left) copy = left;
    338                 zmemcpy(put, next, copy);
    339                 have -= copy;
    340                 next += copy;
    341                 left -= copy;
    342                 put += copy;
    343                 state->length -= copy;
    344             }
    345             Tracev((stderr, "inflate:       stored end\n"));
    346             state->mode = TYPE;
    347             break;
    348 
    349         case TABLE:
    350             /* get dynamic table entries descriptor */
    351             NEEDBITS(14);
    352             state->nlen = BITS(5) + 257;
    353             DROPBITS(5);
    354             state->ndist = BITS(5) + 1;
    355             DROPBITS(5);
    356             state->ncode = BITS(4) + 4;
    357             DROPBITS(4);
    358 #ifndef PKZIP_BUG_WORKAROUND
    359             if (state->nlen > 286 || state->ndist > 30) {
    360                 strm->msg = (char *)"too many length or distance symbols";
    361                 state->mode = BAD;
    362                 break;
    363             }
    364 #endif
    365             Tracev((stderr, "inflate:       table sizes ok\n"));
    366 
    367             /* get code length code lengths (not a typo) */
    368             state->have = 0;
    369             while (state->have < state->ncode) {
    370                 NEEDBITS(3);
    371                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
    372                 DROPBITS(3);
    373             }
    374             while (state->have < 19)
    375                 state->lens[order[state->have++]] = 0;
    376             state->next = state->codes;
    377             state->lencode = (code const FAR *)(state->next);
    378             state->lenbits = 7;
    379             ret = inflate_table(CODES, state->lens, 19, &(state->next),
    380                                 &(state->lenbits), state->work);
    381             if (ret) {
    382                 strm->msg = (char *)"invalid code lengths set";
    383                 state->mode = BAD;
    384                 break;
    385             }
    386             Tracev((stderr, "inflate:       code lengths ok\n"));
    387 
    388             /* get length and distance code code lengths */
    389             state->have = 0;
    390             while (state->have < state->nlen + state->ndist) {
    391                 for (;;) {
    392                     here = state->lencode[BITS(state->lenbits)];
    393                     if ((unsigned)(here.bits) <= bits) break;
    394                     PULLBYTE();
    395                 }
    396                 if (here.val < 16) {
    397                     NEEDBITS(here.bits);
    398                     DROPBITS(here.bits);
    399                     state->lens[state->have++] = here.val;
    400                 }
    401                 else {
    402                     if (here.val == 16) {
    403                         NEEDBITS(here.bits + 2);
    404                         DROPBITS(here.bits);
    405                         if (state->have == 0) {
    406                             strm->msg = (char *)"invalid bit length repeat";
    407                             state->mode = BAD;
    408                             break;
    409                         }
    410                         len = (unsigned)(state->lens[state->have - 1]);
    411                         copy = 3 + BITS(2);
    412                         DROPBITS(2);
    413                     }
    414                     else if (here.val == 17) {
    415                         NEEDBITS(here.bits + 3);
    416                         DROPBITS(here.bits);
    417                         len = 0;
    418                         copy = 3 + BITS(3);
    419                         DROPBITS(3);
    420                     }
    421                     else {
    422                         NEEDBITS(here.bits + 7);
    423                         DROPBITS(here.bits);
    424                         len = 0;
    425                         copy = 11 + BITS(7);
    426                         DROPBITS(7);
    427                     }
    428                     if (state->have + copy > state->nlen + state->ndist) {
    429                         strm->msg = (char *)"invalid bit length repeat";
    430                         state->mode = BAD;
    431                         break;
    432                     }
    433                     while (copy--)
    434                         state->lens[state->have++] = (unsigned short)len;
    435                 }
    436             }
    437 
    438             /* handle error breaks in while */
    439             if (state->mode == BAD) break;
    440 
    441             /* check for end-of-block code (better have one) */
    442             if (state->lens[256] == 0) {
    443                 strm->msg = (char *)"invalid code -- missing end-of-block";
    444                 state->mode = BAD;
    445                 break;
    446             }
    447 
    448             /* build code tables -- note: do not change the lenbits or distbits
    449                values here (9 and 6) without reading the comments in inftrees.h
    450                concerning the ENOUGH constants, which depend on those values */
    451             state->next = state->codes;
    452             state->lencode = (code const FAR *)(state->next);
    453             state->lenbits = 9;
    454             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
    455                                 &(state->lenbits), state->work);
    456             if (ret) {
    457                 strm->msg = (char *)"invalid literal/lengths set";
    458                 state->mode = BAD;
    459                 break;
    460             }
    461             state->distcode = (code const FAR *)(state->next);
    462             state->distbits = 6;
    463             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
    464                             &(state->next), &(state->distbits), state->work);
    465             if (ret) {
    466                 strm->msg = (char *)"invalid distances set";
    467                 state->mode = BAD;
    468                 break;
    469             }
    470             Tracev((stderr, "inflate:       codes ok\n"));
    471             state->mode = LEN;
    472 
    473         case LEN:
    474             /* use inflate_fast() if we have enough input and output */
    475             if (have >= 6 && left >= 258) {
    476                 RESTORE();
    477                 if (state->whave < state->wsize)
    478                     state->whave = state->wsize - left;
    479                 inflate_fast(strm, state->wsize);
    480                 LOAD();
    481                 break;
    482             }
    483 
    484             /* get a literal, length, or end-of-block code */
    485             for (;;) {
    486                 here = state->lencode[BITS(state->lenbits)];
    487                 if ((unsigned)(here.bits) <= bits) break;
    488                 PULLBYTE();
    489             }
    490             if (here.op && (here.op & 0xf0) == 0) {
    491                 last = here;
    492                 for (;;) {
    493                     here = state->lencode[last.val +
    494                             (BITS(last.bits + last.op) >> last.bits)];
    495                     if ((unsigned)(last.bits + here.bits) <= bits) break;
    496                     PULLBYTE();
    497                 }
    498                 DROPBITS(last.bits);
    499             }
    500             DROPBITS(here.bits);
    501             state->length = (unsigned)here.val;
    502 
    503             /* process literal */
    504             if (here.op == 0) {
    505                 Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
    506                         "inflate:         literal '%c'\n" :
    507                         "inflate:         literal 0x%02x\n", here.val));
    508                 ROOM();
    509                 *put++ = (unsigned char)(state->length);
    510                 left--;
    511                 state->mode = LEN;
    512                 break;
    513             }
    514 
    515             /* process end of block */
    516             if (here.op & 32) {
    517                 Tracevv((stderr, "inflate:         end of block\n"));
    518                 state->mode = TYPE;
    519                 break;
    520             }
    521 
    522             /* invalid code */
    523             if (here.op & 64) {
    524                 strm->msg = (char *)"invalid literal/length code";
    525                 state->mode = BAD;
    526                 break;
    527             }
    528 
    529             /* length code -- get extra bits, if any */
    530             state->extra = (unsigned)(here.op) & 15;
    531             if (state->extra != 0) {
    532                 NEEDBITS(state->extra);
    533                 state->length += BITS(state->extra);
    534                 DROPBITS(state->extra);
    535             }
    536             Tracevv((stderr, "inflate:         length %u\n", state->length));
    537 
    538             /* get distance code */
    539             for (;;) {
    540                 here = state->distcode[BITS(state->distbits)];
    541                 if ((unsigned)(here.bits) <= bits) break;
    542                 PULLBYTE();
    543             }
    544             if ((here.op & 0xf0) == 0) {
    545                 last = here;
    546                 for (;;) {
    547                     here = state->distcode[last.val +
    548                             (BITS(last.bits + last.op) >> last.bits)];
    549                     if ((unsigned)(last.bits + here.bits) <= bits) break;
    550                     PULLBYTE();
    551                 }
    552                 DROPBITS(last.bits);
    553             }
    554             DROPBITS(here.bits);
    555             if (here.op & 64) {
    556                 strm->msg = (char *)"invalid distance code";
    557                 state->mode = BAD;
    558                 break;
    559             }
    560             state->offset = (unsigned)here.val;
    561 
    562             /* get distance extra bits, if any */
    563             state->extra = (unsigned)(here.op) & 15;
    564             if (state->extra != 0) {
    565                 NEEDBITS(state->extra);
    566                 state->offset += BITS(state->extra);
    567                 DROPBITS(state->extra);
    568             }
    569             if (state->offset > state->wsize - (state->whave < state->wsize ?
    570                                                 left : 0)) {
    571                 strm->msg = (char *)"invalid distance too far back";
    572                 state->mode = BAD;
    573                 break;
    574             }
    575             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
    576 
    577             /* copy match from window to output */
    578             do {
    579                 ROOM();
    580                 copy = state->wsize - state->offset;
    581                 if (copy < left) {
    582                     from = put + copy;
    583                     copy = left - copy;
    584                 }
    585                 else {
    586                     from = put - state->offset;
    587                     copy = left;
    588                 }
    589                 if (copy > state->length) copy = state->length;
    590                 state->length -= copy;
    591                 left -= copy;
    592                 do {
    593                     *put++ = *from++;
    594                 } while (--copy);
    595             } while (state->length != 0);
    596             break;
    597 
    598         case DONE:
    599             /* inflate stream terminated properly -- write leftover output */
    600             ret = Z_STREAM_END;
    601             if (left < state->wsize) {
    602                 if (out(out_desc, state->window, state->wsize - left))
    603                     ret = Z_BUF_ERROR;
    604             }
    605             goto inf_leave;
    606 
    607         case BAD:
    608             ret = Z_DATA_ERROR;
    609             goto inf_leave;
    610 
    611         default:                /* can't happen, but makes compilers happy */
    612             ret = Z_STREAM_ERROR;
    613             goto inf_leave;
    614         }
    615 
    616     /* Return unused input */
    617   inf_leave:
    618     strm->next_in = next;
    619     strm->avail_in = have;
    620     return ret;
    621 }
    622 
    623 int ZEXPORT inflateBackEnd(strm)
    624 z_streamp strm;
    625 {
    626     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
    627         return Z_STREAM_ERROR;
    628     ZFREE(strm, strm->state);
    629     strm->state = Z_NULL;
    630     Tracev((stderr, "inflate: end\n"));
    631     return Z_OK;
    632 }
    633