Home | History | Annotate | Download | only in transport
      1 /*
      2  *
      3  * Copyright 2014 gRPC authors.
      4  *
      5  * Licensed under the Apache License, Version 2.0 (the "License");
      6  * you may not use this file except in compliance with the License.
      7  * You may obtain a copy of the License at
      8  *
      9  *     http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  *
     17  */
     18 
     19 package transport
     20 
     21 import (
     22 	"bufio"
     23 	"bytes"
     24 	"encoding/base64"
     25 	"fmt"
     26 	"net"
     27 	"net/http"
     28 	"strconv"
     29 	"strings"
     30 	"time"
     31 	"unicode/utf8"
     32 
     33 	"github.com/golang/protobuf/proto"
     34 	"golang.org/x/net/http2"
     35 	"golang.org/x/net/http2/hpack"
     36 	spb "google.golang.org/genproto/googleapis/rpc/status"
     37 	"google.golang.org/grpc/codes"
     38 	"google.golang.org/grpc/status"
     39 )
     40 
     41 const (
     42 	// http2MaxFrameLen specifies the max length of a HTTP2 frame.
     43 	http2MaxFrameLen = 16384 // 16KB frame
     44 	// http://http2.github.io/http2-spec/#SettingValues
     45 	http2InitHeaderTableSize = 4096
     46 	// http2IOBufSize specifies the buffer size for sending frames.
     47 	defaultWriteBufSize = 32 * 1024
     48 	defaultReadBufSize  = 32 * 1024
     49 	// baseContentType is the base content-type for gRPC.  This is a valid
     50 	// content-type on it's own, but can also include a content-subtype such as
     51 	// "proto" as a suffix after "+" or ";".  See
     52 	// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
     53 	// for more details.
     54 	baseContentType = "application/grpc"
     55 )
     56 
     57 var (
     58 	clientPreface   = []byte(http2.ClientPreface)
     59 	http2ErrConvTab = map[http2.ErrCode]codes.Code{
     60 		http2.ErrCodeNo:                 codes.Internal,
     61 		http2.ErrCodeProtocol:           codes.Internal,
     62 		http2.ErrCodeInternal:           codes.Internal,
     63 		http2.ErrCodeFlowControl:        codes.ResourceExhausted,
     64 		http2.ErrCodeSettingsTimeout:    codes.Internal,
     65 		http2.ErrCodeStreamClosed:       codes.Internal,
     66 		http2.ErrCodeFrameSize:          codes.Internal,
     67 		http2.ErrCodeRefusedStream:      codes.Unavailable,
     68 		http2.ErrCodeCancel:             codes.Canceled,
     69 		http2.ErrCodeCompression:        codes.Internal,
     70 		http2.ErrCodeConnect:            codes.Internal,
     71 		http2.ErrCodeEnhanceYourCalm:    codes.ResourceExhausted,
     72 		http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
     73 		http2.ErrCodeHTTP11Required:     codes.Internal,
     74 	}
     75 	statusCodeConvTab = map[codes.Code]http2.ErrCode{
     76 		codes.Internal:          http2.ErrCodeInternal,
     77 		codes.Canceled:          http2.ErrCodeCancel,
     78 		codes.Unavailable:       http2.ErrCodeRefusedStream,
     79 		codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm,
     80 		codes.PermissionDenied:  http2.ErrCodeInadequateSecurity,
     81 	}
     82 	httpStatusConvTab = map[int]codes.Code{
     83 		// 400 Bad Request - INTERNAL.
     84 		http.StatusBadRequest: codes.Internal,
     85 		// 401 Unauthorized  - UNAUTHENTICATED.
     86 		http.StatusUnauthorized: codes.Unauthenticated,
     87 		// 403 Forbidden - PERMISSION_DENIED.
     88 		http.StatusForbidden: codes.PermissionDenied,
     89 		// 404 Not Found - UNIMPLEMENTED.
     90 		http.StatusNotFound: codes.Unimplemented,
     91 		// 429 Too Many Requests - UNAVAILABLE.
     92 		http.StatusTooManyRequests: codes.Unavailable,
     93 		// 502 Bad Gateway - UNAVAILABLE.
     94 		http.StatusBadGateway: codes.Unavailable,
     95 		// 503 Service Unavailable - UNAVAILABLE.
     96 		http.StatusServiceUnavailable: codes.Unavailable,
     97 		// 504 Gateway timeout - UNAVAILABLE.
     98 		http.StatusGatewayTimeout: codes.Unavailable,
     99 	}
    100 )
    101 
    102 // Records the states during HPACK decoding. Must be reset once the
    103 // decoding of the entire headers are finished.
    104 type decodeState struct {
    105 	encoding string
    106 	// statusGen caches the stream status received from the trailer the server
    107 	// sent.  Client side only.  Do not access directly.  After all trailers are
    108 	// parsed, use the status method to retrieve the status.
    109 	statusGen *status.Status
    110 	// rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not
    111 	// intended for direct access outside of parsing.
    112 	rawStatusCode *int
    113 	rawStatusMsg  string
    114 	httpStatus    *int
    115 	// Server side only fields.
    116 	timeoutSet bool
    117 	timeout    time.Duration
    118 	method     string
    119 	// key-value metadata map from the peer.
    120 	mdata          map[string][]string
    121 	statsTags      []byte
    122 	statsTrace     []byte
    123 	contentSubtype string
    124 }
    125 
    126 // isReservedHeader checks whether hdr belongs to HTTP2 headers
    127 // reserved by gRPC protocol. Any other headers are classified as the
    128 // user-specified metadata.
    129 func isReservedHeader(hdr string) bool {
    130 	if hdr != "" && hdr[0] == ':' {
    131 		return true
    132 	}
    133 	switch hdr {
    134 	case "content-type",
    135 		"user-agent",
    136 		"grpc-message-type",
    137 		"grpc-encoding",
    138 		"grpc-message",
    139 		"grpc-status",
    140 		"grpc-timeout",
    141 		"grpc-status-details-bin",
    142 		"te":
    143 		return true
    144 	default:
    145 		return false
    146 	}
    147 }
    148 
    149 // isWhitelistedHeader checks whether hdr should be propagated
    150 // into metadata visible to users.
    151 func isWhitelistedHeader(hdr string) bool {
    152 	switch hdr {
    153 	case ":authority", "user-agent":
    154 		return true
    155 	default:
    156 		return false
    157 	}
    158 }
    159 
    160 // contentSubtype returns the content-subtype for the given content-type.  The
    161 // given content-type must be a valid content-type that starts with
    162 // "application/grpc". A content-subtype will follow "application/grpc" after a
    163 // "+" or ";". See
    164 // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
    165 // more details.
    166 //
    167 // If contentType is not a valid content-type for gRPC, the boolean
    168 // will be false, otherwise true. If content-type == "application/grpc",
    169 // "application/grpc+", or "application/grpc;", the boolean will be true,
    170 // but no content-subtype will be returned.
    171 //
    172 // contentType is assumed to be lowercase already.
    173 func contentSubtype(contentType string) (string, bool) {
    174 	if contentType == baseContentType {
    175 		return "", true
    176 	}
    177 	if !strings.HasPrefix(contentType, baseContentType) {
    178 		return "", false
    179 	}
    180 	// guaranteed since != baseContentType and has baseContentType prefix
    181 	switch contentType[len(baseContentType)] {
    182 	case '+', ';':
    183 		// this will return true for "application/grpc+" or "application/grpc;"
    184 		// which the previous validContentType function tested to be valid, so we
    185 		// just say that no content-subtype is specified in this case
    186 		return contentType[len(baseContentType)+1:], true
    187 	default:
    188 		return "", false
    189 	}
    190 }
    191 
    192 // contentSubtype is assumed to be lowercase
    193 func contentType(contentSubtype string) string {
    194 	if contentSubtype == "" {
    195 		return baseContentType
    196 	}
    197 	return baseContentType + "+" + contentSubtype
    198 }
    199 
    200 func (d *decodeState) status() *status.Status {
    201 	if d.statusGen == nil {
    202 		// No status-details were provided; generate status using code/msg.
    203 		d.statusGen = status.New(codes.Code(int32(*(d.rawStatusCode))), d.rawStatusMsg)
    204 	}
    205 	return d.statusGen
    206 }
    207 
    208 const binHdrSuffix = "-bin"
    209 
    210 func encodeBinHeader(v []byte) string {
    211 	return base64.RawStdEncoding.EncodeToString(v)
    212 }
    213 
    214 func decodeBinHeader(v string) ([]byte, error) {
    215 	if len(v)%4 == 0 {
    216 		// Input was padded, or padding was not necessary.
    217 		return base64.StdEncoding.DecodeString(v)
    218 	}
    219 	return base64.RawStdEncoding.DecodeString(v)
    220 }
    221 
    222 func encodeMetadataHeader(k, v string) string {
    223 	if strings.HasSuffix(k, binHdrSuffix) {
    224 		return encodeBinHeader(([]byte)(v))
    225 	}
    226 	return v
    227 }
    228 
    229 func decodeMetadataHeader(k, v string) (string, error) {
    230 	if strings.HasSuffix(k, binHdrSuffix) {
    231 		b, err := decodeBinHeader(v)
    232 		return string(b), err
    233 	}
    234 	return v, nil
    235 }
    236 
    237 func (d *decodeState) decodeResponseHeader(frame *http2.MetaHeadersFrame) error {
    238 	for _, hf := range frame.Fields {
    239 		if err := d.processHeaderField(hf); err != nil {
    240 			return err
    241 		}
    242 	}
    243 
    244 	// If grpc status exists, no need to check further.
    245 	if d.rawStatusCode != nil || d.statusGen != nil {
    246 		return nil
    247 	}
    248 
    249 	// If grpc status doesn't exist and http status doesn't exist,
    250 	// then it's a malformed header.
    251 	if d.httpStatus == nil {
    252 		return streamErrorf(codes.Internal, "malformed header: doesn't contain status(gRPC or HTTP)")
    253 	}
    254 
    255 	if *(d.httpStatus) != http.StatusOK {
    256 		code, ok := httpStatusConvTab[*(d.httpStatus)]
    257 		if !ok {
    258 			code = codes.Unknown
    259 		}
    260 		return streamErrorf(code, http.StatusText(*(d.httpStatus)))
    261 	}
    262 
    263 	// gRPC status doesn't exist and http status is OK.
    264 	// Set rawStatusCode to be unknown and return nil error.
    265 	// So that, if the stream has ended this Unknown status
    266 	// will be propagated to the user.
    267 	// Otherwise, it will be ignored. In which case, status from
    268 	// a later trailer, that has StreamEnded flag set, is propagated.
    269 	code := int(codes.Unknown)
    270 	d.rawStatusCode = &code
    271 	return nil
    272 
    273 }
    274 
    275 func (d *decodeState) addMetadata(k, v string) {
    276 	if d.mdata == nil {
    277 		d.mdata = make(map[string][]string)
    278 	}
    279 	d.mdata[k] = append(d.mdata[k], v)
    280 }
    281 
    282 func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
    283 	switch f.Name {
    284 	case "content-type":
    285 		contentSubtype, validContentType := contentSubtype(f.Value)
    286 		if !validContentType {
    287 			return streamErrorf(codes.Internal, "transport: received the unexpected content-type %q", f.Value)
    288 		}
    289 		d.contentSubtype = contentSubtype
    290 		// TODO: do we want to propagate the whole content-type in the metadata,
    291 		// or come up with a way to just propagate the content-subtype if it was set?
    292 		// ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"}
    293 		// in the metadata?
    294 		d.addMetadata(f.Name, f.Value)
    295 	case "grpc-encoding":
    296 		d.encoding = f.Value
    297 	case "grpc-status":
    298 		code, err := strconv.Atoi(f.Value)
    299 		if err != nil {
    300 			return streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err)
    301 		}
    302 		d.rawStatusCode = &code
    303 	case "grpc-message":
    304 		d.rawStatusMsg = decodeGrpcMessage(f.Value)
    305 	case "grpc-status-details-bin":
    306 		v, err := decodeBinHeader(f.Value)
    307 		if err != nil {
    308 			return streamErrorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
    309 		}
    310 		s := &spb.Status{}
    311 		if err := proto.Unmarshal(v, s); err != nil {
    312 			return streamErrorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
    313 		}
    314 		d.statusGen = status.FromProto(s)
    315 	case "grpc-timeout":
    316 		d.timeoutSet = true
    317 		var err error
    318 		if d.timeout, err = decodeTimeout(f.Value); err != nil {
    319 			return streamErrorf(codes.Internal, "transport: malformed time-out: %v", err)
    320 		}
    321 	case ":path":
    322 		d.method = f.Value
    323 	case ":status":
    324 		code, err := strconv.Atoi(f.Value)
    325 		if err != nil {
    326 			return streamErrorf(codes.Internal, "transport: malformed http-status: %v", err)
    327 		}
    328 		d.httpStatus = &code
    329 	case "grpc-tags-bin":
    330 		v, err := decodeBinHeader(f.Value)
    331 		if err != nil {
    332 			return streamErrorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err)
    333 		}
    334 		d.statsTags = v
    335 		d.addMetadata(f.Name, string(v))
    336 	case "grpc-trace-bin":
    337 		v, err := decodeBinHeader(f.Value)
    338 		if err != nil {
    339 			return streamErrorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err)
    340 		}
    341 		d.statsTrace = v
    342 		d.addMetadata(f.Name, string(v))
    343 	default:
    344 		if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) {
    345 			break
    346 		}
    347 		v, err := decodeMetadataHeader(f.Name, f.Value)
    348 		if err != nil {
    349 			errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err)
    350 			return nil
    351 		}
    352 		d.addMetadata(f.Name, v)
    353 	}
    354 	return nil
    355 }
    356 
    357 type timeoutUnit uint8
    358 
    359 const (
    360 	hour        timeoutUnit = 'H'
    361 	minute      timeoutUnit = 'M'
    362 	second      timeoutUnit = 'S'
    363 	millisecond timeoutUnit = 'm'
    364 	microsecond timeoutUnit = 'u'
    365 	nanosecond  timeoutUnit = 'n'
    366 )
    367 
    368 func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
    369 	switch u {
    370 	case hour:
    371 		return time.Hour, true
    372 	case minute:
    373 		return time.Minute, true
    374 	case second:
    375 		return time.Second, true
    376 	case millisecond:
    377 		return time.Millisecond, true
    378 	case microsecond:
    379 		return time.Microsecond, true
    380 	case nanosecond:
    381 		return time.Nanosecond, true
    382 	default:
    383 	}
    384 	return
    385 }
    386 
    387 const maxTimeoutValue int64 = 100000000 - 1
    388 
    389 // div does integer division and round-up the result. Note that this is
    390 // equivalent to (d+r-1)/r but has less chance to overflow.
    391 func div(d, r time.Duration) int64 {
    392 	if m := d % r; m > 0 {
    393 		return int64(d/r + 1)
    394 	}
    395 	return int64(d / r)
    396 }
    397 
    398 // TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
    399 func encodeTimeout(t time.Duration) string {
    400 	if t <= 0 {
    401 		return "0n"
    402 	}
    403 	if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
    404 		return strconv.FormatInt(d, 10) + "n"
    405 	}
    406 	if d := div(t, time.Microsecond); d <= maxTimeoutValue {
    407 		return strconv.FormatInt(d, 10) + "u"
    408 	}
    409 	if d := div(t, time.Millisecond); d <= maxTimeoutValue {
    410 		return strconv.FormatInt(d, 10) + "m"
    411 	}
    412 	if d := div(t, time.Second); d <= maxTimeoutValue {
    413 		return strconv.FormatInt(d, 10) + "S"
    414 	}
    415 	if d := div(t, time.Minute); d <= maxTimeoutValue {
    416 		return strconv.FormatInt(d, 10) + "M"
    417 	}
    418 	// Note that maxTimeoutValue * time.Hour > MaxInt64.
    419 	return strconv.FormatInt(div(t, time.Hour), 10) + "H"
    420 }
    421 
    422 func decodeTimeout(s string) (time.Duration, error) {
    423 	size := len(s)
    424 	if size < 2 {
    425 		return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
    426 	}
    427 	unit := timeoutUnit(s[size-1])
    428 	d, ok := timeoutUnitToDuration(unit)
    429 	if !ok {
    430 		return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s)
    431 	}
    432 	t, err := strconv.ParseInt(s[:size-1], 10, 64)
    433 	if err != nil {
    434 		return 0, err
    435 	}
    436 	return d * time.Duration(t), nil
    437 }
    438 
    439 const (
    440 	spaceByte   = ' '
    441 	tildeByte   = '~'
    442 	percentByte = '%'
    443 )
    444 
    445 // encodeGrpcMessage is used to encode status code in header field
    446 // "grpc-message". It does percent encoding and also replaces invalid utf-8
    447 // characters with Unicode replacement character.
    448 //
    449 // It checks to see if each individual byte in msg is an allowable byte, and
    450 // then either percent encoding or passing it through. When percent encoding,
    451 // the byte is converted into hexadecimal notation with a '%' prepended.
    452 func encodeGrpcMessage(msg string) string {
    453 	if msg == "" {
    454 		return ""
    455 	}
    456 	lenMsg := len(msg)
    457 	for i := 0; i < lenMsg; i++ {
    458 		c := msg[i]
    459 		if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
    460 			return encodeGrpcMessageUnchecked(msg)
    461 		}
    462 	}
    463 	return msg
    464 }
    465 
    466 func encodeGrpcMessageUnchecked(msg string) string {
    467 	var buf bytes.Buffer
    468 	for len(msg) > 0 {
    469 		r, size := utf8.DecodeRuneInString(msg)
    470 		for _, b := range []byte(string(r)) {
    471 			if size > 1 {
    472 				// If size > 1, r is not ascii. Always do percent encoding.
    473 				buf.WriteString(fmt.Sprintf("%%%02X", b))
    474 				continue
    475 			}
    476 
    477 			// The for loop is necessary even if size == 1. r could be
    478 			// utf8.RuneError.
    479 			//
    480 			// fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
    481 			if b >= spaceByte && b <= tildeByte && b != percentByte {
    482 				buf.WriteByte(b)
    483 			} else {
    484 				buf.WriteString(fmt.Sprintf("%%%02X", b))
    485 			}
    486 		}
    487 		msg = msg[size:]
    488 	}
    489 	return buf.String()
    490 }
    491 
    492 // decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
    493 func decodeGrpcMessage(msg string) string {
    494 	if msg == "" {
    495 		return ""
    496 	}
    497 	lenMsg := len(msg)
    498 	for i := 0; i < lenMsg; i++ {
    499 		if msg[i] == percentByte && i+2 < lenMsg {
    500 			return decodeGrpcMessageUnchecked(msg)
    501 		}
    502 	}
    503 	return msg
    504 }
    505 
    506 func decodeGrpcMessageUnchecked(msg string) string {
    507 	var buf bytes.Buffer
    508 	lenMsg := len(msg)
    509 	for i := 0; i < lenMsg; i++ {
    510 		c := msg[i]
    511 		if c == percentByte && i+2 < lenMsg {
    512 			parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8)
    513 			if err != nil {
    514 				buf.WriteByte(c)
    515 			} else {
    516 				buf.WriteByte(byte(parsed))
    517 				i += 2
    518 			}
    519 		} else {
    520 			buf.WriteByte(c)
    521 		}
    522 	}
    523 	return buf.String()
    524 }
    525 
    526 type bufWriter struct {
    527 	buf       []byte
    528 	offset    int
    529 	batchSize int
    530 	conn      net.Conn
    531 	err       error
    532 
    533 	onFlush func()
    534 }
    535 
    536 func newBufWriter(conn net.Conn, batchSize int) *bufWriter {
    537 	return &bufWriter{
    538 		buf:       make([]byte, batchSize*2),
    539 		batchSize: batchSize,
    540 		conn:      conn,
    541 	}
    542 }
    543 
    544 func (w *bufWriter) Write(b []byte) (n int, err error) {
    545 	if w.err != nil {
    546 		return 0, w.err
    547 	}
    548 	for len(b) > 0 {
    549 		nn := copy(w.buf[w.offset:], b)
    550 		b = b[nn:]
    551 		w.offset += nn
    552 		n += nn
    553 		if w.offset >= w.batchSize {
    554 			err = w.Flush()
    555 		}
    556 	}
    557 	return n, err
    558 }
    559 
    560 func (w *bufWriter) Flush() error {
    561 	if w.err != nil {
    562 		return w.err
    563 	}
    564 	if w.offset == 0 {
    565 		return nil
    566 	}
    567 	if w.onFlush != nil {
    568 		w.onFlush()
    569 	}
    570 	_, w.err = w.conn.Write(w.buf[:w.offset])
    571 	w.offset = 0
    572 	return w.err
    573 }
    574 
    575 type framer struct {
    576 	writer *bufWriter
    577 	fr     *http2.Framer
    578 }
    579 
    580 func newFramer(conn net.Conn, writeBufferSize, readBufferSize int) *framer {
    581 	r := bufio.NewReaderSize(conn, readBufferSize)
    582 	w := newBufWriter(conn, writeBufferSize)
    583 	f := &framer{
    584 		writer: w,
    585 		fr:     http2.NewFramer(w, r),
    586 	}
    587 	// Opt-in to Frame reuse API on framer to reduce garbage.
    588 	// Frames aren't safe to read from after a subsequent call to ReadFrame.
    589 	f.fr.SetReuseFrames()
    590 	f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
    591 	return f
    592 }
    593