|
|
@@ -3,9 +3,13 @@
|
|
|
package websocket
|
|
|
|
|
|
import (
|
|
|
+ "bytes"
|
|
|
"context"
|
|
|
"crypto/tls"
|
|
|
+ "encoding/base64"
|
|
|
+ "io"
|
|
|
"net/http"
|
|
|
+ "strings"
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
|
@@ -20,8 +24,9 @@ import (
|
|
|
)
|
|
|
|
|
|
type requestHandler struct {
|
|
|
- path string
|
|
|
- ln *Listener
|
|
|
+ path string
|
|
|
+ ln *Listener
|
|
|
+ earlyDataEnabled bool
|
|
|
}
|
|
|
|
|
|
var upgrader = &websocket.Upgrader{
|
|
|
@@ -34,10 +39,22 @@ var upgrader = &websocket.Upgrader{
|
|
|
}
|
|
|
|
|
|
func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
|
|
- if request.URL.Path != h.path {
|
|
|
- writer.WriteHeader(http.StatusNotFound)
|
|
|
- return
|
|
|
+ var earlyData io.Reader
|
|
|
+ if !h.earlyDataEnabled {
|
|
|
+ if request.URL.Path != h.path {
|
|
|
+ writer.WriteHeader(http.StatusNotFound)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if strings.HasPrefix(request.URL.RequestURI(), h.path) {
|
|
|
+ earlyDataStr := request.URL.RequestURI()[len(h.path):]
|
|
|
+ earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr)))
|
|
|
+ } else {
|
|
|
+ writer.WriteHeader(http.StatusNotFound)
|
|
|
+ return
|
|
|
+ }
|
|
|
}
|
|
|
+
|
|
|
conn, err := upgrader.Upgrade(writer, request, nil)
|
|
|
if err != nil {
|
|
|
newError("failed to convert to WebSocket connection").Base(err).WriteToLog()
|
|
|
@@ -52,8 +69,12 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
|
|
|
Port: int(0),
|
|
|
}
|
|
|
}
|
|
|
+ if earlyData == nil {
|
|
|
+ h.ln.addConn(newConnection(conn, remoteAddr))
|
|
|
+ } else {
|
|
|
+ h.ln.addConn(newConnectionWithEarlyData(conn, remoteAddr, earlyData))
|
|
|
+ }
|
|
|
|
|
|
- h.ln.addConn(newConnection(conn, remoteAddr))
|
|
|
}
|
|
|
|
|
|
type Listener struct {
|
|
|
@@ -114,11 +135,16 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet
|
|
|
}
|
|
|
|
|
|
l.listener = listener
|
|
|
+ var useEarlyData = false
|
|
|
+ if wsSettings.MaxEarlyData != 0 {
|
|
|
+ useEarlyData = true
|
|
|
+ }
|
|
|
|
|
|
l.server = http.Server{
|
|
|
Handler: &requestHandler{
|
|
|
- path: wsSettings.GetNormalizedPath(),
|
|
|
- ln: l,
|
|
|
+ path: wsSettings.GetNormalizedPath(),
|
|
|
+ ln: l,
|
|
|
+ earlyDataEnabled: useEarlyData,
|
|
|
},
|
|
|
ReadHeaderTimeout: time.Second * 4,
|
|
|
MaxHeaderBytes: 2048,
|