roundtrip.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. package h2quic
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "sync"
  10. quic "github.com/lucas-clemente/quic-go"
  11. "golang.org/x/net/http/httpguts"
  12. )
  13. type roundTripCloser interface {
  14. http.RoundTripper
  15. io.Closer
  16. }
  17. // RoundTripper implements the http.RoundTripper interface
  18. type RoundTripper struct {
  19. mutex sync.Mutex
  20. // DisableCompression, if true, prevents the Transport from
  21. // requesting compression with an "Accept-Encoding: gzip"
  22. // request header when the Request contains no existing
  23. // Accept-Encoding value. If the Transport requests gzip on
  24. // its own and gets a gzipped response, it's transparently
  25. // decoded in the Response.Body. However, if the user
  26. // explicitly requested gzip it is not automatically
  27. // uncompressed.
  28. DisableCompression bool
  29. // TLSClientConfig specifies the TLS configuration to use with
  30. // tls.Client. If nil, the default configuration is used.
  31. TLSClientConfig *tls.Config
  32. // QuicConfig is the quic.Config used for dialing new connections.
  33. // If nil, reasonable default values will be used.
  34. QuicConfig *quic.Config
  35. // Dial specifies an optional dial function for creating QUIC
  36. // connections for requests.
  37. // If Dial is nil, quic.DialAddr will be used.
  38. Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
  39. clients map[string]roundTripCloser
  40. }
  41. // RoundTripOpt are options for the Transport.RoundTripOpt method.
  42. type RoundTripOpt struct {
  43. // OnlyCachedConn controls whether the RoundTripper may
  44. // create a new QUIC connection. If set true and
  45. // no cached connection is available, RoundTrip
  46. // will return ErrNoCachedConn.
  47. OnlyCachedConn bool
  48. }
  49. var _ roundTripCloser = &RoundTripper{}
  50. // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
  51. var ErrNoCachedConn = errors.New("h2quic: no cached connection was available")
  52. // RoundTripOpt is like RoundTrip, but takes options.
  53. func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
  54. if req.URL == nil {
  55. closeRequestBody(req)
  56. return nil, errors.New("quic: nil Request.URL")
  57. }
  58. if req.URL.Host == "" {
  59. closeRequestBody(req)
  60. return nil, errors.New("quic: no Host in request URL")
  61. }
  62. if req.Header == nil {
  63. closeRequestBody(req)
  64. return nil, errors.New("quic: nil Request.Header")
  65. }
  66. if req.URL.Scheme == "https" {
  67. for k, vv := range req.Header {
  68. if !httpguts.ValidHeaderFieldName(k) {
  69. return nil, fmt.Errorf("quic: invalid http header field name %q", k)
  70. }
  71. for _, v := range vv {
  72. if !httpguts.ValidHeaderFieldValue(v) {
  73. return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k)
  74. }
  75. }
  76. }
  77. } else {
  78. closeRequestBody(req)
  79. return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme)
  80. }
  81. if req.Method != "" && !validMethod(req.Method) {
  82. closeRequestBody(req)
  83. return nil, fmt.Errorf("quic: invalid method %q", req.Method)
  84. }
  85. hostname := authorityAddr("https", hostnameFromRequest(req))
  86. cl, err := r.getClient(hostname, opt.OnlyCachedConn)
  87. if err != nil {
  88. return nil, err
  89. }
  90. return cl.RoundTrip(req)
  91. }
  92. // RoundTrip does a round trip.
  93. func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
  94. return r.RoundTripOpt(req, RoundTripOpt{})
  95. }
  96. func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
  97. r.mutex.Lock()
  98. defer r.mutex.Unlock()
  99. if r.clients == nil {
  100. r.clients = make(map[string]roundTripCloser)
  101. }
  102. client, ok := r.clients[hostname]
  103. if !ok {
  104. if onlyCached {
  105. return nil, ErrNoCachedConn
  106. }
  107. client = newClient(
  108. hostname,
  109. r.TLSClientConfig,
  110. &roundTripperOpts{DisableCompression: r.DisableCompression},
  111. r.QuicConfig,
  112. r.Dial,
  113. )
  114. r.clients[hostname] = client
  115. }
  116. return client, nil
  117. }
  118. // Close closes the QUIC connections that this RoundTripper has used
  119. func (r *RoundTripper) Close() error {
  120. r.mutex.Lock()
  121. defer r.mutex.Unlock()
  122. for _, client := range r.clients {
  123. if err := client.Close(); err != nil {
  124. return err
  125. }
  126. }
  127. r.clients = nil
  128. return nil
  129. }
  130. func closeRequestBody(req *http.Request) {
  131. if req.Body != nil {
  132. req.Body.Close()
  133. }
  134. }
  135. func validMethod(method string) bool {
  136. /*
  137. Method = "OPTIONS" ; Section 9.2
  138. | "GET" ; Section 9.3
  139. | "HEAD" ; Section 9.4
  140. | "POST" ; Section 9.5
  141. | "PUT" ; Section 9.6
  142. | "DELETE" ; Section 9.7
  143. | "TRACE" ; Section 9.8
  144. | "CONNECT" ; Section 9.9
  145. | extension-method
  146. extension-method = token
  147. token = 1*<any CHAR except CTLs or separators>
  148. */
  149. return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
  150. }
  151. // copied from net/http/http.go
  152. func isNotToken(r rune) bool {
  153. return !httpguts.IsTokenRune(r)
  154. }