Home | History | Annotate | Download | only in src
      1 /* infback.c -- inflate using a call-back interface
      2  * Copyright (C) 1995-2011 Mark Adler
      3  * For conditions of distribution and use, see copyright notice in zlib.h
      4  */
      5 
      6 /*
      7    This code is largely copied from inflate.c.  Normally either infback.o or
      8    inflate.o would be linked into an application--not both.  The interface
      9    with inffast.c is retained so that optimized assembler-coded versions of
     10    inflate_fast() can be used with either inflate.c or infback.c.
     11  */
     12 
     13 #include "zutil.h"
     14 #include "inftrees.h"
     15 #include "inflate.h"
     16 #include "inffast.h"
     17 
     18 /* function prototypes */
     19 local void fixedtables OF((struct inflate_state FAR *state));
     20 
     21 /*
     22    strm provides memory allocation functions in zalloc and zfree, or
     23    Z_NULL to use the library memory allocation functions.
     24 
     25    windowBits is in the range 8..15, and window is a user-supplied
     26    window and output buffer that is 2**windowBits bytes.
     27  */
     28 int ZEXPORT inflateBackInit_(strm, windowBits, window, version, stream_size)
     29 z_streamp strm;
     30 int windowBits;
     31 unsigned char FAR *window;
     32 const char *version;
     33 int stream_size;
     34 {
     35     struct inflate_state FAR *state;
     36 
     37     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
     38         stream_size != (int)(sizeof(z_stream)))
     39         return Z_VERSION_ERROR;
     40     if (strm == Z_NULL || window == Z_NULL ||
     41         windowBits < 8 || windowBits > 15)
     42         return Z_STREAM_ERROR;
     43     strm->msg = Z_NULL;                 /* in case we return an error */
     44     if (strm->zalloc == (alloc_func)0) {
     45 #ifdef Z_SOLO
     46         return Z_STREAM_ERROR;
     47 #else
     48         strm->zalloc = zcalloc;
     49         strm->opaque = (voidpf)0;
     50 #endif
     51     }
     52     if (strm->zfree == (free_func)0)
     53 #ifdef Z_SOLO
     54         return Z_STREAM_ERROR;
     55 #else
     56     strm->zfree = zcfree;
     57 #endif
     58     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
     59                                                sizeof(struct inflate_state));
     60     if (state == Z_NULL) return Z_MEM_ERROR;
     61     Tracev((stderr, "inflate: allocated\n"));
     62     strm->state = (struct internal_state FAR *)state;
     63     state->dmax = 32768U;
     64     state->wbits = windowBits;
     65     state->wsize = 1U << windowBits;
     66     state->window = window;
     67     state->wnext = 0;
     68     state->whave = 0;
     69     return Z_OK;
     70 }
     71 
     72 /*
     73    Return state with length and distance decoding tables and index sizes set to
     74    fixed code decoding.  Normally this returns fixed tables from inffixed.h.
     75    If BUILDFIXED is defined, then instead this routine builds the tables the
     76    first time it's called, and returns those tables the first time and
     77    thereafter.  This reduces the size of the code by about 2K bytes, in
     78    exchange for a little execution time.  However, BUILDFIXED should not be
     79    used for threaded applications, since the rewriting of the tables and virgin
     80    may not be thread-safe.
     81  */
     82 local void fixedtables(state)
     83 struct inflate_state FAR *state;
     84 {
     85 #ifdef BUILDFIXED
     86     static int virgin = 1;
     87     static code *lenfix, *distfix;
     88     static code fixed[544];
     89 
     90     /* build fixed huffman tables if first call (may not be thread safe) */
     91     if (virgin) {
     92         unsigned sym, bits;
     93         static code *next;
     94 
     95         /* literal/length table */
     96         sym = 0;
     97         while (sym < 144) state->lens[sym++] = 8;
     98         while (sym < 256) state->lens[sym++] = 9;
     99         while (sym < 280) state->lens[sym++] = 7;
    100         while (sym < 288) state->lens[sym++] = 8;
    101         next = fixed;
    102         lenfix = next;
    103         bits = 9;
    104         inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
    105 
    106         /* distance table */
    107         sym = 0;
    108         while (sym < 32) state->lens[sym++] = 5;
    109         distfix = next;
    110         bits = 5;
    111         inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
    112 
    113         /* do this just once */
    114         virgin = 0;
    115     }
    116 #else /* !BUILDFIXED */
    117 #   include "inffixed.h"
    118 #endif /* BUILDFIXED */
    119     state->lencode = lenfix;
    120     state->lenbits = 9;
    121     state->distcode = distfix;
    122     state->distbits = 5;
    123 }
    124 
    125 /* Macros for inflateBack(): */
    126 
    127 /* Load returned state from inflate_fast() */
    128 #define LOAD() \
    129     do { \
    130         put = strm->next_out; \
    131         left = strm->avail_out; \
    132         next = strm->next_in; \
    133         have = strm->avail_in; \
    134         hold = state->hold; \
    135         bits = state->bits; \
    136     } while (0)
    137 
    138 /* Set state from registers for inflate_fast() */
    139 #define RESTORE() \
    140     do { \
    141         strm->next_out = put; \
    142         strm->avail_out = left; \
    143         strm->next_in = next; \
    144         strm->avail_in = have; \
    145         state->hold = hold; \
    146         state->bits = bits; \
    147     } while (0)
    148 
    149 /* Clear the input bit accumulator */
    150 #define INITBITS() \
    151     do { \
    152         hold = 0; \
    153         bits = 0; \
    154     } while (0)
    155 
    156 /* Assure that some input is available.  If input is requested, but denied,
    157    then return a Z_BUF_ERROR from inflateBack(). */
    158 #define PULL() \
    159     do { \
    160         if (have == 0) { \
    161             have = in(in_desc, &next); \
    162             if (have == 0) { \
    163                 next = Z_NULL; \
    164                 ret = Z_BUF_ERROR; \
    165                 goto inf_leave; \
    166             } \
    167         } \
    168     } while (0)
    169 
    170 /* Get a byte of input into the bit accumulator, or return from inflateBack()
    171    with an error if there is no input available. */
    172 #define PULLBYTE() \
    173     do { \
    174         PULL(); \
    175         have--; \
    176         hold += (unsigned long)(*next++) << bits; \
    177         bits += 8; \
    178     } while (0)
    179 
    180 /* Assure that there are at least n bits in the bit accumulator.  If there is
    181    not enough available input to do that, then return from inflateBack() with
    182    an error. */
    183 #define NEEDBITS(n) \
    184     do { \
    185         while (bits < (unsigned)(n)) \
    186             PULLBYTE(); \
    187     } while (0)
    188 
    189 /* Return the low n bits of the bit accumulator (n < 16) */
    190 #define BITS(n) \
    191     ((unsigned)hold & ((1U << (n)) - 1))
    192 
    193 /* Remove n bits from the bit accumulator */
    194 #define DROPBITS(n) \
    195     do { \
    196         hold >>= (n); \
    197         bits -= (unsigned)(n); \
    198     } while (0)
    199 
    200 /* Remove zero to seven bits as needed to go to a byte boundary */
    201 #define BYTEBITS() \
    202     do { \
    203         hold >>= bits & 7; \
    204         bits -= bits & 7; \
    205     } while (0)
    206 
    207 /* Assure that some output space is available, by writing out the window
    208    if it's full.  If the write fails, return from inflateBack() with a
    209    Z_BUF_ERROR. */
    210 #define ROOM() \
    211     do { \
    212         if (left == 0) { \
    213             put = state->window; \
    214             left = state->wsize; \
    215             state->whave = left; \
    216             if (out(out_desc, put, left)) { \
    217                 ret = Z_BUF_ERROR; \
    218                 goto inf_leave; \
    219             } \
    220         } \
    221     } while (0)
    222 
    223 /*
    224    strm provides the memory allocation functions and window buffer on input,
    225    and provides information on the unused input on return.  For Z_DATA_ERROR
    226    returns, strm will also provide an error message.
    227 
    228    in() and out() are the call-back input and output functions.  When
    229    inflateBack() needs more input, it calls in().  When inflateBack() has
    230    filled the window with output, or when it completes with data in the
    231    window, it calls out() to write out the data.  The application must not
    232    change the provided input until in() is called again or inflateBack()
    233    returns.  The application must not change the window/output buffer until
    234    inflateBack() returns.
    235 
    236    in() and out() are called with a descriptor parameter provided in the
    237    inflateBack() call.  This parameter can be a structure that provides the
    238    information required to do the read or write, as well as accumulated
    239    information on the input and output such as totals and check values.
    240 
    241    in() should return zero on failure.  out() should return non-zero on
    242    failure.  If either in() or out() fails, than inflateBack() returns a
    243    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
    244    was in() or out() that caused in the error.  Otherwise,  inflateBack()
    245    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
    246    error, or Z_MEM_ERROR if it could not allocate memory for the state.
    247    inflateBack() can also return Z_STREAM_ERROR if the input parameters
    248    are not correct, i.e. strm is Z_NULL or the state was not initialized.
    249  */
    250 int ZEXPORT inflateBack(strm, in, in_desc, out, out_desc)
    251 z_streamp strm;
    252 in_func in;
    253 void FAR *in_desc;
    254 out_func out;
    255 void FAR *out_desc;
    256 {
    257     struct inflate_state FAR *state;
    258     unsigned char FAR *next;    /* next input */
    259     unsigned char FAR *put;     /* next output */
    260     unsigned have, left;        /* available input and output */
    261     unsigned long hold;         /* bit buffer */
    262     unsigned bits;              /* bits in bit buffer */
    263     unsigned copy;              /* number of stored or match bytes to copy */
    264     unsigned char FAR *from;    /* where to copy match bytes from */
    265     code here;                  /* current decoding table entry */
    266     code last;                  /* parent table entry */
    267     unsigned len;               /* length to copy for repeats, bits to drop */
    268     int ret;                    /* return code */
    269     static const unsigned short order[19] = /* permutation of code lengths */
    270         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
    271 
    272     /* Check that the strm exists and that the state was initialized */
    273     if (strm == Z_NULL || strm->state == Z_NULL)
    274         return Z_STREAM_ERROR;
    275     state = (struct inflate_state FAR *)strm->state;
    276 
    277     /* Reset the state */
    278     strm->msg = Z_NULL;
    279     state->mode = TYPE;
    280     state->last = 0;
    281     state->whave = 0;
    282     next = strm->next_in;
    283     have = next != Z_NULL ? strm->avail_in : 0;
    284     hold = 0;
    285     bits = 0;
    286     put = state->window;
    287     left = state->wsize;
    288 
    289     /* Inflate until end of block marked as last */
    290     for (;;)
    291         switch (state->mode) {
    292         case TYPE:
    293             /* determine and dispatch block type */
    294             if (state->last) {
    295                 BYTEBITS();
    296                 state->mode = DONE;
    297                 break;
    298             }
    299             NEEDBITS(3);
    300             state->last = BITS(1);
    301             DROPBITS(1);
    302             switch (BITS(2)) {
    303             case 0:                             /* stored block */
    304                 Tracev((stderr, "inflate:     stored block%s\n",
    305                         state->last ? " (last)" : ""));
    306                 state->mode = STORED;
    307                 break;
    308             case 1:                             /* fixed block */
    309                 fixedtables(state);
    310                 Tracev((stderr, "inflate:     fixed codes block%s\n",
    311                         state->last ? " (last)" : ""));
    312                 state->mode = LEN;              /* decode codes */
    313                 break;
    314             case 2:                             /* dynamic block */
    315                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
    316                         state->last ? " (last)" : ""));
    317                 state->mode = TABLE;
    318                 break;
    319             case 3:
    320                 strm->msg = (char *)"invalid block type";
    321                 state->mode = BAD;
    322             }
    323             DROPBITS(2);
    324             break;
    325 
    326         case STORED:
    327             /* get and verify stored block length */
    328             BYTEBITS();                         /* go to byte boundary */
    329             NEEDBITS(32);
    330             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
    331                 strm->msg = (char *)"invalid stored block lengths";
    332                 state->mode = BAD;
    333                 break;
    334             }
    335             state->length = (unsigned)hold & 0xffff;
    336             Tracev((stderr, "inflate:       stored length %u\n",
    337                     state->length));
    338             INITBITS();
    339 
    340             /* copy stored block from input to output */
    341             while (state->length != 0) {
    342                 copy = state->length;
    343                 PULL();
    344                 ROOM();
    345                 if (copy > have) copy = have;
    346                 if (copy > left) copy = left;
    347                 zmemcpy(put, next, copy);
    348                 have -= copy;
    349                 next += copy;
    350                 left -= copy;
    351                 put += copy;
    352                 state->length -= copy;
    353             }
    354             Tracev((stderr, "inflate:       stored end\n"));
    355             state->mode = TYPE;
    356             break;
    357 
    358         case TABLE:
    359             /* get dynamic table entries descriptor */
    360             NEEDBITS(14);
    361             state->nlen = BITS(5) + 257;
    362             DROPBITS(5);
    363             state->ndist = BITS(5) + 1;
    364             DROPBITS(5);
    365             state->ncode = BITS(4) + 4;
    366             DROPBITS(4);
    367 #ifndef PKZIP_BUG_WORKAROUND
    368             if (state->nlen > 286 || state->ndist > 30) {
    369                 strm->msg = (char *)"too many length or distance symbols";
    370                 state->mode = BAD;
    371                 break;
    372             }
    373 #endif
    374             Tracev((stderr, "inflate:       table sizes ok\n"));
    375 
    376             /* get code length code lengths (not a typo) */
    377             state->have = 0;
    378             while (state->have < state->ncode) {
    379                 NEEDBITS(3);
    380                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
    381                 DROPBITS(3);
    382             }
    383             while (state->have < 19)
    384                 state->lens[order[state->have++]] = 0;
    385             state->next = state->codes;
    386             state->lencode = (code const FAR *)(state->next);
    387             state->lenbits = 7;
    388             ret = inflate_table(CODES, state->lens, 19, &(state->next),
    389                                 &(state->lenbits), state->work);
    390             if (ret) {
    391                 strm->msg = (char *)"invalid code lengths set";
    392                 state->mode = BAD;
    393                 break;
    394             }
    395             Tracev((stderr, "inflate:       code lengths ok\n"));
    396 
    397             /* get length and distance code code lengths */
    398             state->have = 0;
    399             while (state->have < state->nlen + state->ndist) {
    400                 for (;;) {
    401                     here = state->lencode[BITS(state->lenbits)];
    402                     if ((unsigned)(here.bits) <= bits) break;
    403                     PULLBYTE();
    404                 }
    405                 if (here.val < 16) {
    406                     DROPBITS(here.bits);
    407                     state->lens[state->have++] = here.val;
    408                 }
    409                 else {
    410                     if (here.val == 16) {
    411                         NEEDBITS(here.bits + 2);
    412                         DROPBITS(here.bits);
    413                         if (state->have == 0) {
    414                             strm->msg = (char *)"invalid bit length repeat";
    415                             state->mode = BAD;
    416                             break;
    417                         }
    418                         len = (unsigned)(state->lens[state->have - 1]);
    419                         copy = 3 + BITS(2);
    420                         DROPBITS(2);
    421                     }
    422                     else if (here.val == 17) {
    423                         NEEDBITS(here.bits + 3);
    424                         DROPBITS(here.bits);
    425                         len = 0;
    426                         copy = 3 + BITS(3);
    427                         DROPBITS(3);
    428                     }
    429                     else {
    430                         NEEDBITS(here.bits + 7);
    431                         DROPBITS(here.bits);
    432                         len = 0;
    433                         copy = 11 + BITS(7);
    434                         DROPBITS(7);
    435                     }
    436                     if (state->have + copy > state->nlen + state->ndist) {
    437                         strm->msg = (char *)"invalid bit length repeat";
    438                         state->mode = BAD;
    439                         break;
    440                     }
    441                     while (copy--)
    442                         state->lens[state->have++] = (unsigned short)len;
    443                 }
    444             }
    445 
    446             /* handle error breaks in while */
    447             if (state->mode == BAD) break;
    448 
    449             /* check for end-of-block code (better have one) */
    450             if (state->lens[256] == 0) {
    451                 strm->msg = (char *)"invalid code -- missing end-of-block";
    452                 state->mode = BAD;
    453                 break;
    454             }
    455 
    456             /* build code tables -- note: do not change the lenbits or distbits
    457                values here (9 and 6) without reading the comments in inftrees.h
    458                concerning the ENOUGH constants, which depend on those values */
    459             state->next = state->codes;
    460             state->lencode = (code const FAR *)(state->next);
    461             state->lenbits = 9;
    462             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
    463                                 &(state->lenbits), state->work);
    464             if (ret) {
    465                 strm->msg = (char *)"invalid literal/lengths set";
    466                 state->mode = BAD;
    467                 break;
    468             }
    469             state->distcode = (code const FAR *)(state->next);
    470             state->distbits = 6;
    471             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
    472                             &(state->next), &(state->distbits), state->work);
    473             if (ret) {
    474                 strm->msg = (char *)"invalid distances set";
    475                 state->mode = BAD;
    476                 break;
    477             }
    478             Tracev((stderr, "inflate:       codes ok\n"));
    479             state->mode = LEN;
    480 
    481         case LEN:
    482             /* use inflate_fast() if we have enough input and output */
    483             if (have >= 6 && left >= 258) {
    484                 RESTORE();
    485                 if (state->whave < state->wsize)
    486                     state->whave = state->wsize - left;
    487                 inflate_fast(strm, state->wsize);
    488                 LOAD();
    489                 break;
    490             }
    491 
    492             /* get a literal, length, or end-of-block code */
    493             for (;;) {
    494                 here = state->lencode[BITS(state->lenbits)];
    495                 if ((unsigned)(here.bits) <= bits) break;
    496                 PULLBYTE();
    497             }
    498             if (here.op && (here.op & 0xf0) == 0) {
    499                 last = here;
    500                 for (;;) {
    501                     here = state->lencode[last.val +
    502                             (BITS(last.bits + last.op) >> last.bits)];
    503                     if ((unsigned)(last.bits + here.bits) <= bits) break;
    504                     PULLBYTE();
    505                 }
    506                 DROPBITS(last.bits);
    507             }
    508             DROPBITS(here.bits);
    509             state->length = (unsigned)here.val;
    510 
    511             /* process literal */
    512             if (here.op == 0) {
    513                 Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
    514                         "inflate:         literal '%c'\n" :
    515                         "inflate:         literal 0x%02x\n", here.val));
    516                 ROOM();
    517                 *put++ = (unsigned char)(state->length);
    518                 left--;
    519                 state->mode = LEN;
    520                 break;
    521             }
    522 
    523             /* process end of block */
    524             if (here.op & 32) {
    525                 Tracevv((stderr, "inflate:         end of block\n"));
    526                 state->mode = TYPE;
    527                 break;
    528             }
    529 
    530             /* invalid code */
    531             if (here.op & 64) {
    532                 strm->msg = (char *)"invalid literal/length code";
    533                 state->mode = BAD;
    534                 break;
    535             }
    536 
    537             /* length code -- get extra bits, if any */
    538             state->extra = (unsigned)(here.op) & 15;
    539             if (state->extra != 0) {
    540                 NEEDBITS(state->extra);
    541                 state->length += BITS(state->extra);
    542                 DROPBITS(state->extra);
    543             }
    544             Tracevv((stderr, "inflate:         length %u\n", state->length));
    545 
    546             /* get distance code */
    547             for (;;) {
    548                 here = state->distcode[BITS(state->distbits)];
    549                 if ((unsigned)(here.bits) <= bits) break;
    550                 PULLBYTE();
    551             }
    552             if ((here.op & 0xf0) == 0) {
    553                 last = here;
    554                 for (;;) {
    555                     here = state->distcode[last.val +
    556                             (BITS(last.bits + last.op) >> last.bits)];
    557                     if ((unsigned)(last.bits + here.bits) <= bits) break;
    558                     PULLBYTE();
    559                 }
    560                 DROPBITS(last.bits);
    561             }
    562             DROPBITS(here.bits);
    563             if (here.op & 64) {
    564                 strm->msg = (char *)"invalid distance code";
    565                 state->mode = BAD;
    566                 break;
    567             }
    568             state->offset = (unsigned)here.val;
    569 
    570             /* get distance extra bits, if any */
    571             state->extra = (unsigned)(here.op) & 15;
    572             if (state->extra != 0) {
    573                 NEEDBITS(state->extra);
    574                 state->offset += BITS(state->extra);
    575                 DROPBITS(state->extra);
    576             }
    577             if (state->offset > state->wsize - (state->whave < state->wsize ?
    578                                                 left : 0)) {
    579                 strm->msg = (char *)"invalid distance too far back";
    580                 state->mode = BAD;
    581                 break;
    582             }
    583             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
    584 
    585             /* copy match from window to output */
    586             do {
    587                 ROOM();
    588                 copy = state->wsize - state->offset;
    589                 if (copy < left) {
    590                     from = put + copy;
    591                     copy = left - copy;
    592                 }
    593                 else {
    594                     from = put - state->offset;
    595                     copy = left;
    596                 }
    597                 if (copy > state->length) copy = state->length;
    598                 state->length -= copy;
    599                 left -= copy;
    600                 do {
    601                     *put++ = *from++;
    602                 } while (--copy);
    603             } while (state->length != 0);
    604             break;
    605 
    606         case DONE:
    607             /* inflate stream terminated properly -- write leftover output */
    608             ret = Z_STREAM_END;
    609             if (left < state->wsize) {
    610                 if (out(out_desc, state->window, state->wsize - left))
    611                     ret = Z_BUF_ERROR;
    612             }
    613             goto inf_leave;
    614 
    615         case BAD:
    616             ret = Z_DATA_ERROR;
    617             goto inf_leave;
    618 
    619         default:                /* can't happen, but makes compilers happy */
    620             ret = Z_STREAM_ERROR;
    621             goto inf_leave;
    622         }
    623 
    624     /* Return unused input */
    625   inf_leave:
    626     strm->next_in = next;
    627     strm->avail_in = have;
    628     return ret;
    629 }
    630 
    631 int ZEXPORT inflateBackEnd(strm)
    632 z_streamp strm;
    633 {
    634     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
    635         return Z_STREAM_ERROR;
    636     ZFREE(strm, strm->state);
    637     strm->state = Z_NULL;
    638     Tracev((stderr, "inflate: end\n"));
    639     return Z_OK;
    640 }
    641