Browse Source

added client support for header based websocket early data

Shelikhoo 4 years ago
parent
commit
54d0c3d400
1 changed files with 29 additions and 3 deletions
  1. 29 3
      transport/internet/websocket/dialer.go

+ 29 - 3
transport/internet/websocket/dialer.go

@@ -7,11 +7,13 @@ import (
 	"context"
 	"encoding/base64"
 	"io"
+	"net/http"
 	"time"
 
 	"github.com/v2fly/v2ray-core/v4/features/extension"
 
 	"github.com/gorilla/websocket"
+
 	core "github.com/v2fly/v2ray-core/v4"
 
 	"github.com/v2fly/v2ray-core/v4/common"
@@ -91,7 +93,7 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in
 		}), nil
 	}
 
-	conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader())
+	conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader()) // nolint: bodyclose
 	if err != nil {
 		var reason string
 		if resp != nil {
@@ -124,7 +126,20 @@ func (d dialerWithEarlyData) Dial(earlyData []byte) (*websocket.Conn, error) {
 		return nil, newError("websocket delayed dialer cannot encode early data tail").Base(errc)
 	}
 
-	conn, resp, err := d.dialer.Dial(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader())
+	dialFunction := func() (*websocket.Conn, *http.Response, error) {
+		return d.dialer.Dial(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader())
+	}
+
+	if d.config.EarlyDataHeaderName != "" {
+		dialFunction = func() (*websocket.Conn, *http.Response, error) {
+			earlyDataStr := earlyDataBuf.String()
+			currentHeader := d.config.GetRequestHeader()
+			currentHeader.Set(d.config.EarlyDataHeaderName, earlyDataStr)
+			return d.dialer.Dial(d.uriBase, currentHeader)
+		}
+	}
+
+	conn, resp, err := dialFunction() // nolint: bodyclose
 	if err != nil {
 		var reason string
 		if resp != nil {
@@ -161,7 +176,18 @@ func (d dialerWithEarlyDataRelayed) Dial(earlyData []byte) (io.ReadWriteCloser,
 		return nil, newError("websocket delayed dialer cannot encode early data tail").Base(errc)
 	}
 
-	conn, err := d.forwarder.DialWebsocket(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader())
+	dialFunction := func() (io.ReadWriteCloser, error) {
+		return d.forwarder.DialWebsocket(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader())
+	}
+
+	if d.config.EarlyDataHeaderName != "" {
+		earlyDataStr := earlyDataBuf.String()
+		currentHeader := d.config.GetRequestHeader()
+		currentHeader.Set(d.config.EarlyDataHeaderName, earlyDataStr)
+		return d.forwarder.DialWebsocket(d.uriBase, currentHeader)
+	}
+
+	conn, err := dialFunction()
 	if err != nil {
 		var reason string
 		return nil, newError("failed to dial to (", d.uriBase, ") with early data: ", reason).Base(err)