Browse Source

http request decide protocol based on ALPN

Shelikhoo 2 years ago
parent
commit
b7e8554ee3

+ 13 - 9
transport/internet/request/roundtripper/httprt/httprt.go

@@ -7,10 +7,12 @@ import (
 	"context"
 	"encoding/base64"
 	"io"
+	gonet "net"
 	"net/http"
 
-	"github.com/v2fly/v2ray-core/v5/common"
+	"github.com/v2fly/v2ray-core/v5/transport/internet/transportcommon"
 
+	"github.com/v2fly/v2ray-core/v5/common"
 	"github.com/v2fly/v2ray-core/v5/common/net"
 	"github.com/v2fly/v2ray-core/v5/transport/internet/request"
 )
@@ -25,20 +27,22 @@ type httpTripperClient struct {
 	assembly request.TransportClientAssembly
 }
 
+type unimplementedBackDrop struct {
+}
+
+func (u unimplementedBackDrop) RoundTrip(r *http.Request) (*http.Response, error) {
+	return nil, newError("unimplemented")
+}
+
 func (h *httpTripperClient) OnTransportClientAssemblyReady(assembly request.TransportClientAssembly) {
 	h.assembly = assembly
 }
 
 func (h *httpTripperClient) RoundTrip(ctx context.Context, req request.Request, opts ...request.RoundTripperOption) (resp request.Response, err error) {
 	if h.httpRTT == nil {
-		h.httpRTT = &http.Transport{
-			DialContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) {
-				return h.assembly.AutoImplDialer().Dial(ctx)
-			},
-			DialTLSContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) {
-				return h.assembly.AutoImplDialer().Dial(ctx)
-			},
-		}
+		h.httpRTT = transportcommon.NewALPNAwareHTTPRoundTripper(ctx, func(ctx context.Context, addr string) (gonet.Conn, error) {
+			return h.assembly.AutoImplDialer().Dial(ctx)
+		}, unimplementedBackDrop{})
 	}
 
 	connectionTagStr := base64.RawURLEncoding.EncodeToString(req.ConnectionTag)

+ 2 - 1
transport/internet/request/stereotype/meek/meek.go

