Home | History | Annotate | Download | only in lzw
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package lzw
      6 
      7 import (
      8 	"bufio"
      9 	"errors"
     10 	"fmt"
     11 	"io"
     12 )
     13 
     14 // A writer is a buffered, flushable writer.
     15 type writer interface {
     16 	io.ByteWriter
     17 	Flush() error
     18 }
     19 
     20 // An errWriteCloser is an io.WriteCloser that always returns a given error.
     21 type errWriteCloser struct {
     22 	err error
     23 }
     24 
     25 func (e *errWriteCloser) Write([]byte) (int, error) {
     26 	return 0, e.err
     27 }
     28 
     29 func (e *errWriteCloser) Close() error {
     30 	return e.err
     31 }
     32 
     33 const (
     34 	// A code is a 12 bit value, stored as a uint32 when encoding to avoid
     35 	// type conversions when shifting bits.
     36 	maxCode     = 1<<12 - 1
     37 	invalidCode = 1<<32 - 1
     38 	// There are 1<<12 possible codes, which is an upper bound on the number of
     39 	// valid hash table entries at any given point in time. tableSize is 4x that.
     40 	tableSize = 4 * 1 << 12
     41 	tableMask = tableSize - 1
     42 	// A hash table entry is a uint32. Zero is an invalid entry since the
     43 	// lower 12 bits of a valid entry must be a non-literal code.
     44 	invalidEntry = 0
     45 )
     46 
     47 // encoder is LZW compressor.
     48 type encoder struct {
     49 	// w is the writer that compressed bytes are written to.
     50 	w writer
     51 	// order, write, bits, nBits and width are the state for
     52 	// converting a code stream into a byte stream.
     53 	order Order
     54 	write func(*encoder, uint32) error
     55 	bits  uint32
     56 	nBits uint
     57 	width uint
     58 	// litWidth is the width in bits of literal codes.
     59 	litWidth uint
     60 	// hi is the code implied by the next code emission.
     61 	// overflow is the code at which hi overflows the code width.
     62 	hi, overflow uint32
     63 	// savedCode is the accumulated code at the end of the most recent Write
     64 	// call. It is equal to invalidCode if there was no such call.
     65 	savedCode uint32
     66 	// err is the first error encountered during writing. Closing the encoder
     67 	// will make any future Write calls return errClosed
     68 	err error
     69 	// table is the hash table from 20-bit keys to 12-bit values. Each table
     70 	// entry contains key<<12|val and collisions resolve by linear probing.
     71 	// The keys consist of a 12-bit code prefix and an 8-bit byte suffix.
     72 	// The values are a 12-bit code.
     73 	table [tableSize]uint32
     74 }
     75 
     76 // writeLSB writes the code c for "Least Significant Bits first" data.
     77 func (e *encoder) writeLSB(c uint32) error {
     78 	e.bits |= c << e.nBits
     79 	e.nBits += e.width
     80 	for e.nBits >= 8 {
     81 		if err := e.w.WriteByte(uint8(e.bits)); err != nil {
     82 			return err
     83 		}
     84 		e.bits >>= 8
     85 		e.nBits -= 8
     86 	}
     87 	return nil
     88 }
     89 
     90 // writeMSB writes the code c for "Most Significant Bits first" data.
     91 func (e *encoder) writeMSB(c uint32) error {
     92 	e.bits |= c << (32 - e.width - e.nBits)
     93 	e.nBits += e.width
     94 	for e.nBits >= 8 {
     95 		if err := e.w.WriteByte(uint8(e.bits >> 24)); err != nil {
     96 			return err
     97 		}
     98 		e.bits <<= 8
     99 		e.nBits -= 8
    100 	}
    101 	return nil
    102 }
    103 
    104 // errOutOfCodes is an internal error that means that the encoder has run out
    105 // of unused codes and a clear code needs to be sent next.
    106 var errOutOfCodes = errors.New("lzw: out of codes")
    107 
    108 // incHi increments e.hi and checks for both overflow and running out of
    109 // unused codes. In the latter case, incHi sends a clear code, resets the
    110 // encoder state and returns errOutOfCodes.
    111 func (e *encoder) incHi() error {
    112 	e.hi++
    113 	if e.hi == e.overflow {
    114 		e.width++
    115 		e.overflow <<= 1
    116 	}
    117 	if e.hi == maxCode {
    118 		clear := uint32(1) << e.litWidth
    119 		if err := e.write(e, clear); err != nil {
    120 			return err
    121 		}
    122 		e.width = e.litWidth + 1
    123 		e.hi = clear + 1
    124 		e.overflow = clear << 1
    125 		for i := range e.table {
    126 			e.table[i] = invalidEntry
    127 		}
    128 		return errOutOfCodes
    129 	}
    130 	return nil
    131 }
    132 
    133 // Write writes a compressed representation of p to e's underlying writer.
    134 func (e *encoder) Write(p []byte) (n int, err error) {
    135 	if e.err != nil {
    136 		return 0, e.err
    137 	}
    138 	if len(p) == 0 {
    139 		return 0, nil
    140 	}
    141 	if maxLit := uint8(1<<e.litWidth - 1); maxLit != 0xff {
    142 		for _, x := range p {
    143 			if x > maxLit {
    144 				e.err = errors.New("lzw: input byte too large for the litWidth")
    145 				return 0, e.err
    146 			}
    147 		}
    148 	}
    149 	n = len(p)
    150 	code := e.savedCode
    151 	if code == invalidCode {
    152 		// The first code sent is always a literal code.
    153 		code, p = uint32(p[0]), p[1:]
    154 	}
    155 loop:
    156 	for _, x := range p {
    157 		literal := uint32(x)
    158 		key := code<<8 | literal
    159 		// If there is a hash table hit for this key then we continue the loop
    160 		// and do not emit a code yet.
    161 		hash := (key>>12 ^ key) & tableMask
    162 		for h, t := hash, e.table[hash]; t != invalidEntry; {
    163 			if key == t>>12 {
    164 				code = t & maxCode
    165 				continue loop
    166 			}
    167 			h = (h + 1) & tableMask
    168 			t = e.table[h]
    169 		}
    170 		// Otherwise, write the current code, and literal becomes the start of
    171 		// the next emitted code.
    172 		if e.err = e.write(e, code); e.err != nil {
    173 			return 0, e.err
    174 		}
    175 		code = literal
    176 		// Increment e.hi, the next implied code. If we run out of codes, reset
    177 		// the encoder state (including clearing the hash table) and continue.
    178 		if err1 := e.incHi(); err1 != nil {
    179 			if err1 == errOutOfCodes {
    180 				continue
    181 			}
    182 			e.err = err1
    183 			return 0, e.err
    184 		}
    185 		// Otherwise, insert key -> e.hi into the map that e.table represents.
    186 		for {
    187 			if e.table[hash] == invalidEntry {
    188 				e.table[hash] = (key << 12) | e.hi
    189 				break
    190 			}
    191 			hash = (hash + 1) & tableMask
    192 		}
    193 	}
    194 	e.savedCode = code
    195 	return n, nil
    196 }
    197 
    198 // Close closes the encoder, flushing any pending output. It does not close or
    199 // flush e's underlying writer.
    200 func (e *encoder) Close() error {
    201 	if e.err != nil {
    202 		if e.err == errClosed {
    203 			return nil
    204 		}
    205 		return e.err
    206 	}
    207 	// Make any future calls to Write return errClosed.
    208 	e.err = errClosed
    209 	// Write the savedCode if valid.
    210 	if e.savedCode != invalidCode {
    211 		if err := e.write(e, e.savedCode); err != nil {
    212 			return err
    213 		}
    214 		if err := e.incHi(); err != nil && err != errOutOfCodes {
    215 			return err
    216 		}
    217 	}
    218 	// Write the eof code.
    219 	eof := uint32(1)<<e.litWidth + 1
    220 	if err := e.write(e, eof); err != nil {
    221 		return err
    222 	}
    223 	// Write the final bits.
    224 	if e.nBits > 0 {
    225 		if e.order == MSB {
    226 			e.bits >>= 24
    227 		}
    228 		if err := e.w.WriteByte(uint8(e.bits)); err != nil {
    229 			return err
    230 		}
    231 	}
    232 	return e.w.Flush()
    233 }
    234 
    235 // NewWriter creates a new io.WriteCloser.
    236 // Writes to the returned io.WriteCloser are compressed and written to w.
    237 // It is the caller's responsibility to call Close on the WriteCloser when
    238 // finished writing.
    239 // The number of bits to use for literal codes, litWidth, must be in the
    240 // range [2,8] and is typically 8. Input bytes must be less than 1<<litWidth.
    241 func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
    242 	var write func(*encoder, uint32) error
    243 	switch order {
    244 	case LSB:
    245 		write = (*encoder).writeLSB
    246 	case MSB:
    247 		write = (*encoder).writeMSB
    248 	default:
    249 		return &errWriteCloser{errors.New("lzw: unknown order")}
    250 	}
    251 	if litWidth < 2 || 8 < litWidth {
    252 		return &errWriteCloser{fmt.Errorf("lzw: litWidth %d out of range", litWidth)}
    253 	}
    254 	bw, ok := w.(writer)
    255 	if !ok {
    256 		bw = bufio.NewWriter(w)
    257 	}
    258 	lw := uint(litWidth)
    259 	return &encoder{
    260 		w:         bw,
    261 		order:     order,
    262 		write:     write,
    263 		width:     1 + lw,
    264 		litWidth:  lw,
    265 		hi:        1<<lw + 1,
    266 		overflow:  1 << (lw + 1),
    267 		savedCode: invalidCode,
    268 	}
    269 }
    270