Home | History | Annotate | Download | only in oauth2
      1 // Copyright 2014 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package oauth2
      6 
      7 import (
      8 	"errors"
      9 	"io"
     10 	"net/http"
     11 	"sync"
     12 )
     13 
     14 // Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
     15 // wrapping a base RoundTripper and adding an Authorization header
     16 // with a token from the supplied Sources.
     17 //
     18 // Transport is a low-level mechanism. Most code will use the
     19 // higher-level Config.Client method instead.
     20 type Transport struct {
     21 	// Source supplies the token to add to outgoing requests'
     22 	// Authorization headers.
     23 	Source TokenSource
     24 
     25 	// Base is the base RoundTripper used to make HTTP requests.
     26 	// If nil, http.DefaultTransport is used.
     27 	Base http.RoundTripper
     28 
     29 	mu     sync.Mutex                      // guards modReq
     30 	modReq map[*http.Request]*http.Request // original -> modified
     31 }
     32 
     33 // RoundTrip authorizes and authenticates the request with an
     34 // access token. If no token exists or token is expired,
     35 // tries to refresh/fetch a new token.
     36 func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
     37 	if t.Source == nil {
     38 		return nil, errors.New("oauth2: Transport's Source is nil")
     39 	}
     40 	token, err := t.Source.Token()
     41 	if err != nil {
     42 		return nil, err
     43 	}
     44 
     45 	req2 := cloneRequest(req) // per RoundTripper contract
     46 	token.SetAuthHeader(req2)
     47 	t.setModReq(req, req2)
     48 	res, err := t.base().RoundTrip(req2)
     49 	if err != nil {
     50 		t.setModReq(req, nil)
     51 		return nil, err
     52 	}
     53 	res.Body = &onEOFReader{
     54 		rc: res.Body,
     55 		fn: func() { t.setModReq(req, nil) },
     56 	}
     57 	return res, nil
     58 }
     59 
     60 // CancelRequest cancels an in-flight request by closing its connection.
     61 func (t *Transport) CancelRequest(req *http.Request) {
     62 	type canceler interface {
     63 		CancelRequest(*http.Request)
     64 	}
     65 	if cr, ok := t.base().(canceler); ok {
     66 		t.mu.Lock()
     67 		modReq := t.modReq[req]
     68 		delete(t.modReq, req)
     69 		t.mu.Unlock()
     70 		cr.CancelRequest(modReq)
     71 	}
     72 }
     73 
     74 func (t *Transport) base() http.RoundTripper {
     75 	if t.Base != nil {
     76 		return t.Base
     77 	}
     78 	return http.DefaultTransport
     79 }
     80 
     81 func (t *Transport) setModReq(orig, mod *http.Request) {
     82 	t.mu.Lock()
     83 	defer t.mu.Unlock()
     84 	if t.modReq == nil {
     85 		t.modReq = make(map[*http.Request]*http.Request)
     86 	}
     87 	if mod == nil {
     88 		delete(t.modReq, orig)
     89 	} else {
     90 		t.modReq[orig] = mod
     91 	}
     92 }
     93 
     94 // cloneRequest returns a clone of the provided *http.Request.
     95 // The clone is a shallow copy of the struct and its Header map.
     96 func cloneRequest(r *http.Request) *http.Request {
     97 	// shallow copy of the struct
     98 	r2 := new(http.Request)
     99 	*r2 = *r
    100 	// deep copy of the Header
    101 	r2.Header = make(http.Header, len(r.Header))
    102 	for k, s := range r.Header {
    103 		r2.Header[k] = append([]string(nil), s...)
    104 	}
    105 	return r2
    106 }
    107 
    108 type onEOFReader struct {
    109 	rc io.ReadCloser
    110 	fn func()
    111 }
    112 
    113 func (r *onEOFReader) Read(p []byte) (n int, err error) {
    114 	n, err = r.rc.Read(p)
    115 	if err == io.EOF {
    116 		r.runFunc()
    117 	}
    118 	return
    119 }
    120 
    121 func (r *onEOFReader) Close() error {
    122 	err := r.rc.Close()
    123 	r.runFunc()
    124 	return err
    125 }
    126 
    127 func (r *onEOFReader) runFunc() {
    128 	if fn := r.fn; fn != nil {
    129 		fn()
    130 		r.fn = nil
    131 	}
    132 }
    133