Home | History | Annotate | Download | only in jpeg
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package jpeg
      6 
      7 import (
      8 	"io"
      9 )
     10 
     11 // maxCodeLength is the maximum (inclusive) number of bits in a Huffman code.
     12 const maxCodeLength = 16
     13 
     14 // maxNCodes is the maximum (inclusive) number of codes in a Huffman tree.
     15 const maxNCodes = 256
     16 
     17 // lutSize is the log-2 size of the Huffman decoder's look-up table.
     18 const lutSize = 8
     19 
     20 // huffman is a Huffman decoder, specified in section C.
     21 type huffman struct {
     22 	// length is the number of codes in the tree.
     23 	nCodes int32
     24 	// lut is the look-up table for the next lutSize bits in the bit-stream.
     25 	// The high 8 bits of the uint16 are the encoded value. The low 8 bits
     26 	// are 1 plus the code length, or 0 if the value is too large to fit in
     27 	// lutSize bits.
     28 	lut [1 << lutSize]uint16
     29 	// vals are the decoded values, sorted by their encoding.
     30 	vals [maxNCodes]uint8
     31 	// minCodes[i] is the minimum code of length i, or -1 if there are no
     32 	// codes of that length.
     33 	minCodes [maxCodeLength]int32
     34 	// maxCodes[i] is the maximum code of length i, or -1 if there are no
     35 	// codes of that length.
     36 	maxCodes [maxCodeLength]int32
     37 	// valsIndices[i] is the index into vals of minCodes[i].
     38 	valsIndices [maxCodeLength]int32
     39 }
     40 
     41 // errShortHuffmanData means that an unexpected EOF occurred while decoding
     42 // Huffman data.
     43 var errShortHuffmanData = FormatError("short Huffman data")
     44 
     45 // ensureNBits reads bytes from the byte buffer to ensure that d.bits.n is at
     46 // least n. For best performance (avoiding function calls inside hot loops),
     47 // the caller is the one responsible for first checking that d.bits.n < n.
     48 func (d *decoder) ensureNBits(n int32) error {
     49 	for {
     50 		c, err := d.readByteStuffedByte()
     51 		if err != nil {
     52 			if err == io.EOF {
     53 				return errShortHuffmanData
     54 			}
     55 			return err
     56 		}
     57 		d.bits.a = d.bits.a<<8 | uint32(c)
     58 		d.bits.n += 8
     59 		if d.bits.m == 0 {
     60 			d.bits.m = 1 << 7
     61 		} else {
     62 			d.bits.m <<= 8
     63 		}
     64 		if d.bits.n >= n {
     65 			break
     66 		}
     67 	}
     68 	return nil
     69 }
     70 
     71 // receiveExtend is the composition of RECEIVE and EXTEND, specified in section
     72 // F.2.2.1.
     73 func (d *decoder) receiveExtend(t uint8) (int32, error) {
     74 	if d.bits.n < int32(t) {
     75 		if err := d.ensureNBits(int32(t)); err != nil {
     76 			return 0, err
     77 		}
     78 	}
     79 	d.bits.n -= int32(t)
     80 	d.bits.m >>= t
     81 	s := int32(1) << t
     82 	x := int32(d.bits.a>>uint8(d.bits.n)) & (s - 1)
     83 	if x < s>>1 {
     84 		x += ((-1) << t) + 1
     85 	}
     86 	return x, nil
     87 }
     88 
     89 // processDHT processes a Define Huffman Table marker, and initializes a huffman
     90 // struct from its contents. Specified in section B.2.4.2.
     91 func (d *decoder) processDHT(n int) error {
     92 	for n > 0 {
     93 		if n < 17 {
     94 			return FormatError("DHT has wrong length")
     95 		}
     96 		if err := d.readFull(d.tmp[:17]); err != nil {
     97 			return err
     98 		}
     99 		tc := d.tmp[0] >> 4
    100 		if tc > maxTc {
    101 			return FormatError("bad Tc value")
    102 		}
    103 		th := d.tmp[0] & 0x0f
    104 		// The baseline th <= 1 restriction is specified in table B.5.
    105 		if th > maxTh || (d.baseline && th > 1) {
    106 			return FormatError("bad Th value")
    107 		}
    108 		h := &d.huff[tc][th]
    109 
    110 		// Read nCodes and h.vals (and derive h.nCodes).
    111 		// nCodes[i] is the number of codes with code length i.
    112 		// h.nCodes is the total number of codes.
    113 		h.nCodes = 0
    114 		var nCodes [maxCodeLength]int32
    115 		for i := range nCodes {
    116 			nCodes[i] = int32(d.tmp[i+1])
    117 			h.nCodes += nCodes[i]
    118 		}
    119 		if h.nCodes == 0 {
    120 			return FormatError("Huffman table has zero length")
    121 		}
    122 		if h.nCodes > maxNCodes {
    123 			return FormatError("Huffman table has excessive length")
    124 		}
    125 		n -= int(h.nCodes) + 17
    126 		if n < 0 {
    127 			return FormatError("DHT has wrong length")
    128 		}
    129 		if err := d.readFull(h.vals[:h.nCodes]); err != nil {
    130 			return err
    131 		}
    132 
    133 		// Derive the look-up table.
    134 		for i := range h.lut {
    135 			h.lut[i] = 0
    136 		}
    137 		var x, code uint32
    138 		for i := uint32(0); i < lutSize; i++ {
    139 			code <<= 1
    140 			for j := int32(0); j < nCodes[i]; j++ {
    141 				// The codeLength is 1+i, so shift code by 8-(1+i) to
    142 				// calculate the high bits for every 8-bit sequence
    143 				// whose codeLength's high bits matches code.
    144 				// The high 8 bits of lutValue are the encoded value.
    145 				// The low 8 bits are 1 plus the codeLength.
    146 				base := uint8(code << (7 - i))
    147 				lutValue := uint16(h.vals[x])<<8 | uint16(2+i)
    148 				for k := uint8(0); k < 1<<(7-i); k++ {
    149 					h.lut[base|k] = lutValue
    150 				}
    151 				code++
    152 				x++
    153 			}
    154 		}
    155 
    156 		// Derive minCodes, maxCodes, and valsIndices.
    157 		var c, index int32
    158 		for i, n := range nCodes {
    159 			if n == 0 {
    160 				h.minCodes[i] = -1
    161 				h.maxCodes[i] = -1
    162 				h.valsIndices[i] = -1
    163 			} else {
    164 				h.minCodes[i] = c
    165 				h.maxCodes[i] = c + n - 1
    166 				h.valsIndices[i] = index
    167 				c += n
    168 				index += n
    169 			}
    170 			c <<= 1
    171 		}
    172 	}
    173 	return nil
    174 }
    175 
    176 // decodeHuffman returns the next Huffman-coded value from the bit-stream,
    177 // decoded according to h.
    178 func (d *decoder) decodeHuffman(h *huffman) (uint8, error) {
    179 	if h.nCodes == 0 {
    180 		return 0, FormatError("uninitialized Huffman table")
    181 	}
    182 
    183 	if d.bits.n < 8 {
    184 		if err := d.ensureNBits(8); err != nil {
    185 			if err != errMissingFF00 && err != errShortHuffmanData {
    186 				return 0, err
    187 			}
    188 			// There are no more bytes of data in this segment, but we may still
    189 			// be able to read the next symbol out of the previously read bits.
    190 			// First, undo the readByte that the ensureNBits call made.
    191 			if d.bytes.nUnreadable != 0 {
    192 				d.unreadByteStuffedByte()
    193 			}
    194 			goto slowPath
    195 		}
    196 	}
    197 	if v := h.lut[(d.bits.a>>uint32(d.bits.n-lutSize))&0xff]; v != 0 {
    198 		n := (v & 0xff) - 1
    199 		d.bits.n -= int32(n)
    200 		d.bits.m >>= n
    201 		return uint8(v >> 8), nil
    202 	}
    203 
    204 slowPath:
    205 	for i, code := 0, int32(0); i < maxCodeLength; i++ {
    206 		if d.bits.n == 0 {
    207 			if err := d.ensureNBits(1); err != nil {
    208 				return 0, err
    209 			}
    210 		}
    211 		if d.bits.a&d.bits.m != 0 {
    212 			code |= 1
    213 		}
    214 		d.bits.n--
    215 		d.bits.m >>= 1
    216 		if code <= h.maxCodes[i] {
    217 			return h.vals[h.valsIndices[i]+code-h.minCodes[i]], nil
    218 		}
    219 		code <<= 1
    220 	}
    221 	return 0, FormatError("bad Huffman code")
    222 }
    223 
    224 func (d *decoder) decodeBit() (bool, error) {
    225 	if d.bits.n == 0 {
    226 		if err := d.ensureNBits(1); err != nil {
    227 			return false, err
    228 		}
    229 	}
    230 	ret := d.bits.a&d.bits.m != 0
    231 	d.bits.n--
    232 	d.bits.m >>= 1
    233 	return ret, nil
    234 }
    235 
    236 func (d *decoder) decodeBits(n int32) (uint32, error) {
    237 	if d.bits.n < n {
    238 		if err := d.ensureNBits(n); err != nil {
    239 			return 0, err
    240 		}
    241 	}
    242 	ret := d.bits.a >> uint32(d.bits.n-n)
    243 	ret &= (1 << uint32(n)) - 1
    244 	d.bits.n -= n
    245 	d.bits.m >>= uint32(n)
    246 	return ret, nil
    247 }
    248