Home | History | Annotate | Download | only in base64
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // Package base64 implements base64 encoding as specified by RFC 4648.
      6 package base64
      7 
      8 import (
      9 	"io"
     10 	"strconv"
     11 )
     12 
     13 /*
     14  * Encodings
     15  */
     16 
     17 // An Encoding is a radix 64 encoding/decoding scheme, defined by a
     18 // 64-character alphabet.  The most common encoding is the "base64"
     19 // encoding defined in RFC 4648 and used in MIME (RFC 2045) and PEM
     20 // (RFC 1421).  RFC 4648 also defines an alternate encoding, which is
     21 // the standard encoding with - and _ substituted for + and /.
     22 type Encoding struct {
     23 	encode    [64]byte
     24 	decodeMap [256]byte
     25 	padChar   rune
     26 }
     27 
     28 const (
     29 	StdPadding rune = '=' // Standard padding character
     30 	NoPadding  rune = -1  // No padding
     31 )
     32 
     33 const encodeStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
     34 const encodeURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
     35 
     36 // NewEncoding returns a new padded Encoding defined by the given alphabet,
     37 // which must be a 64-byte string.
     38 // The resulting Encoding uses the default padding character ('='),
     39 // which may be changed or disabled via WithPadding.
     40 func NewEncoding(encoder string) *Encoding {
     41 	if len(encoder) != 64 {
     42 		panic("encoding alphabet is not 64-bytes long")
     43 	}
     44 
     45 	e := new(Encoding)
     46 	e.padChar = StdPadding
     47 	copy(e.encode[:], encoder)
     48 
     49 	for i := 0; i < len(e.decodeMap); i++ {
     50 		e.decodeMap[i] = 0xFF
     51 	}
     52 	for i := 0; i < len(encoder); i++ {
     53 		e.decodeMap[encoder[i]] = byte(i)
     54 	}
     55 	return e
     56 }
     57 
     58 // WithPadding creates a new encoding identical to enc except
     59 // with a specified padding character, or NoPadding to disable padding.
     60 func (enc Encoding) WithPadding(padding rune) *Encoding {
     61 	enc.padChar = padding
     62 	return &enc
     63 }
     64 
     65 // StdEncoding is the standard base64 encoding, as defined in
     66 // RFC 4648.
     67 var StdEncoding = NewEncoding(encodeStd)
     68 
     69 // URLEncoding is the alternate base64 encoding defined in RFC 4648.
     70 // It is typically used in URLs and file names.
     71 var URLEncoding = NewEncoding(encodeURL)
     72 
     73 // RawStdEncoding is the standard raw, unpadded base64 encoding,
     74 // as defined in RFC 4648 section 3.2.
     75 // This is the same as StdEncoding but omits padding characters.
     76 var RawStdEncoding = StdEncoding.WithPadding(NoPadding)
     77 
     78 // URLEncoding is the unpadded alternate base64 encoding defined in RFC 4648.
     79 // It is typically used in URLs and file names.
     80 // This is the same as URLEncoding but omits padding characters.
     81 var RawURLEncoding = URLEncoding.WithPadding(NoPadding)
     82 
     83 /*
     84  * Encoder
     85  */
     86 
     87 // Encode encodes src using the encoding enc, writing
     88 // EncodedLen(len(src)) bytes to dst.
     89 //
     90 // The encoding pads the output to a multiple of 4 bytes,
     91 // so Encode is not appropriate for use on individual blocks
     92 // of a large data stream.  Use NewEncoder() instead.
     93 func (enc *Encoding) Encode(dst, src []byte) {
     94 	if len(src) == 0 {
     95 		return
     96 	}
     97 
     98 	di, si := 0, 0
     99 	n := (len(src) / 3) * 3
    100 	for si < n {
    101 		// Convert 3x 8bit source bytes into 4 bytes
    102 		val := uint(src[si+0])<<16 | uint(src[si+1])<<8 | uint(src[si+2])
    103 
    104 		dst[di+0] = enc.encode[val>>18&0x3F]
    105 		dst[di+1] = enc.encode[val>>12&0x3F]
    106 		dst[di+2] = enc.encode[val>>6&0x3F]
    107 		dst[di+3] = enc.encode[val&0x3F]
    108 
    109 		si += 3
    110 		di += 4
    111 	}
    112 
    113 	remain := len(src) - si
    114 	if remain == 0 {
    115 		return
    116 	}
    117 	// Add the remaining small block
    118 	val := uint(src[si+0]) << 16
    119 	if remain == 2 {
    120 		val |= uint(src[si+1]) << 8
    121 	}
    122 
    123 	dst[di+0] = enc.encode[val>>18&0x3F]
    124 	dst[di+1] = enc.encode[val>>12&0x3F]
    125 
    126 	switch remain {
    127 	case 2:
    128 		dst[di+2] = enc.encode[val>>6&0x3F]
    129 		if enc.padChar != NoPadding {
    130 			dst[di+3] = byte(enc.padChar)
    131 		}
    132 	case 1:
    133 		if enc.padChar != NoPadding {
    134 			dst[di+2] = byte(enc.padChar)
    135 			dst[di+3] = byte(enc.padChar)
    136 		}
    137 	}
    138 }
    139 
    140 // EncodeToString returns the base64 encoding of src.
    141 func (enc *Encoding) EncodeToString(src []byte) string {
    142 	buf := make([]byte, enc.EncodedLen(len(src)))
    143 	enc.Encode(buf, src)
    144 	return string(buf)
    145 }
    146 
    147 type encoder struct {
    148 	err  error
    149 	enc  *Encoding
    150 	w    io.Writer
    151 	buf  [3]byte    // buffered data waiting to be encoded
    152 	nbuf int        // number of bytes in buf
    153 	out  [1024]byte // output buffer
    154 }
    155 
    156 func (e *encoder) Write(p []byte) (n int, err error) {
    157 	if e.err != nil {
    158 		return 0, e.err
    159 	}
    160 
    161 	// Leading fringe.
    162 	if e.nbuf > 0 {
    163 		var i int
    164 		for i = 0; i < len(p) && e.nbuf < 3; i++ {
    165 			e.buf[e.nbuf] = p[i]
    166 			e.nbuf++
    167 		}
    168 		n += i
    169 		p = p[i:]
    170 		if e.nbuf < 3 {
    171 			return
    172 		}
    173 		e.enc.Encode(e.out[:], e.buf[:])
    174 		if _, e.err = e.w.Write(e.out[:4]); e.err != nil {
    175 			return n, e.err
    176 		}
    177 		e.nbuf = 0
    178 	}
    179 
    180 	// Large interior chunks.
    181 	for len(p) >= 3 {
    182 		nn := len(e.out) / 4 * 3
    183 		if nn > len(p) {
    184 			nn = len(p)
    185 			nn -= nn % 3
    186 		}
    187 		e.enc.Encode(e.out[:], p[:nn])
    188 		if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil {
    189 			return n, e.err
    190 		}
    191 		n += nn
    192 		p = p[nn:]
    193 	}
    194 
    195 	// Trailing fringe.
    196 	for i := 0; i < len(p); i++ {
    197 		e.buf[i] = p[i]
    198 	}
    199 	e.nbuf = len(p)
    200 	n += len(p)
    201 	return
    202 }
    203 
    204 // Close flushes any pending output from the encoder.
    205 // It is an error to call Write after calling Close.
    206 func (e *encoder) Close() error {
    207 	// If there's anything left in the buffer, flush it out
    208 	if e.err == nil && e.nbuf > 0 {
    209 		e.enc.Encode(e.out[:], e.buf[:e.nbuf])
    210 		_, e.err = e.w.Write(e.out[:e.enc.EncodedLen(e.nbuf)])
    211 		e.nbuf = 0
    212 	}
    213 	return e.err
    214 }
    215 
    216 // NewEncoder returns a new base64 stream encoder.  Data written to
    217 // the returned writer will be encoded using enc and then written to w.
    218 // Base64 encodings operate in 4-byte blocks; when finished
    219 // writing, the caller must Close the returned encoder to flush any
    220 // partially written blocks.
    221 func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
    222 	return &encoder{enc: enc, w: w}
    223 }
    224 
    225 // EncodedLen returns the length in bytes of the base64 encoding
    226 // of an input buffer of length n.
    227 func (enc *Encoding) EncodedLen(n int) int {
    228 	if enc.padChar == NoPadding {
    229 		return (n*8 + 5) / 6 // minimum # chars at 6 bits per char
    230 	}
    231 	return (n + 2) / 3 * 4 // minimum # 4-char quanta, 3 bytes each
    232 }
    233 
    234 /*
    235  * Decoder
    236  */
    237 
    238 type CorruptInputError int64
    239 
    240 func (e CorruptInputError) Error() string {
    241 	return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10)
    242 }
    243 
    244 // decode is like Decode but returns an additional 'end' value, which
    245 // indicates if end-of-message padding or a partial quantum was encountered
    246 // and thus any additional data is an error.
    247 func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
    248 	si := 0
    249 
    250 	// skip over newlines
    251 	for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
    252 		si++
    253 	}
    254 
    255 	for si < len(src) && !end {
    256 		// Decode quantum using the base64 alphabet
    257 		var dbuf [4]byte
    258 		dinc, dlen := 3, 4
    259 
    260 		for j := range dbuf {
    261 			if len(src) == si {
    262 				if enc.padChar != NoPadding || j < 2 {
    263 					return n, false, CorruptInputError(si - j)
    264 				}
    265 				dinc, dlen, end = j-1, j, true
    266 				break
    267 			}
    268 			in := src[si]
    269 
    270 			si++
    271 			// skip over newlines
    272 			for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
    273 				si++
    274 			}
    275 
    276 			if rune(in) == enc.padChar {
    277 				// We've reached the end and there's padding
    278 				switch j {
    279 				case 0, 1:
    280 					// incorrect padding
    281 					return n, false, CorruptInputError(si - 1)
    282 				case 2:
    283 					// "==" is expected, the first "=" is already consumed.
    284 					if si == len(src) {
    285 						// not enough padding
    286 						return n, false, CorruptInputError(len(src))
    287 					}
    288 					if rune(src[si]) != enc.padChar {
    289 						// incorrect padding
    290 						return n, false, CorruptInputError(si - 1)
    291 					}
    292 
    293 					si++
    294 					// skip over newlines
    295 					for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
    296 						si++
    297 					}
    298 				}
    299 				if si < len(src) {
    300 					// trailing garbage
    301 					err = CorruptInputError(si)
    302 				}
    303 				dinc, dlen, end = 3, j, true
    304 				break
    305 			}
    306 			dbuf[j] = enc.decodeMap[in]
    307 			if dbuf[j] == 0xFF {
    308 				return n, false, CorruptInputError(si - 1)
    309 			}
    310 		}
    311 
    312 		// Convert 4x 6bit source bytes into 3 bytes
    313 		val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
    314 		switch dlen {
    315 		case 4:
    316 			dst[2] = byte(val >> 0)
    317 			fallthrough
    318 		case 3:
    319 			dst[1] = byte(val >> 8)
    320 			fallthrough
    321 		case 2:
    322 			dst[0] = byte(val >> 16)
    323 		}
    324 		dst = dst[dinc:]
    325 		n += dlen - 1
    326 	}
    327 
    328 	return n, end, err
    329 }
    330 
    331 // Decode decodes src using the encoding enc.  It writes at most
    332 // DecodedLen(len(src)) bytes to dst and returns the number of bytes
    333 // written.  If src contains invalid base64 data, it will return the
    334 // number of bytes successfully written and CorruptInputError.
    335 // New line characters (\r and \n) are ignored.
    336 func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
    337 	n, _, err = enc.decode(dst, src)
    338 	return
    339 }
    340 
    341 // DecodeString returns the bytes represented by the base64 string s.
    342 func (enc *Encoding) DecodeString(s string) ([]byte, error) {
    343 	dbuf := make([]byte, enc.DecodedLen(len(s)))
    344 	n, _, err := enc.decode(dbuf, []byte(s))
    345 	return dbuf[:n], err
    346 }
    347 
    348 type decoder struct {
    349 	err    error
    350 	enc    *Encoding
    351 	r      io.Reader
    352 	end    bool       // saw end of message
    353 	buf    [1024]byte // leftover input
    354 	nbuf   int
    355 	out    []byte // leftover decoded output
    356 	outbuf [1024 / 4 * 3]byte
    357 }
    358 
    359 func (d *decoder) Read(p []byte) (n int, err error) {
    360 	if d.err != nil {
    361 		return 0, d.err
    362 	}
    363 
    364 	// Use leftover decoded output from last read.
    365 	if len(d.out) > 0 {
    366 		n = copy(p, d.out)
    367 		d.out = d.out[n:]
    368 		return n, nil
    369 	}
    370 
    371 	// This code assumes that d.r strips supported whitespace ('\r' and '\n').
    372 
    373 	// Read a chunk.
    374 	nn := len(p) / 3 * 4
    375 	if nn < 4 {
    376 		nn = 4
    377 	}
    378 	if nn > len(d.buf) {
    379 		nn = len(d.buf)
    380 	}
    381 	nn, d.err = io.ReadAtLeast(d.r, d.buf[d.nbuf:nn], 4-d.nbuf)
    382 	d.nbuf += nn
    383 	if d.err != nil || d.nbuf < 4 {
    384 		return 0, d.err
    385 	}
    386 
    387 	// Decode chunk into p, or d.out and then p if p is too small.
    388 	nr := d.nbuf / 4 * 4
    389 	nw := d.nbuf / 4 * 3
    390 	if nw > len(p) {
    391 		nw, d.end, d.err = d.enc.decode(d.outbuf[:], d.buf[:nr])
    392 		d.out = d.outbuf[:nw]
    393 		n = copy(p, d.out)
    394 		d.out = d.out[n:]
    395 	} else {
    396 		n, d.end, d.err = d.enc.decode(p, d.buf[:nr])
    397 	}
    398 	d.nbuf -= nr
    399 	for i := 0; i < d.nbuf; i++ {
    400 		d.buf[i] = d.buf[i+nr]
    401 	}
    402 
    403 	if d.err == nil {
    404 		d.err = err
    405 	}
    406 	return n, d.err
    407 }
    408 
    409 type newlineFilteringReader struct {
    410 	wrapped io.Reader
    411 }
    412 
    413 func (r *newlineFilteringReader) Read(p []byte) (int, error) {
    414 	n, err := r.wrapped.Read(p)
    415 	for n > 0 {
    416 		offset := 0
    417 		for i, b := range p[:n] {
    418 			if b != '\r' && b != '\n' {
    419 				if i != offset {
    420 					p[offset] = b
    421 				}
    422 				offset++
    423 			}
    424 		}
    425 		if offset > 0 {
    426 			return offset, err
    427 		}
    428 		// Previous buffer entirely whitespace, read again
    429 		n, err = r.wrapped.Read(p)
    430 	}
    431 	return n, err
    432 }
    433 
    434 // NewDecoder constructs a new base64 stream decoder.
    435 func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
    436 	return &decoder{enc: enc, r: &newlineFilteringReader{r}}
    437 }
    438 
    439 // DecodedLen returns the maximum length in bytes of the decoded data
    440 // corresponding to n bytes of base64-encoded data.
    441 func (enc *Encoding) DecodedLen(n int) int {
    442 	if enc.padChar == NoPadding {
    443 		// Unpadded data may end with partial block of 2-3 characters.
    444 		return (n*6 + 7) / 8
    445 	}
    446 	// Padded base64 should always be a multiple of 4 characters in length.
    447 	return n / 4 * 3
    448 }
    449