response_writer.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. package h2quic
  2. import (
  3. "bytes"
  4. "net/http"
  5. "strconv"
  6. "strings"
  7. "sync"
  8. quic "github.com/lucas-clemente/quic-go"
  9. "github.com/lucas-clemente/quic-go/internal/protocol"
  10. "github.com/lucas-clemente/quic-go/internal/utils"
  11. "golang.org/x/net/http2"
  12. "golang.org/x/net/http2/hpack"
  13. )
  14. type responseWriter struct {
  15. dataStreamID protocol.StreamID
  16. dataStream quic.Stream
  17. headerStream quic.Stream
  18. headerStreamMutex *sync.Mutex
  19. header http.Header
  20. status int // status code passed to WriteHeader
  21. headerWritten bool
  22. logger utils.Logger
  23. }
  24. func newResponseWriter(
  25. headerStream quic.Stream,
  26. headerStreamMutex *sync.Mutex,
  27. dataStream quic.Stream,
  28. dataStreamID protocol.StreamID,
  29. logger utils.Logger,
  30. ) *responseWriter {
  31. return &responseWriter{
  32. header: http.Header{},
  33. headerStream: headerStream,
  34. headerStreamMutex: headerStreamMutex,
  35. dataStream: dataStream,
  36. dataStreamID: dataStreamID,
  37. logger: logger,
  38. }
  39. }
  40. func (w *responseWriter) Header() http.Header {
  41. return w.header
  42. }
  43. func (w *responseWriter) WriteHeader(status int) {
  44. if w.headerWritten {
  45. return
  46. }
  47. w.headerWritten = true
  48. w.status = status
  49. var headers bytes.Buffer
  50. enc := hpack.NewEncoder(&headers)
  51. enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
  52. for k, v := range w.header {
  53. for index := range v {
  54. enc.WriteField(hpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
  55. }
  56. }
  57. w.logger.Infof("Responding with %d", status)
  58. w.headerStreamMutex.Lock()
  59. defer w.headerStreamMutex.Unlock()
  60. h2framer := http2.NewFramer(w.headerStream, nil)
  61. err := h2framer.WriteHeaders(http2.HeadersFrameParam{
  62. StreamID: uint32(w.dataStreamID),
  63. EndHeaders: true,
  64. BlockFragment: headers.Bytes(),
  65. })
  66. if err != nil {
  67. w.logger.Errorf("could not write h2 header: %s", err.Error())
  68. }
  69. }
  70. func (w *responseWriter) Write(p []byte) (int, error) {
  71. if !w.headerWritten {
  72. w.WriteHeader(200)
  73. }
  74. if !bodyAllowedForStatus(w.status) {
  75. return 0, http.ErrBodyNotAllowed
  76. }
  77. return w.dataStream.Write(p)
  78. }
  79. func (w *responseWriter) Flush() {}
  80. // This is a NOP. Use http.Request.Context
  81. func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
  82. // test that we implement http.Flusher
  83. var _ http.Flusher = &responseWriter{}
  84. // copied from http2/http2.go
  85. // bodyAllowedForStatus reports whether a given response status code
  86. // permits a body. See RFC 2616, section 4.4.
  87. func bodyAllowedForStatus(status int) bool {
  88. switch {
  89. case status >= 100 && status <= 199:
  90. return false
  91. case status == 204:
  92. return false
  93. case status == 304:
  94. return false
  95. }
  96. return true
  97. }