Browse Source

Support using custom resolver when dialing domain address

Vigilans 2 years ago
parent
commit
b1d38db30a
2 changed files with 21 additions and 1 deletions
  1. 2 0
      common/session/session.go
  2. 19 1
      transport/internet/dialer.go

+ 2 - 0
common/session/session.go

@@ -51,6 +51,8 @@ type Outbound struct {
 	Target net.Destination
 	// Gateway address
 	Gateway net.Address
+	// Domain resolver to use when dialing
+	Resolver func(ctx context.Context, domain string) net.Address
 }
 
 // SniffingRequest controls the behavior of content sniffing.

+ 19 - 1
transport/internet/dialer.go

@@ -68,8 +68,10 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *MemoryStrea
 
 // DialSystem calls system dialer to create a network connection.
 func DialSystem(ctx context.Context, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) {
+	outbound := session.OutboundFromContext(ctx)
+
 	var src net.Address
-	if outbound := session.OutboundFromContext(ctx); outbound != nil {
+	if outbound != nil {
 		src = outbound.Gateway
 	}
 
@@ -77,6 +79,22 @@ func DialSystem(ctx context.Context, dest net.Destination, sockopt *SocketConfig
 		return DialTaggedOutbound(ctx, dest, transportLayerOutgoingTag)
 	}
 
+	originalAddr := dest.Address
+	if outbound != nil && outbound.Resolver != nil && dest.Address.Family().IsDomain() {
+		if addr := outbound.Resolver(ctx, dest.Address.Domain()); addr != nil {
+			dest.Address = addr
+		}
+	}
+
+	switch {
+	case src != nil && dest.Address != originalAddr:
+		newError("dialing to ", dest, " resolved from ", originalAddr, " via ", src).WriteToLog(session.ExportIDToError(ctx))
+	case src != nil:
+		newError("dialing to ", dest, " via ", src).WriteToLog(session.ExportIDToError(ctx))
+	case dest.Address != originalAddr:
+		newError("dialing to ", dest, " resolved from ", originalAddr).WriteToLog(session.ExportIDToError(ctx))
+	}
+
 	return effectiveSystemDialer.Dial(ctx, src, dest, sockopt)
 }