client.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. package h2quic
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "net/http"
  9. "strings"
  10. "sync"
  11. "golang.org/x/net/http2"
  12. "golang.org/x/net/http2/hpack"
  13. "golang.org/x/net/idna"
  14. quic "github.com/lucas-clemente/quic-go"
  15. "github.com/lucas-clemente/quic-go/internal/protocol"
  16. "github.com/lucas-clemente/quic-go/internal/utils"
  17. "github.com/lucas-clemente/quic-go/qerr"
  18. )
  19. type roundTripperOpts struct {
  20. DisableCompression bool
  21. }
  22. var dialAddr = quic.DialAddr
  23. // client is a HTTP2 client doing QUIC requests
  24. type client struct {
  25. mutex sync.RWMutex
  26. tlsConf *tls.Config
  27. config *quic.Config
  28. opts *roundTripperOpts
  29. hostname string
  30. handshakeErr error
  31. dialOnce sync.Once
  32. dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
  33. session quic.Session
  34. headerStream quic.Stream
  35. headerErr *qerr.QuicError
  36. headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
  37. requestWriter *requestWriter
  38. responses map[protocol.StreamID]chan *http.Response
  39. logger utils.Logger
  40. }
  41. var _ http.RoundTripper = &client{}
  42. var defaultQuicConfig = &quic.Config{
  43. RequestConnectionIDOmission: true,
  44. KeepAlive: true,
  45. }
  46. // newClient creates a new client
  47. func newClient(
  48. hostname string,
  49. tlsConfig *tls.Config,
  50. opts *roundTripperOpts,
  51. quicConfig *quic.Config,
  52. dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
  53. ) *client {
  54. config := defaultQuicConfig
  55. if quicConfig != nil {
  56. config = quicConfig
  57. }
  58. return &client{
  59. hostname: authorityAddr("https", hostname),
  60. responses: make(map[protocol.StreamID]chan *http.Response),
  61. tlsConf: tlsConfig,
  62. config: config,
  63. opts: opts,
  64. headerErrored: make(chan struct{}),
  65. dialer: dialer,
  66. logger: utils.DefaultLogger.WithPrefix("client"),
  67. }
  68. }
  69. // dial dials the connection
  70. func (c *client) dial() error {
  71. var err error
  72. if c.dialer != nil {
  73. c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
  74. } else {
  75. c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
  76. }
  77. if err != nil {
  78. return err
  79. }
  80. // once the version has been negotiated, open the header stream
  81. c.headerStream, err = c.session.OpenStream()
  82. if err != nil {
  83. return err
  84. }
  85. c.requestWriter = newRequestWriter(c.headerStream, c.logger)
  86. go c.handleHeaderStream()
  87. return nil
  88. }
  89. func (c *client) handleHeaderStream() {
  90. decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
  91. h2framer := http2.NewFramer(nil, c.headerStream)
  92. var err error
  93. for err == nil {
  94. err = c.readResponse(h2framer, decoder)
  95. }
  96. if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway {
  97. c.logger.Debugf("Error handling header stream: %s", err)
  98. }
  99. c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
  100. // stop all running request
  101. close(c.headerErrored)
  102. }
  103. func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
  104. frame, err := h2framer.ReadFrame()
  105. if err != nil {
  106. return err
  107. }
  108. hframe, ok := frame.(*http2.HeadersFrame)
  109. if !ok {
  110. return errors.New("not a headers frame")
  111. }
  112. mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
  113. mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
  114. if err != nil {
  115. return fmt.Errorf("cannot read header fields: %s", err.Error())
  116. }
  117. c.mutex.RLock()
  118. responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
  119. c.mutex.RUnlock()
  120. if !ok {
  121. return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
  122. }
  123. rsp, err := responseFromHeaders(mhframe)
  124. if err != nil {
  125. return err
  126. }
  127. responseChan <- rsp
  128. return nil
  129. }
  130. // Roundtrip executes a request and returns a response
  131. func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
  132. // TODO: add port to address, if it doesn't have one
  133. if req.URL.Scheme != "https" {
  134. return nil, errors.New("quic http2: unsupported scheme")
  135. }
  136. if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
  137. return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
  138. }
  139. c.dialOnce.Do(func() {
  140. c.handshakeErr = c.dial()
  141. })
  142. if c.handshakeErr != nil {
  143. return nil, c.handshakeErr
  144. }
  145. hasBody := (req.Body != nil)
  146. responseChan := make(chan *http.Response)
  147. dataStream, err := c.session.OpenStreamSync()
  148. if err != nil {
  149. _ = c.closeWithError(err)
  150. return nil, err
  151. }
  152. c.mutex.Lock()
  153. c.responses[dataStream.StreamID()] = responseChan
  154. c.mutex.Unlock()
  155. var requestedGzip bool
  156. if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
  157. requestedGzip = true
  158. }
  159. // TODO: add support for trailers
  160. endStream := !hasBody
  161. err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
  162. if err != nil {
  163. _ = c.closeWithError(err)
  164. return nil, err
  165. }
  166. resc := make(chan error, 1)
  167. if hasBody {
  168. go func() {
  169. resc <- c.writeRequestBody(dataStream, req.Body)
  170. }()
  171. }
  172. var res *http.Response
  173. var receivedResponse bool
  174. var bodySent bool
  175. if !hasBody {
  176. bodySent = true
  177. }
  178. ctx := req.Context()
  179. for !(bodySent && receivedResponse) {
  180. select {
  181. case res = <-responseChan:
  182. receivedResponse = true
  183. c.mutex.Lock()
  184. delete(c.responses, dataStream.StreamID())
  185. c.mutex.Unlock()
  186. case err := <-resc:
  187. bodySent = true
  188. if err != nil {
  189. return nil, err
  190. }
  191. case <-ctx.Done():
  192. // error code 6 signals that stream was canceled
  193. dataStream.CancelRead(6)
  194. dataStream.CancelWrite(6)
  195. c.mutex.Lock()
  196. delete(c.responses, dataStream.StreamID())
  197. c.mutex.Unlock()
  198. return nil, ctx.Err()
  199. case <-c.headerErrored:
  200. // an error occurred on the header stream
  201. _ = c.closeWithError(c.headerErr)
  202. return nil, c.headerErr
  203. }
  204. }
  205. // TODO: correctly set this variable
  206. var streamEnded bool
  207. isHead := (req.Method == "HEAD")
  208. res = setLength(res, isHead, streamEnded)
  209. if streamEnded || isHead {
  210. res.Body = noBody
  211. } else {
  212. res.Body = dataStream
  213. if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
  214. res.Header.Del("Content-Encoding")
  215. res.Header.Del("Content-Length")
  216. res.ContentLength = -1
  217. res.Body = &gzipReader{body: res.Body}
  218. res.Uncompressed = true
  219. }
  220. }
  221. res.Request = req
  222. return res, nil
  223. }
  224. func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
  225. defer func() {
  226. cerr := body.Close()
  227. if err == nil {
  228. // TODO: what to do with dataStream here? Maybe reset it?
  229. err = cerr
  230. }
  231. }()
  232. _, err = io.Copy(dataStream, body)
  233. if err != nil {
  234. // TODO: what to do with dataStream here? Maybe reset it?
  235. return err
  236. }
  237. return dataStream.Close()
  238. }
  239. func (c *client) closeWithError(e error) error {
  240. if c.session == nil {
  241. return nil
  242. }
  243. return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
  244. }
  245. // Close closes the client
  246. func (c *client) Close() error {
  247. if c.session == nil {
  248. return nil
  249. }
  250. return c.session.Close()
  251. }
  252. // copied from net/transport.go
  253. // authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
  254. // and returns a host:port. The port 443 is added if needed.
  255. func authorityAddr(scheme string, authority string) (addr string) {
  256. host, port, err := net.SplitHostPort(authority)
  257. if err != nil { // authority didn't have a port
  258. port = "443"
  259. if scheme == "http" {
  260. port = "80"
  261. }
  262. host = authority
  263. }
  264. if a, err := idna.ToASCII(host); err == nil {
  265. host = a
  266. }
  267. // IPv6 address literal, without a port:
  268. if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
  269. return host + ":" + port
  270. }
  271. return net.JoinHostPort(host, port)
  272. }