|
|
@@ -7,11 +7,13 @@ import (
|
|
|
"context"
|
|
|
"encoding/base64"
|
|
|
"io"
|
|
|
+ "net/http"
|
|
|
"time"
|
|
|
|
|
|
"github.com/v2fly/v2ray-core/v4/features/extension"
|
|
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
+
|
|
|
core "github.com/v2fly/v2ray-core/v4"
|
|
|
|
|
|
"github.com/v2fly/v2ray-core/v4/common"
|
|
|
@@ -91,7 +93,7 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in
|
|
|
}), nil
|
|
|
}
|
|
|
|
|
|
- conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader())
|
|
|
+ conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader()) // nolint: bodyclose
|
|
|
if err != nil {
|
|
|
var reason string
|
|
|
if resp != nil {
|
|
|
@@ -124,7 +126,20 @@ func (d dialerWithEarlyData) Dial(earlyData []byte) (*websocket.Conn, error) {
|
|
|
return nil, newError("websocket delayed dialer cannot encode early data tail").Base(errc)
|
|
|
}
|
|
|
|
|
|
- conn, resp, err := d.dialer.Dial(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader())
|
|
|
+ dialFunction := func() (*websocket.Conn, *http.Response, error) {
|
|
|
+ return d.dialer.Dial(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader())
|
|
|
+ }
|
|
|
+
|
|
|
+ if d.config.EarlyDataHeaderName != "" {
|
|
|
+ dialFunction = func() (*websocket.Conn, *http.Response, error) {
|
|
|
+ earlyDataStr := earlyDataBuf.String()
|
|
|
+ currentHeader := d.config.GetRequestHeader()
|
|
|
+ currentHeader.Set(d.config.EarlyDataHeaderName, earlyDataStr)
|
|
|
+ return d.dialer.Dial(d.uriBase, currentHeader)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ conn, resp, err := dialFunction() // nolint: bodyclose
|
|
|
if err != nil {
|
|
|
var reason string
|
|
|
if resp != nil {
|
|
|
@@ -161,7 +176,18 @@ func (d dialerWithEarlyDataRelayed) Dial(earlyData []byte) (io.ReadWriteCloser,
|
|
|
return nil, newError("websocket delayed dialer cannot encode early data tail").Base(errc)
|
|
|
}
|
|
|
|
|
|
- conn, err := d.forwarder.DialWebsocket(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader())
|
|
|
+ dialFunction := func() (io.ReadWriteCloser, error) {
|
|
|
+ return d.forwarder.DialWebsocket(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader())
|
|
|
+ }
|
|
|
+
|
|
|
+ if d.config.EarlyDataHeaderName != "" {
|
|
|
+ earlyDataStr := earlyDataBuf.String()
|
|
|
+ currentHeader := d.config.GetRequestHeader()
|
|
|
+ currentHeader.Set(d.config.EarlyDataHeaderName, earlyDataStr)
|
|
|
+ return d.forwarder.DialWebsocket(d.uriBase, currentHeader)
|
|
|
+ }
|
|
|
+
|
|
|
+ conn, err := dialFunction()
|
|
|
if err != nil {
|
|
|
var reason string
|
|
|
return nil, newError("failed to dial to (", d.uriBase, ") with early data: ", reason).Base(err)
|