Browse Source

:bug: http dialer add socket config; sockopt.mark use uint32 (#1264)

(cherry picked from commit 4d155bc2bf9dc5bdc3d7433aa67fbc2f5f93677d)
Ye Zhihao 4 years ago
parent
commit
27614e56fc
3 changed files with 16 additions and 7 deletions
  1. 1 1
      common/session/session.go
  2. 1 1
      transport/internet/config.proto
  3. 14 5
      transport/internet/http/dialer.go

+ 1 - 1
common/session/session.go

@@ -75,7 +75,7 @@ type Content struct {
 // Sockopt is the settings for socket connection.
 type Sockopt struct {
 	// Mark of the socket connection.
-	Mark int32
+	Mark uint32
 }
 
 // SetAttribute attachs additional string attributes to content.

+ 1 - 1
transport/internet/config.proto

@@ -56,7 +56,7 @@ message ProxyConfig {
 // SocketConfig is options to be applied on network sockets.
 message SocketConfig {
   // Mark of the connection. If non-zero, the value will be set to SO_MARK.
-  int32 mark = 1;
+  uint32 mark = 1;
 
   enum TCPFastOpenState {
     // AsIs is to leave the current TFO state as is, unmodified.

+ 14 - 5
transport/internet/http/dialer.go

@@ -24,16 +24,24 @@ var (
 	globalDialerAccess sync.Mutex
 )
 
-func getHTTPClient(ctx context.Context, dest net.Destination, tlsSettings *tls.Config) *http.Client {
+type dialerCanceller func()
+
+func getHTTPClient(ctx context.Context, dest net.Destination, tlsSettings *tls.Config, streamSettings *internet.MemoryStreamConfig) (*http.Client, dialerCanceller) {
 	globalDialerAccess.Lock()
 	defer globalDialerAccess.Unlock()
 
+	canceller := func() {
+		globalDialerAccess.Lock()
+		defer globalDialerAccess.Unlock()
+		delete(globalDialerMap, dest)
+	}
+
 	if globalDialerMap == nil {
 		globalDialerMap = make(map[net.Destination]*http.Client)
 	}
 
 	if client, found := globalDialerMap[dest]; found {
-		return client
+		return client, canceller
 	}
 
 	transport := &http2.Transport{
@@ -52,7 +60,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, tlsSettings *tls.C
 			address := net.ParseAddress(rawHost)
 
 			detachedContext := core.ToBackgroundDetachedContext(ctx)
-			pconn, err := internet.DialSystem(detachedContext, net.TCPDestination(address, port), nil)
+			pconn, err := internet.DialSystem(detachedContext, net.TCPDestination(address, port), streamSettings.SocketSettings)
 			if err != nil {
 				return nil, err
 			}
@@ -80,7 +88,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, tlsSettings *tls.C
 	}
 
 	globalDialerMap[dest] = client
-	return client
+	return client, canceller
 }
 
 // Dial dials a new TCP connection to the given destination.
@@ -90,7 +98,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 	if tlsConfig == nil {
 		return nil, newError("TLS must be enabled for http transport.").AtWarning()
 	}
-	client := getHTTPClient(ctx, dest, tlsConfig)
+	client, canceller := getHTTPClient(ctx, dest, tlsConfig, streamSettings)
 
 	opts := pipe.OptionsFromContext(ctx)
 	preader, pwriter := pipe.New(opts...)
@@ -128,6 +136,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 
 	response, err := client.Do(request) // nolint: bodyclose
 	if err != nil {
+		canceller()
 		return nil, newError("failed to dial to ", dest).Base(err).AtWarning()
 	}
 	if response.StatusCode != 200 {