Browse Source

fix early data listener bug

Shelikhoo 4 years ago
parent
commit
3a77bbdf65
1 changed files with 34 additions and 8 deletions
  1. 34 8
      transport/internet/websocket/hub.go

+ 34 - 8
transport/internet/websocket/hub.go

@@ -3,9 +3,13 @@
 package websocket
 
 import (
+	"bytes"
 	"context"
 	"crypto/tls"
+	"encoding/base64"
+	"io"
 	"net/http"
+	"strings"
 	"sync"
 	"time"
 
@@ -20,8 +24,9 @@ import (
 )
 
 type requestHandler struct {
-	path string
-	ln   *Listener
+	path             string
+	ln               *Listener
+	earlyDataEnabled bool
 }
 
 var upgrader = &websocket.Upgrader{
@@ -34,10 +39,22 @@ var upgrader = &websocket.Upgrader{
 }
 
 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
-	if request.URL.Path != h.path {
-		writer.WriteHeader(http.StatusNotFound)
-		return
+	var earlyData io.Reader
+	if !h.earlyDataEnabled {
+		if request.URL.Path != h.path {
+			writer.WriteHeader(http.StatusNotFound)
+			return
+		}
+	} else {
+		if strings.HasPrefix(request.URL.RequestURI(), h.path) {
+			earlyDataStr := request.URL.RequestURI()[len(h.path):]
+			earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr)))
+		} else {
+			writer.WriteHeader(http.StatusNotFound)
+			return
+		}
 	}
+
 	conn, err := upgrader.Upgrade(writer, request, nil)
 	if err != nil {
 		newError("failed to convert to WebSocket connection").Base(err).WriteToLog()
@@ -52,8 +69,12 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 			Port: int(0),
 		}
 	}
+	if earlyData == nil {
+		h.ln.addConn(newConnection(conn, remoteAddr))
+	} else {
+		h.ln.addConn(newConnectionWithEarlyData(conn, remoteAddr, earlyData))
+	}
 
-	h.ln.addConn(newConnection(conn, remoteAddr))
 }
 
 type Listener struct {
@@ -114,11 +135,16 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet
 	}
 
 	l.listener = listener
+	var useEarlyData = false
+	if wsSettings.MaxEarlyData != 0 {
+		useEarlyData = true
+	}
 
 	l.server = http.Server{
 		Handler: &requestHandler{
-			path: wsSettings.GetNormalizedPath(),
-			ln:   l,
+			path:             wsSettings.GetNormalizedPath(),
+			ln:               l,
+			earlyDataEnabled: useEarlyData,
 		},
 		ReadHeaderTimeout: time.Second * 4,
 		MaxHeaderBytes:    2048,