Browse Source

Add h1SkipWaitForReply Option to http proxy protocol

Shelikhoo 2 years ago
parent
commit
cfc6bd465b

+ 68 - 26
proxy/http/client.go

@@ -2,12 +2,14 @@ package http
 
 import (
 	"bufio"
+	"bytes"
 	"context"
 	"encoding/base64"
 	"io"
 	"net/http"
 	"net/url"
 	"sync"
+	"time"
 
 	"golang.org/x/net/http2"
 
@@ -29,8 +31,9 @@ import (
 )
 
 type Client struct {
-	serverPicker  protocol.ServerPicker
-	policyManager policy.Manager
+	serverPicker       protocol.ServerPicker
+	policyManager      policy.Manager
+	h1SkipWaitForReply bool
 }
 
 type h2Conn struct {
@@ -59,8 +62,9 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
 
 	v := core.MustFromContext(ctx)
 	return &Client{
-		serverPicker:  protocol.NewRoundRobinServerPicker(serverList),
-		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
+		serverPicker:       protocol.NewRoundRobinServerPicker(serverList),
+		policyManager:      v.GetFeature(policy.ManagerType()).(policy.Manager),
+		h1SkipWaitForReply: config.H1SkipWaitForReply,
 	}, nil
 }
 
@@ -87,7 +91,13 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 		// transmitted together. Note we should not get stuck here, as the payload may
 		// not exist (considering to access MySQL database via a HTTP proxy, where the
 		// server sends hello to the client first).
-		if mbuf, _ := reader.ReadMultiBufferTimeout(proxy.FirstPayloadTimeout); mbuf != nil {
+		waitTime := proxy.FirstPayloadTimeout
+		if c.h1SkipWaitForReply {
+			// Some server require first write to be present in client hello.
+			// Increase timeout to if the client have explicitly requested to skip waiting for reply.
+			waitTime = time.Second
+		}
+		if mbuf, _ := reader.ReadMultiBufferTimeout(waitTime); mbuf != nil {
 			mlen := mbuf.Len()
 			firstPayload = bytespool.Alloc(mlen)
 			mbuf, _ = buf.SplitBytes(mbuf, firstPayload)
@@ -103,14 +113,19 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 		dest := server.Destination()
 		user = server.PickUser()
 
-		netConn, err := setUpHTTPTunnel(ctx, dest, targetAddr, user, dialer, firstPayload)
+		netConn, firstResp, err := setUpHTTPTunnel(ctx, dest, targetAddr, user, dialer, firstPayload, c.h1SkipWaitForReply)
 		if netConn != nil {
-			if _, ok := netConn.(*http2Conn); !ok {
+			if _, ok := netConn.(*http2Conn); !ok && !c.h1SkipWaitForReply {
 				if _, err := netConn.Write(firstPayload); err != nil {
 					netConn.Close()
 					return err
 				}
 			}
+			if firstResp != nil {
+				if err := link.Writer.WriteMultiBuffer(firstResp); err != nil {
+					return err
+				}
+			}
 			conn = internet.Connection(netConn)
 		}
 		return err
@@ -150,7 +165,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 }
 
 // setUpHTTPTunnel will create a socket tunnel via HTTP CONNECT method
-func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, user *protocol.MemoryUser, dialer internet.Dialer, firstPayload []byte) (net.Conn, error) {
+func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, user *protocol.MemoryUser, dialer internet.Dialer, firstPayload []byte, writeFirstPayloadInH1 bool,
+) (net.Conn, buf.MultiBuffer, error) {
 	req := &http.Request{
 		Method: http.MethodConnect,
 		URL:    &url.URL{Host: target},
@@ -164,27 +180,53 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
 		req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
 	}
 
-	connectHTTP1 := func(rawConn net.Conn) (net.Conn, error) {
+	connectHTTP1 := func(rawConn net.Conn) (net.Conn, buf.MultiBuffer, error) {
 		req.Header.Set("Proxy-Connection", "Keep-Alive")
 
-		err := req.Write(rawConn)
-		if err != nil {
-			rawConn.Close()
-			return nil, err
+		if !writeFirstPayloadInH1 {
+			err := req.Write(rawConn)
+			if err != nil {
+				rawConn.Close()
+				return nil, nil, err
+			}
+		} else {
+			buffer := bytes.NewBuffer(nil)
+			err := req.Write(buffer)
+			if err != nil {
+				rawConn.Close()
+				return nil, nil, err
+			}
+			_, err = io.Copy(buffer, bytes.NewReader(firstPayload))
+			if err != nil {
+				rawConn.Close()
+				return nil, nil, err
+			}
+			_, err = rawConn.Write(buffer.Bytes())
+			if err != nil {
+				rawConn.Close()
+				return nil, nil, err
+			}
 		}
-
-		resp, err := http.ReadResponse(bufio.NewReader(rawConn), req)
+		bufferedReader := bufio.NewReader(rawConn)
+		resp, err := http.ReadResponse(bufferedReader, req)
 		if err != nil {
 			rawConn.Close()
-			return nil, err
+			return nil, nil, err
 		}
 		defer resp.Body.Close()
 
 		if resp.StatusCode != http.StatusOK {
 			rawConn.Close()
-			return nil, newError("Proxy responded with non 200 code: " + resp.Status)
+			return nil, nil, newError("Proxy responded with non 200 code: " + resp.Status)
 		}
-		return rawConn, nil
+		if bufferedReader.Buffered() > 0 {
+			payload, err := buf.ReadFrom(io.LimitReader(bufferedReader, int64(bufferedReader.Buffered())))
+			if err != nil {
+				return nil, nil, newError("unable to drain buffer: ").Base(err)
+			}
+			return rawConn, payload, nil
+		}
+		return rawConn, nil, nil
 	}
 
 	connectHTTP2 := func(rawConn net.Conn, h2clientConn *http2.ClientConn) (net.Conn, error) {
@@ -228,16 +270,16 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
 		if cc.CanTakeNewRequest() {
 			proxyConn, err := connectHTTP2(rc, cc)
 			if err != nil {
-				return nil, err
+				return nil, nil, err
 			}
 
-			return proxyConn, nil
+			return proxyConn, nil, nil
 		}
 	}
 
 	rawConn, err := dialer.Dial(ctx, dest)
 	if err != nil {
-		return nil, err
+		return nil, nil, err
 	}
 
 	iConn := rawConn
@@ -249,7 +291,7 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
 	if tlsConn, ok := iConn.(*tls.Conn); ok {
 		if err := tlsConn.Handshake(); err != nil {
 			rawConn.Close()
-			return nil, err
+			return nil, nil, err
 		}
 		nextProto = tlsConn.ConnectionState().NegotiatedProtocol
 	}
@@ -262,13 +304,13 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
 		h2clientConn, err := t.NewClientConn(rawConn)
 		if err != nil {
 			rawConn.Close()
-			return nil, err
+			return nil, nil, err
 		}
 
 		proxyConn, err := connectHTTP2(rawConn, h2clientConn)
 		if err != nil {
 			rawConn.Close()
-			return nil, err
+			return nil, nil, err
 		}
 
 		cachedH2Mutex.Lock()
@@ -282,9 +324,9 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
 		}
 		cachedH2Mutex.Unlock()
 
-		return proxyConn, err
+		return proxyConn, nil, err
 	default:
-		return nil, newError("negotiated unsupported application layer protocol: " + nextProto)
+		return nil, nil, newError("negotiated unsupported application layer protocol: " + nextProto)
 	}
 }
 

+ 1 - 0
proxy/http/config.proto

@@ -25,4 +25,5 @@ message ServerConfig {
 message ClientConfig {
   // Sever is a list of HTTP server addresses.
   repeated v2ray.core.common.protocol.ServerEndpoint server = 1;
+  bool h1_skip_wait_for_reply = 2;
 }

+ 1 - 0
proxy/http/simplified/config.go

@@ -25,6 +25,7 @@ func init() {
 					Port:    simplifiedClient.Port,
 				},
 			},
+			H1SkipWaitForReply: simplifiedClient.H1SkipWaitForReply,
 		}
 		return common.CreateObject(ctx, fullClient)
 	}))

+ 1 - 0
proxy/http/simplified/config.proto

@@ -20,4 +20,5 @@ message ClientConfig {
 
   v2ray.core.common.net.IPOrDomain address = 1;
   uint32 port = 2;
+  bool h1_skip_wait_for_reply = 3;
 }