|
|
@@ -2,12 +2,14 @@ package http
|
|
|
|
|
|
import (
|
|
|
"bufio"
|
|
|
+ "bytes"
|
|
|
"context"
|
|
|
"encoding/base64"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
"net/url"
|
|
|
"sync"
|
|
|
+ "time"
|
|
|
|
|
|
"golang.org/x/net/http2"
|
|
|
|
|
|
@@ -29,8 +31,9 @@ import (
|
|
|
)
|
|
|
|
|
|
type Client struct {
|
|
|
- serverPicker protocol.ServerPicker
|
|
|
- policyManager policy.Manager
|
|
|
+ serverPicker protocol.ServerPicker
|
|
|
+ policyManager policy.Manager
|
|
|
+ h1SkipWaitForReply bool
|
|
|
}
|
|
|
|
|
|
type h2Conn struct {
|
|
|
@@ -59,8 +62,9 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
|
|
|
|
|
|
v := core.MustFromContext(ctx)
|
|
|
return &Client{
|
|
|
- serverPicker: protocol.NewRoundRobinServerPicker(serverList),
|
|
|
- policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
|
|
+ serverPicker: protocol.NewRoundRobinServerPicker(serverList),
|
|
|
+ policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
|
|
+ h1SkipWaitForReply: config.H1SkipWaitForReply,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
@@ -87,7 +91,13 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
|
|
|
// transmitted together. Note we should not get stuck here, as the payload may
|
|
|
// not exist (considering to access MySQL database via a HTTP proxy, where the
|
|
|
// server sends hello to the client first).
|
|
|
- if mbuf, _ := reader.ReadMultiBufferTimeout(proxy.FirstPayloadTimeout); mbuf != nil {
|
|
|
+ waitTime := proxy.FirstPayloadTimeout
|
|
|
+ if c.h1SkipWaitForReply {
|
|
|
+ // Some server require first write to be present in client hello.
|
|
|
+ // Increase timeout to if the client have explicitly requested to skip waiting for reply.
|
|
|
+ waitTime = time.Second
|
|
|
+ }
|
|
|
+ if mbuf, _ := reader.ReadMultiBufferTimeout(waitTime); mbuf != nil {
|
|
|
mlen := mbuf.Len()
|
|
|
firstPayload = bytespool.Alloc(mlen)
|
|
|
mbuf, _ = buf.SplitBytes(mbuf, firstPayload)
|
|
|
@@ -103,14 +113,19 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
|
|
|
dest := server.Destination()
|
|
|
user = server.PickUser()
|
|
|
|
|
|
- netConn, err := setUpHTTPTunnel(ctx, dest, targetAddr, user, dialer, firstPayload)
|
|
|
+ netConn, firstResp, err := setUpHTTPTunnel(ctx, dest, targetAddr, user, dialer, firstPayload, c.h1SkipWaitForReply)
|
|
|
if netConn != nil {
|
|
|
- if _, ok := netConn.(*http2Conn); !ok {
|
|
|
+ if _, ok := netConn.(*http2Conn); !ok && !c.h1SkipWaitForReply {
|
|
|
if _, err := netConn.Write(firstPayload); err != nil {
|
|
|
netConn.Close()
|
|
|
return err
|
|
|
}
|
|
|
}
|
|
|
+ if firstResp != nil {
|
|
|
+ if err := link.Writer.WriteMultiBuffer(firstResp); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
conn = internet.Connection(netConn)
|
|
|
}
|
|
|
return err
|
|
|
@@ -150,7 +165,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
|
|
|
}
|
|
|
|
|
|
// setUpHTTPTunnel will create a socket tunnel via HTTP CONNECT method
|
|
|
-func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, user *protocol.MemoryUser, dialer internet.Dialer, firstPayload []byte) (net.Conn, error) {
|
|
|
+func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, user *protocol.MemoryUser, dialer internet.Dialer, firstPayload []byte, writeFirstPayloadInH1 bool,
|
|
|
+) (net.Conn, buf.MultiBuffer, error) {
|
|
|
req := &http.Request{
|
|
|
Method: http.MethodConnect,
|
|
|
URL: &url.URL{Host: target},
|
|
|
@@ -164,27 +180,53 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
|
|
|
req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
|
|
}
|
|
|
|
|
|
- connectHTTP1 := func(rawConn net.Conn) (net.Conn, error) {
|
|
|
+ connectHTTP1 := func(rawConn net.Conn) (net.Conn, buf.MultiBuffer, error) {
|
|
|
req.Header.Set("Proxy-Connection", "Keep-Alive")
|
|
|
|
|
|
- err := req.Write(rawConn)
|
|
|
- if err != nil {
|
|
|
- rawConn.Close()
|
|
|
- return nil, err
|
|
|
+ if !writeFirstPayloadInH1 {
|
|
|
+ err := req.Write(rawConn)
|
|
|
+ if err != nil {
|
|
|
+ rawConn.Close()
|
|
|
+ return nil, nil, err
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ buffer := bytes.NewBuffer(nil)
|
|
|
+ err := req.Write(buffer)
|
|
|
+ if err != nil {
|
|
|
+ rawConn.Close()
|
|
|
+ return nil, nil, err
|
|
|
+ }
|
|
|
+ _, err = io.Copy(buffer, bytes.NewReader(firstPayload))
|
|
|
+ if err != nil {
|
|
|
+ rawConn.Close()
|
|
|
+ return nil, nil, err
|
|
|
+ }
|
|
|
+ _, err = rawConn.Write(buffer.Bytes())
|
|
|
+ if err != nil {
|
|
|
+ rawConn.Close()
|
|
|
+ return nil, nil, err
|
|
|
+ }
|
|
|
}
|
|
|
-
|
|
|
- resp, err := http.ReadResponse(bufio.NewReader(rawConn), req)
|
|
|
+ bufferedReader := bufio.NewReader(rawConn)
|
|
|
+ resp, err := http.ReadResponse(bufferedReader, req)
|
|
|
if err != nil {
|
|
|
rawConn.Close()
|
|
|
- return nil, err
|
|
|
+ return nil, nil, err
|
|
|
}
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
|
rawConn.Close()
|
|
|
- return nil, newError("Proxy responded with non 200 code: " + resp.Status)
|
|
|
+ return nil, nil, newError("Proxy responded with non 200 code: " + resp.Status)
|
|
|
}
|
|
|
- return rawConn, nil
|
|
|
+ if bufferedReader.Buffered() > 0 {
|
|
|
+ payload, err := buf.ReadFrom(io.LimitReader(bufferedReader, int64(bufferedReader.Buffered())))
|
|
|
+ if err != nil {
|
|
|
+ return nil, nil, newError("unable to drain buffer: ").Base(err)
|
|
|
+ }
|
|
|
+ return rawConn, payload, nil
|
|
|
+ }
|
|
|
+ return rawConn, nil, nil
|
|
|
}
|
|
|
|
|
|
connectHTTP2 := func(rawConn net.Conn, h2clientConn *http2.ClientConn) (net.Conn, error) {
|
|
|
@@ -228,16 +270,16 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
|
|
|
if cc.CanTakeNewRequest() {
|
|
|
proxyConn, err := connectHTTP2(rc, cc)
|
|
|
if err != nil {
|
|
|
- return nil, err
|
|
|
+ return nil, nil, err
|
|
|
}
|
|
|
|
|
|
- return proxyConn, nil
|
|
|
+ return proxyConn, nil, nil
|
|
|
}
|
|
|
}
|
|
|
|
|
|
rawConn, err := dialer.Dial(ctx, dest)
|
|
|
if err != nil {
|
|
|
- return nil, err
|
|
|
+ return nil, nil, err
|
|
|
}
|
|
|
|
|
|
iConn := rawConn
|
|
|
@@ -249,7 +291,7 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
|
|
|
if tlsConn, ok := iConn.(*tls.Conn); ok {
|
|
|
if err := tlsConn.Handshake(); err != nil {
|
|
|
rawConn.Close()
|
|
|
- return nil, err
|
|
|
+ return nil, nil, err
|
|
|
}
|
|
|
nextProto = tlsConn.ConnectionState().NegotiatedProtocol
|
|
|
}
|
|
|
@@ -262,13 +304,13 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
|
|
|
h2clientConn, err := t.NewClientConn(rawConn)
|
|
|
if err != nil {
|
|
|
rawConn.Close()
|
|
|
- return nil, err
|
|
|
+ return nil, nil, err
|
|
|
}
|
|
|
|
|
|
proxyConn, err := connectHTTP2(rawConn, h2clientConn)
|
|
|
if err != nil {
|
|
|
rawConn.Close()
|
|
|
- return nil, err
|
|
|
+ return nil, nil, err
|
|
|
}
|
|
|
|
|
|
cachedH2Mutex.Lock()
|
|
|
@@ -282,9 +324,9 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
|
|
|
}
|
|
|
cachedH2Mutex.Unlock()
|
|
|
|
|
|
- return proxyConn, err
|
|
|
+ return proxyConn, nil, err
|
|
|
default:
|
|
|
- return nil, newError("negotiated unsupported application layer protocol: " + nextProto)
|
|
|
+ return nil, nil, newError("negotiated unsupported application layer protocol: " + nextProto)
|
|
|
}
|
|
|
}
|
|
|
|