@@ -38,7 +38,8 @@ func meekDial(ctx context.Context, dest net.Destination, streamSettings *interne
 	}
 	httprtSetting := &httprt.ClientConfig{Http: &httprt.HTTPConfig{
 		UrlPrefix: meekSetting.Url,
-	}}
+	},
+	}
 	request := &assembly.Config{
 		Assembler:    serial.ToTypedMessage(simpleAssembler),
 		Roundtripper: serial.ToTypedMessage(httprtSetting),

+ 5 - 0
transport/internet/security/connprop.go

@@ -0,0 +1,5 @@
+package security
+
+type ConnectionApplicationProtocol interface {
+	GetConnectionApplicationProtocol() (string, error)
+}

+ 7 - 0
transport/internet/tls/tls.go

@@ -17,6 +17,13 @@ type Conn struct {
 	*tls.Conn
 }
 
+func (c *Conn) GetConnectionApplicationProtocol() (string, error) {
+	if err := c.Handshake(); err != nil {
+		return "", err
+	}
+	return c.ConnectionState().NegotiatedProtocol, nil
+}
+
 func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
 	mb = buf.Compact(mb)
 	mb, err := buf.WriteMultiBuffer(c, mb)

+ 12 - 1
transport/internet/tls/utls/utls.go

@@ -90,7 +90,18 @@ func (e Engine) Client(conn net.Conn, opts ...security.Option) (security.Conn, e
 	if err != nil {
 		return nil, newError("unable to finish utls handshake").Base(err)
 	}
-	return utlsClientConn, nil
+	return uTLSClientConnection{utlsClientConn}, nil
+}
+
+type uTLSClientConnection struct {
+	*utls.UConn
+}
+
+func (u uTLSClientConnection) GetConnectionApplicationProtocol() (string, error) {
+	if err := u.Handshake(); err != nil {
+		return "", err
+	}
+	return u.ConnectionState().NegotiatedProtocol, nil
 }
 
 func uTLSConfigFromTLSConfig(config *systls.Config) (*utls.Config, error) { // nolint: unparam

+ 215 - 0
transport/internet/transportcommon/httpDialer.go

@@ -0,0 +1,215 @@
+package transportcommon
+
+import (
+	"context"
+	"crypto/tls"
+	"errors"
+	"fmt"
+	"net"
+	"net/http"
+	"sync"
+	"time"
+
+	"github.com/v2fly/v2ray-core/v5/transport/internet/security"
+	"golang.org/x/net/http2"
+)
+
+type DialerFunc func(ctx context.Context, addr string) (net.Conn, error)
+
+// NewALPNAwareHTTPRoundTripper creates an instance of RoundTripper that dial to remote HTTPS endpoint with
+// an alternative version of TLS implementation.
+func NewALPNAwareHTTPRoundTripper(ctx context.Context, dialer DialerFunc,
+	backdropTransport http.RoundTripper) http.RoundTripper {
+	rtImpl := &alpnAwareHTTPRoundTripperImpl{
+		connectWithH1:     map[string]bool{},
+		backdropTransport: backdropTransport,
+		pendingConn:       map[pendingConnKey]*unclaimedConnection{},
+		dialer:            dialer,
+		ctx:               ctx,
+	}
+	rtImpl.init()
+	return rtImpl
+}
+
+type alpnAwareHTTPRoundTripperImpl struct {
+	accessConnectWithH1 sync.Mutex
+	connectWithH1       map[string]bool
+
+	httpsH1Transport  http.RoundTripper
+	httpsH2Transport  http.RoundTripper
+	backdropTransport http.RoundTripper
+
+	accessDialingConnection sync.Mutex
+	pendingConn             map[pendingConnKey]*unclaimedConnection
+
+	ctx    context.Context
+	dialer DialerFunc
+}
+
+type pendingConnKey struct {
+	isH2 bool
+	dest string
+}
+
+var errEAGAIN = errors.New("incorrect ALPN negotiated, try again with another ALPN")
+var errEAGAINTooMany = errors.New("incorrect ALPN negotiated")
+var errExpired = errors.New("connection have expired")
+
+func (r *alpnAwareHTTPRoundTripperImpl) RoundTrip(req *http.Request) (*http.Response, error) {
+	if req.URL.Scheme != "https" {
+		return r.backdropTransport.RoundTrip(req)
+	}
+	for retryCount := 0; retryCount < 5; retryCount++ {
+		effectivePort := req.URL.Port()
+		if effectivePort == "" {
+			effectivePort = "443"
+		}
+		if r.getShouldConnectWithH1(fmt.Sprintf("%v:%v", req.URL.Hostname(), effectivePort)) {
+			resp, err := r.httpsH1Transport.RoundTrip(req)
+			if errors.Is(err, errEAGAIN) {
+				continue
+			}
+			return resp, err
+		}
+		resp, err := r.httpsH2Transport.RoundTrip(req)
+		if errors.Is(err, errEAGAIN) {
+			continue
+		}
+		return resp, err
+	}
+	return nil, errEAGAINTooMany
+}
+
+func (r *alpnAwareHTTPRoundTripperImpl) getShouldConnectWithH1(domainName string) bool {
+	r.accessConnectWithH1.Lock()
+	defer r.accessConnectWithH1.Unlock()
+	if value, set := r.connectWithH1[domainName]; set {
+		return value
+	}
+	return false
+}
+
+func (r *alpnAwareHTTPRoundTripperImpl) setShouldConnectWithH1(domainName string) {
+	r.accessConnectWithH1.Lock()
+	defer r.accessConnectWithH1.Unlock()
+	r.connectWithH1[domainName] = true
+}
+
+func (r *alpnAwareHTTPRoundTripperImpl) clearShouldConnectWithH1(domainName string) {
+	r.accessConnectWithH1.Lock()
+	defer r.accessConnectWithH1.Unlock()
+	r.connectWithH1[domainName] = false
+}
+
+func getPendingConnectionID(dest string, alpnIsH2 bool) pendingConnKey {
+	return pendingConnKey{isH2: alpnIsH2, dest: dest}
+}
+
+func (r *alpnAwareHTTPRoundTripperImpl) putConn(addr string, alpnIsH2 bool, conn net.Conn) {
+	connId := getPendingConnectionID(addr, alpnIsH2)
+	r.pendingConn[connId] = NewUnclaimedConnection(conn, time.Minute)
+}
+func (r *alpnAwareHTTPRoundTripperImpl) getConn(addr string, alpnIsH2 bool) net.Conn {
+	connId := getPendingConnectionID(addr, alpnIsH2)
+	if conn, ok := r.pendingConn[connId]; ok {
+		delete(r.pendingConn, connId)
+		if claimedConnection, err := conn.claimConnection(); err == nil {
+			return claimedConnection
+		}
+	}
+	return nil
+}
+func (r *alpnAwareHTTPRoundTripperImpl) dialOrGetTLSWithExpectedALPN(ctx context.Context, addr string, expectedH2 bool) (net.Conn, error) {
+	r.accessDialingConnection.Lock()
+	defer r.accessDialingConnection.Unlock()
+
+	if r.getShouldConnectWithH1(addr) == expectedH2 {
+		return nil, errEAGAIN
+	}
+
+	//Get a cached connection if possible to reduce preflight connection closed without sending data
+	if gconn := r.getConn(addr, expectedH2); gconn != nil {
+		return gconn, nil
+	}
+
+	conn, err := r.dialTLS(ctx, addr)
+	if err != nil {
+		return nil, err
+	}
+
+	protocol := ""
+	if connAPLNGetter, ok := conn.(security.ConnectionApplicationProtocol); ok {
+		connectionALPN, err := connAPLNGetter.GetConnectionApplicationProtocol()
+		if err != nil {
+			return nil, newError("failed to get connection ALPN").Base(err).AtWarning()
+		}
+		protocol = connectionALPN
+	}
+
+	protocolIsH2 := protocol == http2.NextProtoTLS
+
+	if protocolIsH2 == expectedH2 {
+		return conn, err
+	}
+
+	r.putConn(addr, protocolIsH2, conn)
+
+	if protocolIsH2 {
+		r.clearShouldConnectWithH1(addr)
+	} else {
+		r.setShouldConnectWithH1(addr)
+	}
+
+	return nil, errEAGAIN
+}
+
+func (r *alpnAwareHTTPRoundTripperImpl) dialTLS(ctx context.Context, addr string) (net.Conn, error) {
+	return r.dialer(r.ctx, addr)
+}
+
+func (r *alpnAwareHTTPRoundTripperImpl) init() {
+	r.httpsH2Transport = &http2.Transport{
+		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+			return r.dialOrGetTLSWithExpectedALPN(context.Background(), addr, true)
+		},
+	}
+	r.httpsH1Transport = &http.Transport{
+		DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
+			return r.dialOrGetTLSWithExpectedALPN(ctx, addr, false)
+		},
+	}
+}
+
+func NewUnclaimedConnection(conn net.Conn, expireTime time.Duration) *unclaimedConnection {
+	c := &unclaimedConnection{
+		Conn: conn,
+	}
+	time.AfterFunc(expireTime, c.tick)
+	return c
+}
+
+type unclaimedConnection struct {
+	net.Conn
+	claimed bool
+	access  sync.Mutex
+}
+
+func (c *unclaimedConnection) claimConnection() (net.Conn, error) {
+	c.access.Lock()
+	defer c.access.Unlock()
+	if !c.claimed {
+		c.claimed = true
+		return c.Conn, nil
+	}
+	return nil, errExpired
+}
+
+func (c *unclaimedConnection) tick() {
+	c.access.Lock()
+	defer c.access.Unlock()
+	if !c.claimed {
+		c.claimed = true
+		c.Conn.Close()
+		c.Conn = nil
+	}
+}