Browse Source

Optimize HTTP tunnel setup in TFO environment

Anonymous-Someneese 5 years ago
parent
commit
a5caa01cb6
1 changed files with 43 additions and 8 deletions
  1. 43 8
      proxy/http/client.go

+ 43 - 8
proxy/http/client.go

@@ -4,9 +4,9 @@ package http
 
 import (
 	"bufio"
+	"io"
 	"context"
 	"encoding/base64"
-	"io"
 	"net/http"
 	"strings"
 
@@ -93,9 +93,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 		p = c.policyManager.ForLevel(user.Level)
 	}
 
-	if err := setUpHTTPTunnel(conn, &destination, user); err != nil {
-		return err
-	}
+	conn = setUpHTTPTunnel(conn, &destination, user)
 
 	ctx, cancel := context.WithCancel(ctx)
 	timer := signal.CancelAfterInactivity(ctx, cancel, p.Timeouts.ConnectionIdle)
@@ -125,8 +123,13 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 	return nil
 }
 
+type tunnelConn struct {
+	internet.Connection
+	header *buf.Buffer
+}
+
 // setUpHTTPTunnel will create a socket tunnel via HTTP CONNECT method
-func setUpHTTPTunnel(writer io.Writer, destination *net.Destination, user *protocol.MemoryUser) error {
+func setUpHTTPTunnel(conn internet.Connection, destination *net.Destination, user *protocol.MemoryUser) *tunnelConn {
 	var headers []string
 	destNetAddr := destination.NetAddr()
 	headers = append(headers, "CONNECT "+destNetAddr+" HTTP/1.1")
@@ -140,11 +143,43 @@ func setUpHTTPTunnel(writer io.Writer, destination *net.Destination, user *proto
 
 	b := buf.New()
 	b.WriteString(strings.Join(headers, "\r\n") + "\r\n\r\n")
-	if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil {
-		return err
+	return &tunnelConn {
+		Connection: conn,
+		header:     b,
 	}
+}
 
-	return nil
+func (c *tunnelConn) Write(b []byte) (n int, err error) {
+	if c.header == nil {
+		return c.Connection.Write(b)
+	}
+	buffer := c.header
+	lenheader := c.header.Len()
+	// Concate header and b
+	_, err = buffer.Write(b)
+	if err != nil {
+		c.header.Resize(0, lenheader)
+		return 0, err
+	}
+	// Write buffer
+	nc, err := io.Copy(c.Connection, buffer)
+	if int32(nc) < lenheader {
+		c.header.Resize(int32(nc), lenheader)
+		return 0, err
+	}
+	c.header.Release()
+	c.header = nil
+	n = int(nc) - int(lenheader)
+	if err != nil {
+		return n, err
+	}
+	// Write trailing bytes
+	if n < len(b) {
+		var nw int
+		nw, err = c.Connection.Write(b[:n])
+		n += nw
+	}
+	return n, err
 }
 
 func init() {