Browse Source

Refine header based websocket earlydata fix

Shelikhoo 2 years ago
parent
commit
c055a08b2c
1 changed files with 7 additions and 8 deletions
  1. 7 8
      transport/internet/websocket/hub.go

+ 7 - 8
transport/internet/websocket/hub.go

@@ -38,7 +38,8 @@ var upgrader = &websocket.Upgrader{
 }
 
 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
-	var earlyDataStr string
+	responseHeader := http.Header{}
+
 	var earlyData io.Reader
 	if !h.earlyDataEnabled { // nolint: gocritic
 		if request.URL.Path != h.path {
@@ -50,11 +51,14 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 			writer.WriteHeader(http.StatusNotFound)
 			return
 		}
-		earlyDataStr = request.Header.Get(h.earlyDataHeaderName)
+		earlyDataStr := request.Header.Get(h.earlyDataHeaderName)
 		earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr)))
+		if strings.EqualFold("Sec-WebSocket-Protocol", h.earlyDataHeaderName) {
+			responseHeader.Set(h.earlyDataHeaderName, earlyDataStr)
+		}
 	} else {
 		if strings.HasPrefix(request.URL.RequestURI(), h.path) {
-			earlyDataStr = request.URL.RequestURI()[len(h.path):]
+			earlyDataStr := request.URL.RequestURI()[len(h.path):]
 			earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr)))
 		} else {
 			writer.WriteHeader(http.StatusNotFound)
@@ -62,11 +66,6 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 		}
 	}
 
-	responseHeader := http.Header{}
-	if h.earlyDataEnabled && h.earlyDataHeaderName != "" {
-		responseHeader.Set(h.earlyDataHeaderName, earlyDataStr)
-	}
-
 	conn, err := upgrader.Upgrade(writer, request, responseHeader)
 	if err != nil {
 		newError("failed to convert to WebSocket connection").Base(err).WriteToLog()