Home | History | Annotate | Download | only in multipart
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package multipart
      6 
      7 import (
      8 	"bytes"
      9 	"crypto/rand"
     10 	"errors"
     11 	"fmt"
     12 	"io"
     13 	"net/textproto"
     14 	"sort"
     15 	"strings"
     16 )
     17 
     18 // A Writer generates multipart messages.
     19 type Writer struct {
     20 	w        io.Writer
     21 	boundary string
     22 	lastpart *part
     23 }
     24 
     25 // NewWriter returns a new multipart Writer with a random boundary,
     26 // writing to w.
     27 func NewWriter(w io.Writer) *Writer {
     28 	return &Writer{
     29 		w:        w,
     30 		boundary: randomBoundary(),
     31 	}
     32 }
     33 
     34 // Boundary returns the Writer's boundary.
     35 func (w *Writer) Boundary() string {
     36 	return w.boundary
     37 }
     38 
     39 // SetBoundary overrides the Writer's default randomly-generated
     40 // boundary separator with an explicit value.
     41 //
     42 // SetBoundary must be called before any parts are created, may only
     43 // contain certain ASCII characters, and must be non-empty and
     44 // at most 69 bytes long.
     45 func (w *Writer) SetBoundary(boundary string) error {
     46 	if w.lastpart != nil {
     47 		return errors.New("mime: SetBoundary called after write")
     48 	}
     49 	// rfc2046#section-5.1.1
     50 	if len(boundary) < 1 || len(boundary) > 69 {
     51 		return errors.New("mime: invalid boundary length")
     52 	}
     53 	for _, b := range boundary {
     54 		if 'A' <= b && b <= 'Z' || 'a' <= b && b <= 'z' || '0' <= b && b <= '9' {
     55 			continue
     56 		}
     57 		switch b {
     58 		case '\'', '(', ')', '+', '_', ',', '-', '.', '/', ':', '=', '?':
     59 			continue
     60 		}
     61 		return errors.New("mime: invalid boundary character")
     62 	}
     63 	w.boundary = boundary
     64 	return nil
     65 }
     66 
     67 // FormDataContentType returns the Content-Type for an HTTP
     68 // multipart/form-data with this Writer's Boundary.
     69 func (w *Writer) FormDataContentType() string {
     70 	return "multipart/form-data; boundary=" + w.boundary
     71 }
     72 
     73 func randomBoundary() string {
     74 	var buf [30]byte
     75 	_, err := io.ReadFull(rand.Reader, buf[:])
     76 	if err != nil {
     77 		panic(err)
     78 	}
     79 	return fmt.Sprintf("%x", buf[:])
     80 }
     81 
     82 // CreatePart creates a new multipart section with the provided
     83 // header. The body of the part should be written to the returned
     84 // Writer. After calling CreatePart, any previous part may no longer
     85 // be written to.
     86 func (w *Writer) CreatePart(header textproto.MIMEHeader) (io.Writer, error) {
     87 	if w.lastpart != nil {
     88 		if err := w.lastpart.close(); err != nil {
     89 			return nil, err
     90 		}
     91 	}
     92 	var b bytes.Buffer
     93 	if w.lastpart != nil {
     94 		fmt.Fprintf(&b, "\r\n--%s\r\n", w.boundary)
     95 	} else {
     96 		fmt.Fprintf(&b, "--%s\r\n", w.boundary)
     97 	}
     98 
     99 	keys := make([]string, 0, len(header))
    100 	for k := range header {
    101 		keys = append(keys, k)
    102 	}
    103 	sort.Strings(keys)
    104 	for _, k := range keys {
    105 		for _, v := range header[k] {
    106 			fmt.Fprintf(&b, "%s: %s\r\n", k, v)
    107 		}
    108 	}
    109 	fmt.Fprintf(&b, "\r\n")
    110 	_, err := io.Copy(w.w, &b)
    111 	if err != nil {
    112 		return nil, err
    113 	}
    114 	p := &part{
    115 		mw: w,
    116 	}
    117 	w.lastpart = p
    118 	return p, nil
    119 }
    120 
    121 var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
    122 
    123 func escapeQuotes(s string) string {
    124 	return quoteEscaper.Replace(s)
    125 }
    126 
    127 // CreateFormFile is a convenience wrapper around CreatePart. It creates
    128 // a new form-data header with the provided field name and file name.
    129 func (w *Writer) CreateFormFile(fieldname, filename string) (io.Writer, error) {
    130 	h := make(textproto.MIMEHeader)
    131 	h.Set("Content-Disposition",
    132 		fmt.Sprintf(`form-data; name="%s"; filename="%s"`,
    133 			escapeQuotes(fieldname), escapeQuotes(filename)))
    134 	h.Set("Content-Type", "application/octet-stream")
    135 	return w.CreatePart(h)
    136 }
    137 
    138 // CreateFormField calls CreatePart with a header using the
    139 // given field name.
    140 func (w *Writer) CreateFormField(fieldname string) (io.Writer, error) {
    141 	h := make(textproto.MIMEHeader)
    142 	h.Set("Content-Disposition",
    143 		fmt.Sprintf(`form-data; name="%s"`, escapeQuotes(fieldname)))
    144 	return w.CreatePart(h)
    145 }
    146 
    147 // WriteField calls CreateFormField and then writes the given value.
    148 func (w *Writer) WriteField(fieldname, value string) error {
    149 	p, err := w.CreateFormField(fieldname)
    150 	if err != nil {
    151 		return err
    152 	}
    153 	_, err = p.Write([]byte(value))
    154 	return err
    155 }
    156 
    157 // Close finishes the multipart message and writes the trailing
    158 // boundary end line to the output.
    159 func (w *Writer) Close() error {
    160 	if w.lastpart != nil {
    161 		if err := w.lastpart.close(); err != nil {
    162 			return err
    163 		}
    164 		w.lastpart = nil
    165 	}
    166 	_, err := fmt.Fprintf(w.w, "\r\n--%s--\r\n", w.boundary)
    167 	return err
    168 }
    169 
    170 type part struct {
    171 	mw     *Writer
    172 	closed bool
    173 	we     error // last error that occurred writing
    174 }
    175 
    176 func (p *part) close() error {
    177 	p.closed = true
    178 	return p.we
    179 }
    180 
    181 func (p *part) Write(d []byte) (n int, err error) {
    182 	if p.closed {
    183 		return 0, errors.New("multipart: can't write to finished part")
    184 	}
    185 	n, err = p.mw.w.Write(d)
    186 	if err != nil {
    187 		p.we = err
    188 	}
    189 	return
    190 }
    191