roundtrip_test.go 6.0 KB


  1. package h2quic
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "errors"
  6. "io"
  7. "net/http"
  8. "time"
  9. quic "github.com/lucas-clemente/quic-go"
  10. . "github.com/onsi/ginkgo"
  11. . "github.com/onsi/gomega"
  12. )
  13. type mockClient struct {
  14. closed bool
  15. }
  16. func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) {
  17. return &http.Response{Request: req}, nil
  18. }
  19. func (m *mockClient) Close() error {
  20. m.closed = true
  21. return nil
  22. }
  23. var _ roundTripCloser = &mockClient{}
  24. type mockBody struct {
  25. reader bytes.Reader
  26. readErr error
  27. closeErr error
  28. closed bool
  29. }
  30. func (m *mockBody) Read(p []byte) (int, error) {
  31. if m.readErr != nil {
  32. return 0, m.readErr
  33. }
  34. return m.reader.Read(p)
  35. }
  36. func (m *mockBody) SetData(data []byte) {
  37. m.reader = *bytes.NewReader(data)
  38. }
  39. func (m *mockBody) Close() error {
  40. m.closed = true
  41. return m.closeErr
  42. }
  43. // make sure the mockBody can be used as a http.Request.Body
  44. var _ io.ReadCloser = &mockBody{}
  45. var _ = Describe("RoundTripper", func() {
  46. var (
  47. rt *RoundTripper
  48. req1 *http.Request
  49. )
  50. BeforeEach(func() {
  51. rt = &RoundTripper{}
  52. var err error
  53. req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
  54. Expect(err).ToNot(HaveOccurred())
  55. })
  56. Context("dialing hosts", func() {
  57. origDialAddr := dialAddr
  58. streamOpenErr := errors.New("error opening stream")
  59. BeforeEach(func() {
  60. origDialAddr = dialAddr
  61. dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
  62. // return an error when trying to open a stream
  63. // we don't want to test all the dial logic here, just that dialing happens at all
  64. return &mockSession{streamOpenErr: streamOpenErr}, nil
  65. }
  66. })
  67. AfterEach(func() {
  68. dialAddr = origDialAddr
  69. })
  70. It("creates new clients", func() {
  71. req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
  72. Expect(err).ToNot(HaveOccurred())
  73. _, err = rt.RoundTrip(req)
  74. Expect(err).To(MatchError(streamOpenErr))
  75. Expect(rt.clients).To(HaveLen(1))
  76. })
  77. It("uses the quic.Config, if provided", func() {
  78. config := &quic.Config{HandshakeTimeout: time.Millisecond}
  79. var receivedConfig *quic.Config
  80. dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
  81. receivedConfig = config
  82. return nil, errors.New("err")
  83. }
  84. rt.QuicConfig = config
  85. rt.RoundTrip(req1)
  86. Expect(receivedConfig).To(Equal(config))
  87. })
  88. It("uses the custom dialer, if provided", func() {
  89. var dialed bool
  90. dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
  91. dialed = true
  92. return nil, errors.New("err")
  93. }
  94. rt.Dial = dialer
  95. rt.RoundTrip(req1)
  96. Expect(dialed).To(BeTrue())
  97. })
  98. It("reuses existing clients", func() {
  99. req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
  100. Expect(err).ToNot(HaveOccurred())
  101. _, err = rt.RoundTrip(req)
  102. Expect(err).To(MatchError(streamOpenErr))
  103. Expect(rt.clients).To(HaveLen(1))
  104. req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
  105. Expect(err).ToNot(HaveOccurred())
  106. _, err = rt.RoundTrip(req2)
  107. Expect(err).To(MatchError(streamOpenErr))
  108. Expect(rt.clients).To(HaveLen(1))
  109. })
  110. It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
  111. req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
  112. Expect(err).ToNot(HaveOccurred())
  113. _, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
  114. Expect(err).To(MatchError(ErrNoCachedConn))
  115. })
  116. })
  117. Context("validating request", func() {
  118. It("rejects plain HTTP requests", func() {
  119. req, err := http.NewRequest("GET", "http://www.example.org/", nil)
  120. req.Body = &mockBody{}
  121. Expect(err).ToNot(HaveOccurred())
  122. _, err = rt.RoundTrip(req)
  123. Expect(err).To(MatchError("quic: unsupported protocol scheme: http"))
  124. Expect(req.Body.(*mockBody).closed).To(BeTrue())
  125. })
  126. It("rejects requests without a URL", func() {
  127. req1.URL = nil
  128. req1.Body = &mockBody{}
  129. _, err := rt.RoundTrip(req1)
  130. Expect(err).To(MatchError("quic: nil Request.URL"))
  131. Expect(req1.Body.(*mockBody).closed).To(BeTrue())
  132. })
  133. It("rejects request without a URL Host", func() {
  134. req1.URL.Host = ""
  135. req1.Body = &mockBody{}
  136. _, err := rt.RoundTrip(req1)
  137. Expect(err).To(MatchError("quic: no Host in request URL"))
  138. Expect(req1.Body.(*mockBody).closed).To(BeTrue())
  139. })
  140. It("doesn't try to close the body if the request doesn't have one", func() {
  141. req1.URL = nil
  142. Expect(req1.Body).To(BeNil())
  143. _, err := rt.RoundTrip(req1)
  144. Expect(err).To(MatchError("quic: nil Request.URL"))
  145. })
  146. It("rejects requests without a header", func() {
  147. req1.Header = nil
  148. req1.Body = &mockBody{}
  149. _, err := rt.RoundTrip(req1)
  150. Expect(err).To(MatchError("quic: nil Request.Header"))
  151. Expect(req1.Body.(*mockBody).closed).To(BeTrue())
  152. })
  153. It("rejects requests with invalid header name fields", func() {
  154. req1.Header.Add("foobär", "value")
  155. _, err := rt.RoundTrip(req1)
  156. Expect(err).To(MatchError("quic: invalid http header field name \"foobär\""))
  157. })
  158. It("rejects requests with invalid header name values", func() {
  159. req1.Header.Add("foo", string([]byte{0x7}))
  160. _, err := rt.RoundTrip(req1)
  161. Expect(err.Error()).To(ContainSubstring("quic: invalid http header field value"))
  162. })
  163. It("rejects requests with an invalid request method", func() {
  164. req1.Method = "foobär"
  165. req1.Body = &mockBody{}
  166. _, err := rt.RoundTrip(req1)
  167. Expect(err).To(MatchError("quic: invalid method \"foobär\""))
  168. Expect(req1.Body.(*mockBody).closed).To(BeTrue())
  169. })
  170. })
  171. Context("closing", func() {
  172. It("closes", func() {
  173. rt.clients = make(map[string]roundTripCloser)
  174. cl := &mockClient{}
  175. rt.clients["foo.bar"] = cl
  176. err := rt.Close()
  177. Expect(err).ToNot(HaveOccurred())
  178. Expect(len(rt.clients)).To(BeZero())
  179. Expect(cl.closed).To(BeTrue())
  180. })
  181. It("closes a RoundTripper that has never been used", func() {
  182. Expect(len(rt.clients)).To(BeZero())
  183. err := rt.Close()
  184. Expect(err).ToNot(HaveOccurred())
  185. Expect(len(rt.clients)).To(BeZero())
  186. })
  187. })
  188. })