request_writer.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. package h2quic
  2. import (
  3. "bytes"
  4. "fmt"
  5. "net/http"
  6. "strconv"
  7. "strings"
  8. "sync"
  9. "golang.org/x/net/http/httpguts"
  10. "golang.org/x/net/http2"
  11. "golang.org/x/net/http2/hpack"
  12. quic "github.com/lucas-clemente/quic-go"
  13. "github.com/lucas-clemente/quic-go/internal/protocol"
  14. "github.com/lucas-clemente/quic-go/internal/utils"
  15. )
  16. type requestWriter struct {
  17. mutex sync.Mutex
  18. headerStream quic.Stream
  19. henc *hpack.Encoder
  20. hbuf bytes.Buffer // HPACK encoder writes into this
  21. logger utils.Logger
  22. }
  23. const defaultUserAgent = "quic-go"
  24. func newRequestWriter(headerStream quic.Stream, logger utils.Logger) *requestWriter {
  25. rw := &requestWriter{
  26. headerStream: headerStream,
  27. logger: logger,
  28. }
  29. rw.henc = hpack.NewEncoder(&rw.hbuf)
  30. return rw
  31. }
  32. func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error {
  33. // TODO: add support for trailers
  34. // TODO: add support for gzip compression
  35. // TODO: write continuation frames, if the header frame is too long
  36. w.mutex.Lock()
  37. defer w.mutex.Unlock()
  38. w.encodeHeaders(req, requestGzip, "", actualContentLength(req))
  39. h2framer := http2.NewFramer(w.headerStream, nil)
  40. return h2framer.WriteHeaders(http2.HeadersFrameParam{
  41. StreamID: uint32(dataStreamID),
  42. EndHeaders: true,
  43. EndStream: endStream,
  44. BlockFragment: w.hbuf.Bytes(),
  45. Priority: http2.PriorityParam{Weight: 0xff},
  46. })
  47. }
  48. // the rest of this files is copied from http2.Transport
  49. func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
  50. w.hbuf.Reset()
  51. host := req.Host
  52. if host == "" {
  53. host = req.URL.Host
  54. }
  55. host, err := httpguts.PunycodeHostPort(host)
  56. if err != nil {
  57. return nil, err
  58. }
  59. var path string
  60. if req.Method != "CONNECT" {
  61. path = req.URL.RequestURI()
  62. if !validPseudoPath(path) {
  63. orig := path
  64. path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
  65. if !validPseudoPath(path) {
  66. if req.URL.Opaque != "" {
  67. return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
  68. }
  69. return nil, fmt.Errorf("invalid request :path %q", orig)
  70. }
  71. }
  72. }
  73. // Check for any invalid headers and return an error before we
  74. // potentially pollute our hpack state. (We want to be able to
  75. // continue to reuse the hpack encoder for future requests)
  76. for k, vv := range req.Header {
  77. if !httpguts.ValidHeaderFieldName(k) {
  78. return nil, fmt.Errorf("invalid HTTP header name %q", k)
  79. }
  80. for _, v := range vv {
  81. if !httpguts.ValidHeaderFieldValue(v) {
  82. return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
  83. }
  84. }
  85. }
  86. // 8.1.2.3 Request Pseudo-Header Fields
  87. // The :path pseudo-header field includes the path and query parts of the
  88. // target URI (the path-absolute production and optionally a '?' character
  89. // followed by the query production (see Sections 3.3 and 3.4 of
  90. // [RFC3986]).
  91. w.writeHeader(":authority", host)
  92. w.writeHeader(":method", req.Method)
  93. if req.Method != "CONNECT" {
  94. w.writeHeader(":path", path)
  95. w.writeHeader(":scheme", req.URL.Scheme)
  96. }
  97. if trailers != "" {
  98. w.writeHeader("trailer", trailers)
  99. }
  100. var didUA bool
  101. for k, vv := range req.Header {
  102. lowKey := strings.ToLower(k)
  103. switch lowKey {
  104. case "host", "content-length":
  105. // Host is :authority, already sent.
  106. // Content-Length is automatic, set below.
  107. continue
  108. case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive":
  109. // Per 8.1.2.2 Connection-Specific Header
  110. // Fields, don't send connection-specific
  111. // fields. We have already checked if any
  112. // are error-worthy so just ignore the rest.
  113. continue
  114. case "user-agent":
  115. // Match Go's http1 behavior: at most one
  116. // User-Agent. If set to nil or empty string,
  117. // then omit it. Otherwise if not mentioned,
  118. // include the default (below).
  119. didUA = true
  120. if len(vv) < 1 {
  121. continue
  122. }
  123. vv = vv[:1]
  124. if vv[0] == "" {
  125. continue
  126. }
  127. }
  128. for _, v := range vv {
  129. w.writeHeader(lowKey, v)
  130. }
  131. }
  132. if shouldSendReqContentLength(req.Method, contentLength) {
  133. w.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
  134. }
  135. if addGzipHeader {
  136. w.writeHeader("accept-encoding", "gzip")
  137. }
  138. if !didUA {
  139. w.writeHeader("user-agent", defaultUserAgent)
  140. }
  141. return w.hbuf.Bytes(), nil
  142. }
  143. func (w *requestWriter) writeHeader(name, value string) {
  144. w.logger.Debugf("http2: Transport encoding header %q = %q", name, value)
  145. w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
  146. }
  147. // shouldSendReqContentLength reports whether the http2.Transport should send
  148. // a "content-length" request header. This logic is basically a copy of the net/http
  149. // transferWriter.shouldSendContentLength.
  150. // The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
  151. // -1 means unknown.
  152. func shouldSendReqContentLength(method string, contentLength int64) bool {
  153. if contentLength > 0 {
  154. return true
  155. }
  156. if contentLength < 0 {
  157. return false
  158. }
  159. // For zero bodies, whether we send a content-length depends on the method.
  160. // It also kinda doesn't matter for http2 either way, with END_STREAM.
  161. switch method {
  162. case "POST", "PUT", "PATCH":
  163. return true
  164. default:
  165. return false
  166. }
  167. }
  168. func validPseudoPath(v string) bool {
  169. return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*"
  170. }
  171. // actualContentLength returns a sanitized version of
  172. // req.ContentLength, where 0 actually means zero (not unknown) and -1
  173. // means unknown.
  174. func actualContentLength(req *http.Request) int64 {
  175. if req.Body == nil {
  176. return 0
  177. }
  178. if req.ContentLength != 0 {
  179. return req.ContentLength
  180. }
  181. return -1
  182. }