Browse Source

added browser forwarder early data

Shelikhoo 4 years ago
parent
commit
af641f3219

+ 38 - 16
transport/internet/websocket/config.pb.go

@@ -1,7 +1,7 @@
 // Code generated by protoc-gen-go. DO NOT EDIT.
 // versions:
 // 	protoc-gen-go v1.25.0
-// 	protoc        v3.15.5
+// 	protoc        v3.13.0
 // source: transport/internet/websocket/config.proto
 
 package websocket
@@ -86,9 +86,11 @@ type Config struct {
 	unknownFields protoimpl.UnknownFields
 
 	// URL path to the WebSocket service. Empty value means root(/).
-	Path                string    `protobuf:"bytes,2,opt,name=path,proto3" json:"path,omitempty"`
-	Header              []*Header `protobuf:"bytes,3,rep,name=header,proto3" json:"header,omitempty"`
-	AcceptProxyProtocol bool      `protobuf:"varint,4,opt,name=accept_proxy_protocol,json=acceptProxyProtocol,proto3" json:"accept_proxy_protocol,omitempty"`
+	Path                 string    `protobuf:"bytes,2,opt,name=path,proto3" json:"path,omitempty"`
+	Header               []*Header `protobuf:"bytes,3,rep,name=header,proto3" json:"header,omitempty"`
+	AcceptProxyProtocol  bool      `protobuf:"varint,4,opt,name=accept_proxy_protocol,json=acceptProxyProtocol,proto3" json:"accept_proxy_protocol,omitempty"`
+	MaxEarlyData         int32     `protobuf:"varint,5,opt,name=max_early_data,json=maxEarlyData,proto3" json:"max_early_data,omitempty"`
+	UseBrowserForwarding bool      `protobuf:"varint,6,opt,name=use_browser_forwarding,json=useBrowserForwarding,proto3" json:"use_browser_forwarding,omitempty"`
 }
 
 func (x *Config) Reset() {
@@ -144,6 +146,20 @@ func (x *Config) GetAcceptProxyProtocol() bool {
 	return false
 }
 
+func (x *Config) GetMaxEarlyData() int32 {
+	if x != nil {
+		return x.MaxEarlyData
+	}
+	return 0
+}
+
+func (x *Config) GetUseBrowserForwarding() bool {
+	if x != nil {
+		return x.UseBrowserForwarding
+	}
+	return false
+}
+
 var File_transport_internet_websocket_config_proto protoreflect.FileDescriptor
 
 var file_transport_internet_websocket_config_proto_rawDesc = []byte{
@@ -155,7 +171,7 @@ var file_transport_internet_websocket_config_proto_rawDesc = []byte{
 	0x63, 0x6b, 0x65, 0x74, 0x22, 0x30, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x12, 0x10,
 	0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79,
 	0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52,
-	0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x9f, 0x01, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69,
+	0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0xfb, 0x01, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69,
 	0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52,
 	0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x47, 0x0a, 0x06, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18,
 	0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2f, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f,
@@ -165,17 +181,23 @@ var file_transport_internet_websocket_config_proto_rawDesc = []byte{
 	0x0a, 0x15, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x5f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x70,
 	0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x61,
 	0x63, 0x63, 0x65, 0x70, 0x74, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63,
-	0x6f, 0x6c, 0x4a, 0x04, 0x08, 0x01, 0x10, 0x02, 0x42, 0x96, 0x01, 0x0a, 0x2b, 0x63, 0x6f, 0x6d,
-	0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x74, 0x72, 0x61, 0x6e,
-	0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x77,
-	0x65, 0x62, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x50, 0x01, 0x5a, 0x3b, 0x67, 0x69, 0x74, 0x68,
-	0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x76, 0x32, 0x66, 0x6c, 0x79, 0x2f, 0x76, 0x32, 0x72,
-	0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x76, 0x34, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73,
-	0x70, 0x6f, 0x72, 0x74, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2f, 0x77, 0x65,
-	0x62, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0xaa, 0x02, 0x27, 0x56, 0x32, 0x52, 0x61, 0x79, 0x2e,
-	0x43, 0x6f, 0x72, 0x65, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x49,
-	0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x57, 0x65, 0x62, 0x73, 0x6f, 0x63, 0x6b, 0x65,
-	0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+	0x6f, 0x6c, 0x12, 0x24, 0x0a, 0x0e, 0x6d, 0x61, 0x78, 0x5f, 0x65, 0x61, 0x72, 0x6c, 0x79, 0x5f,
+	0x64, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x6d, 0x61, 0x78, 0x45,
+	0x61, 0x72, 0x6c, 0x79, 0x44, 0x61, 0x74, 0x61, 0x12, 0x34, 0x0a, 0x16, 0x75, 0x73, 0x65, 0x5f,
+	0x62, 0x72, 0x6f, 0x77, 0x73, 0x65, 0x72, 0x5f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69,
+	0x6e, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x75, 0x73, 0x65, 0x42, 0x72, 0x6f,
+	0x77, 0x73, 0x65, 0x72, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x4a, 0x04,
+	0x08, 0x01, 0x10, 0x02, 0x42, 0x96, 0x01, 0x0a, 0x2b, 0x63, 0x6f, 0x6d, 0x2e, 0x76, 0x32, 0x72,
+	0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72,
+	0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x77, 0x65, 0x62, 0x73, 0x6f,
+	0x63, 0x6b, 0x65, 0x74, 0x50, 0x01, 0x5a, 0x3b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63,
+	0x6f, 0x6d, 0x2f, 0x76, 0x32, 0x66, 0x6c, 0x79, 0x2f, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2d, 0x63,
+	0x6f, 0x72, 0x65, 0x2f, 0x76, 0x34, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74,
+	0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x73, 0x6f, 0x63,
+	0x6b, 0x65, 0x74, 0xaa, 0x02, 0x27, 0x56, 0x32, 0x52, 0x61, 0x79, 0x2e, 0x43, 0x6f, 0x72, 0x65,
+	0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x49, 0x6e, 0x74, 0x65, 0x72,
+	0x6e, 0x65, 0x74, 0x2e, 0x57, 0x65, 0x62, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x62, 0x06, 0x70,
+	0x72, 0x6f, 0x74, 0x6f, 0x33,
 }
 
 var (

+ 4 - 0
transport/internet/websocket/config.proto

@@ -20,4 +20,8 @@ message Config {
   repeated Header header = 3;
 
   bool accept_proxy_protocol = 4;
+
+  int32 max_early_data = 5;
+
+  bool use_browser_forwarding = 6;
 }

+ 92 - 0
transport/internet/websocket/connection.go

@@ -3,6 +3,7 @@
 package websocket
 
 import (
+	"context"
 	"io"
 	"net"
 	"time"
@@ -23,6 +24,15 @@ type connection struct {
 	conn       *websocket.Conn
 	reader     io.Reader
 	remoteAddr net.Addr
+
+	shouldWait        bool
+	delayedDialFinish context.Context
+	finishedDial      context.CancelFunc
+	dialer            DelayedDialer
+}
+
+type DelayedDialer interface {
+	Dial(earlyData []byte) (*websocket.Conn, error)
 }
 
 func newConnection(conn *websocket.Conn, remoteAddr net.Addr) *connection {
@@ -32,6 +42,41 @@ func newConnection(conn *websocket.Conn, remoteAddr net.Addr) *connection {
 	}
 }
 
+func newConnectionWithEarlyData(conn *websocket.Conn, remoteAddr net.Addr, earlyData io.Reader) *connection {
+	return &connection{
+		conn:       conn,
+		remoteAddr: remoteAddr,
+		reader:     earlyData,
+	}
+}
+
+func newConnectionWithDelayedDial(dialer DelayedDialer) *connection {
+	delayedDialContext, CancellFunc := context.WithCancel(context.Background())
+	return &connection{
+		shouldWait:        true,
+		delayedDialFinish: delayedDialContext,
+		finishedDial:      CancellFunc,
+		dialer:            dialer,
+	}
+}
+
+func newRelayedConnectionWithDelayedDial(dialer DelayedDialerForwarded) *connectionForwarder {
+	delayedDialContext, CancellFunc := context.WithCancel(context.Background())
+	return &connectionForwarder{
+		shouldWait:        true,
+		delayedDialFinish: delayedDialContext,
+		finishedDial:      CancellFunc,
+		dialer:            dialer,
+	}
+}
+
+func newRelayedConnection(conn io.ReadWriteCloser) *connectionForwarder {
+	return &connectionForwarder{
+		ReadWriteCloser: conn,
+		shouldWait:      false,
+	}
+}
+
 // Read implements net.Conn.Read()
 func (c *connection) Read(b []byte) (int, error) {
 	for {
@@ -50,6 +95,12 @@ func (c *connection) Read(b []byte) (int, error) {
 }
 
 func (c *connection) getReader() (io.Reader, error) {
+	if c.shouldWait {
+		<-c.delayedDialFinish.Done()
+		if c.conn == nil {
+			return nil, newError("unable to read delayed dial websocket connection as it do not exist")
+		}
+	}
 	if c.reader != nil {
 		return c.reader, nil
 	}
@@ -64,6 +115,17 @@ func (c *connection) getReader() (io.Reader, error) {
 
 // Write implements io.Writer.
 func (c *connection) Write(b []byte) (int, error) {
+	if c.shouldWait {
+		var err error
+		c.conn, err = c.dialer.Dial(b)
+		c.finishedDial()
+		if err != nil {
+			return 0, newError("Unable to proceed with delayed write").Base(err)
+		}
+		c.remoteAddr = c.conn.RemoteAddr()
+		c.shouldWait = false
+		return len(b), nil
+	}
 	if err := c.conn.WriteMessage(websocket.BinaryMessage, b); err != nil {
 		return 0, err
 	}
@@ -78,6 +140,12 @@ func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
 }
 
 func (c *connection) Close() error {
+	if c.shouldWait {
+		<-c.delayedDialFinish.Done()
+		if c.conn == nil {
+			return newError("unable to close delayed dial websocket connection as it do not exist")
+		}
+	}
 	var errors []interface{}
 	if err := c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil {
 		errors = append(errors, err)
@@ -92,6 +160,16 @@ func (c *connection) Close() error {
 }
 
 func (c *connection) LocalAddr() net.Addr {
+	if c.shouldWait {
+		<-c.delayedDialFinish.Done()
+		if c.conn == nil {
+			newError("websocket transport is not materialized when LocalAddr() is called").AtWarning().WriteToLog()
+			return &net.UnixAddr{
+				Name: "@placeholder",
+				Net:  "unix",
+			}
+		}
+	}
 	return c.conn.LocalAddr()
 }
 
@@ -107,9 +185,23 @@ func (c *connection) SetDeadline(t time.Time) error {
 }
 
 func (c *connection) SetReadDeadline(t time.Time) error {
+	if c.shouldWait {
+		<-c.delayedDialFinish.Done()
+		if c.conn == nil {
+			newError("websocket transport is not materialized when SetReadDeadline() is called").AtWarning().WriteToLog()
+			return nil
+		}
+	}
 	return c.conn.SetReadDeadline(t)
 }
 
 func (c *connection) SetWriteDeadline(t time.Time) error {
+	if c.shouldWait {
+		<-c.delayedDialFinish.Done()
+		if c.conn == nil {
+			newError("websocket transport is not materialized when SetWriteDeadline() is called").AtWarning().WriteToLog()
+			return nil
+		}
+	}
 	return c.conn.SetWriteDeadline(t)
 }

+ 81 - 0
transport/internet/websocket/connforwarder.go

@@ -0,0 +1,81 @@
+package websocket
+
+import (
+	"context"
+	"io"
+	"net"
+	"time"
+)
+
+type connectionForwarder struct {
+	io.ReadWriteCloser
+
+	shouldWait        bool
+	delayedDialFinish context.Context
+	finishedDial      context.CancelFunc
+	dialer            DelayedDialerForwarded
+}
+
+func (c *connectionForwarder) Read(p []byte) (n int, err error) {
+	if c.shouldWait {
+		<-c.delayedDialFinish.Done()
+		if c.ReadWriteCloser == nil {
+			return 0, newError("unable to read delayed dial websocket connection as it do not exist")
+		}
+	}
+	return c.ReadWriteCloser.Read(p)
+}
+
+func (c *connectionForwarder) Write(p []byte) (n int, err error) {
+	if c.shouldWait {
+		var err error
+		c.ReadWriteCloser, err = c.dialer.Dial(p)
+		c.finishedDial()
+		if err != nil {
+			return 0, newError("Unable to proceed with delayed write").Base(err)
+		}
+		c.shouldWait = false
+		return len(p), nil
+	}
+	return c.ReadWriteCloser.Write(p)
+}
+
+func (c *connectionForwarder) Close() error {
+	if c.shouldWait {
+		<-c.delayedDialFinish.Done()
+		if c.ReadWriteCloser == nil {
+			return newError("unable to close delayed dial websocket connection as it do not exist")
+		}
+	}
+	return c.ReadWriteCloser.Close()
+}
+
+func (c connectionForwarder) LocalAddr() net.Addr {
+	return &net.UnixAddr{
+		Name: "not available",
+		Net:  "",
+	}
+}
+
+func (c connectionForwarder) RemoteAddr() net.Addr {
+	return &net.UnixAddr{
+		Name: "not available",
+		Net:  "",
+	}
+}
+
+func (c connectionForwarder) SetDeadline(t time.Time) error {
+	return nil
+}
+
+func (c connectionForwarder) SetReadDeadline(t time.Time) error {
+	return nil
+}
+
+func (c connectionForwarder) SetWriteDeadline(t time.Time) error {
+	return nil
+}
+
+type DelayedDialerForwarded interface {
+	Dial(earlyData []byte) (io.ReadWriteCloser, error)
+}

+ 106 - 0
transport/internet/websocket/dialer.go

@@ -3,7 +3,12 @@
 package websocket
 
 import (
+	"bytes"
 	"context"
+	"encoding/base64"
+	core "github.com/v2fly/v2ray-core/v4"
+	"github.com/v2fly/v2ray-core/v4/features/ext"
+	"io"
 	"time"
 
 	"github.com/gorilla/websocket"
@@ -55,6 +60,36 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in
 	}
 	uri := protocol + "://" + host + wsSettings.GetNormalizedPath()
 
+	if wsSettings.UseBrowserForwarding {
+		var forwarder ext.BrowserForwarder
+		err := core.RequireFeatures(ctx, func(Forwarder ext.BrowserForwarder) {
+			forwarder = Forwarder
+		})
+		if err != nil {
+			return nil, newError("cannot find browser forwarder service").Base(err)
+		}
+		if wsSettings.MaxEarlyData != 0 {
+			return newRelayedConnectionWithDelayedDial(&dialerWithEarlyDataRelayed{
+				forwarder: forwarder,
+				uriBase:   uri,
+				config:    wsSettings,
+			}), nil
+		}
+		conn, err := forwarder.DialWebsocket(uri, nil)
+		if err != nil {
+			return nil, newError("cannot dial with browser forwarder service").Base(err)
+		}
+		return newRelayedConnection(conn), nil
+	}
+
+	if wsSettings.MaxEarlyData != 0 {
+		return newConnectionWithDelayedDial(&dialerWithEarlyData{
+			dialer:  dialer,
+			uriBase: uri,
+			config:  wsSettings,
+		}), nil
+	}
+
 	conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader())
 	if err != nil {
 		var reason string
@@ -66,3 +101,74 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in
 
 	return newConnection(conn, conn.RemoteAddr()), nil
 }
+
+type dialerWithEarlyData struct {
+	dialer  *websocket.Dialer
+	uriBase string
+	config  *Config
+}
+
+func (d dialerWithEarlyData) Dial(earlyData []byte) (*websocket.Conn, error) {
+	earlyDataBuf := bytes.NewBuffer(nil)
+	base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, earlyDataBuf)
+
+	earlydata := bytes.NewReader(earlyData)
+	limitedEarlyDatareader := io.LimitReader(earlydata, int64(d.config.MaxEarlyData))
+	n, encerr := io.Copy(base64EarlyDataEncoder, limitedEarlyDatareader)
+	if encerr != nil {
+		return nil, newError("websocket delayed dialer cannot encode early data").Base(encerr)
+	}
+
+	if errc := base64EarlyDataEncoder.Close(); errc != nil {
+		return nil, newError("websocket delayed dialer cannot encode early data tail").Base(errc)
+	}
+
+	conn, resp, err := d.dialer.Dial(d.uriBase+string(earlyDataBuf.Bytes()), d.config.GetRequestHeader())
+	if err != nil {
+		var reason string
+		if resp != nil {
+			reason = resp.Status
+		}
+		return nil, newError("failed to dial to (", d.uriBase, ") with early data: ", reason).Base(err)
+	}
+	if n != int64(len(earlyData)) {
+		if errWrite := conn.WriteMessage(websocket.BinaryMessage, earlyData[n:]); errWrite != nil {
+			return nil, newError("failed to dial to (", d.uriBase, ") with early data as write of remainder early data failed: ").Base(err)
+		}
+	}
+	return conn, nil
+}
+
+type dialerWithEarlyDataRelayed struct {
+	forwarder ext.BrowserForwarder
+	uriBase   string
+	config    *Config
+}
+
+func (d dialerWithEarlyDataRelayed) Dial(earlyData []byte) (io.ReadWriteCloser, error) {
+	earlyDataBuf := bytes.NewBuffer(nil)
+	base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, earlyDataBuf)
+
+	earlydata := bytes.NewReader(earlyData)
+	limitedEarlyDatareader := io.LimitReader(earlydata, int64(d.config.MaxEarlyData))
+	n, encerr := io.Copy(base64EarlyDataEncoder, limitedEarlyDatareader)
+	if encerr != nil {
+		return nil, newError("websocket delayed dialer cannot encode early data").Base(encerr)
+	}
+
+	if errc := base64EarlyDataEncoder.Close(); errc != nil {
+		return nil, newError("websocket delayed dialer cannot encode early data tail").Base(errc)
+	}
+
+	conn, err := d.forwarder.DialWebsocket(d.uriBase+string(earlyDataBuf.Bytes()), d.config.GetRequestHeader())
+	if err != nil {
+		var reason string
+		return nil, newError("failed to dial to (", d.uriBase, ") with early data: ", reason).Base(err)
+	}
+	if n != int64(len(earlyData)) {
+		if _, errWrite := conn.Write(earlyData[n:]); errWrite != nil {
+			return nil, newError("failed to dial to (", d.uriBase, ") with early data as write of remainder early data failed: ").Base(err)
+		}
+	}
+	return conn, nil
+}