Browse Source

Apply timeout to dns outbound (#1330)

世界 4 years ago
parent
commit
7b0699e8a5
4 changed files with 41 additions and 16 deletions
  1. 5 3
      infra/conf/dns_proxy.go
  2. 20 10
      proxy/dns/config.pb.go
  3. 1 0
      proxy/dns/config.proto
  4. 15 3
      proxy/dns/dns.go

+ 5 - 3
infra/conf/dns_proxy.go

@@ -9,9 +9,10 @@ import (
 )
 )
 
 
 type DNSOutboundConfig struct {
 type DNSOutboundConfig struct {
-	Network cfgcommon.Network  `json:"network"`
-	Address *cfgcommon.Address `json:"address"`
-	Port    uint16             `json:"port"`
+	Network   cfgcommon.Network  `json:"network"`
+	Address   *cfgcommon.Address `json:"address"`
+	Port      uint16             `json:"port"`
+	UserLevel uint32             `json:"userLevel"`
 }
 }
 
 
 func (c *DNSOutboundConfig) Build() (proto.Message, error) {
 func (c *DNSOutboundConfig) Build() (proto.Message, error) {
@@ -20,6 +21,7 @@ func (c *DNSOutboundConfig) Build() (proto.Message, error) {
 			Network: c.Network.Build(),
 			Network: c.Network.Build(),
 			Port:    uint32(c.Port),
 			Port:    uint32(c.Port),
 		},
 		},
+		UserLevel: c.UserLevel,
 	}
 	}
 	if c.Address != nil {
 	if c.Address != nil {
 		config.Server.Address = c.Address.Build()
 		config.Server.Address = c.Address.Build()

+ 20 - 10
proxy/dns/config.pb.go

@@ -28,7 +28,8 @@ type Config struct {
 
 
 	// Server is the DNS server address. If specified, this address overrides the
 	// Server is the DNS server address. If specified, this address overrides the
 	// original one.
 	// original one.
-	Server *net.Endpoint `protobuf:"bytes,1,opt,name=server,proto3" json:"server,omitempty"`
+	Server    *net.Endpoint `protobuf:"bytes,1,opt,name=server,proto3" json:"server,omitempty"`
+	UserLevel uint32        `protobuf:"varint,2,opt,name=user_level,json=userLevel,proto3" json:"user_level,omitempty"`
 }
 }
 
 
 func (x *Config) Reset() {
 func (x *Config) Reset() {
@@ -70,6 +71,13 @@ func (x *Config) GetServer() *net.Endpoint {
 	return nil
 	return nil
 }
 }
 
 
+func (x *Config) GetUserLevel() uint32 {
+	if x != nil {
+		return x.UserLevel
+	}
+	return 0
+}
+
 var File_proxy_dns_config_proto protoreflect.FileDescriptor
 var File_proxy_dns_config_proto protoreflect.FileDescriptor
 
 
 var file_proxy_dns_config_proto_rawDesc = []byte{
 var file_proxy_dns_config_proto_rawDesc = []byte{
@@ -77,18 +85,20 @@ var file_proxy_dns_config_proto_rawDesc = []byte{
 	0x69, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x14, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e,
 	0x69, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x14, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e,
 	0x63, 0x6f, 0x72, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x64, 0x6e, 0x73, 0x1a, 0x1c,
 	0x63, 0x6f, 0x72, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x64, 0x6e, 0x73, 0x1a, 0x1c,
 	0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x6e, 0x65, 0x74, 0x2f, 0x64, 0x65, 0x73, 0x74, 0x69,
 	0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x6e, 0x65, 0x74, 0x2f, 0x64, 0x65, 0x73, 0x74, 0x69,
-	0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x41, 0x0a, 0x06,
+	0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x60, 0x0a, 0x06,
 	0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x37, 0x0a, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
 	0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x37, 0x0a, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
 	0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63,
 	0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63,
 	0x6f, 0x72, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x45,
 	0x6f, 0x72, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x6e, 0x65, 0x74, 0x2e, 0x45,
-	0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x52, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x42,
-	0x5d, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72,
-	0x65, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x64, 0x6e, 0x73, 0x50, 0x01, 0x5a, 0x28, 0x67,
-	0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x76, 0x32, 0x66, 0x6c, 0x79, 0x2f,
-	0x76, 0x32, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x76, 0x34, 0x2f, 0x70, 0x72,
-	0x6f, 0x78, 0x79, 0x2f, 0x64, 0x6e, 0x73, 0xaa, 0x02, 0x14, 0x56, 0x32, 0x52, 0x61, 0x79, 0x2e,
-	0x43, 0x6f, 0x72, 0x65, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x44, 0x6e, 0x73, 0x62, 0x06,
-	0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+	0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x52, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12,
+	0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x02, 0x20,
+	0x01, 0x28, 0x0d, 0x52, 0x09, 0x75, 0x73, 0x65, 0x72, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x42, 0x5d,
+	0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65,
+	0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x64, 0x6e, 0x73, 0x50, 0x01, 0x5a, 0x28, 0x67, 0x69,
+	0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x76, 0x32, 0x66, 0x6c, 0x79, 0x2f, 0x76,
+	0x32, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x76, 0x34, 0x2f, 0x70, 0x72, 0x6f,
+	0x78, 0x79, 0x2f, 0x64, 0x6e, 0x73, 0xaa, 0x02, 0x14, 0x56, 0x32, 0x52, 0x61, 0x79, 0x2e, 0x43,
+	0x6f, 0x72, 0x65, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x44, 0x6e, 0x73, 0x62, 0x06, 0x70,
+	0x72, 0x6f, 0x74, 0x6f, 0x33,
 }
 }
 
 
 var (
 var (

+ 1 - 0
proxy/dns/config.proto

@@ -12,4 +12,5 @@ message Config {
   // Server is the DNS server address. If specified, this address overrides the
   // Server is the DNS server address. If specified, this address overrides the
   // original one.
   // original one.
   v2ray.core.common.net.Endpoint server = 1;
   v2ray.core.common.net.Endpoint server = 1;
+  uint32 user_level = 2;
 }
 }

+ 15 - 3
proxy/dns/dns.go

@@ -7,6 +7,7 @@ import (
 	"context"
 	"context"
 	"io"
 	"io"
 	"sync"
 	"sync"
+	"time"
 
 
 	"golang.org/x/net/dns/dnsmessage"
 	"golang.org/x/net/dns/dnsmessage"
 
 
@@ -16,8 +17,10 @@ import (
 	"github.com/v2fly/v2ray-core/v4/common/net"
 	"github.com/v2fly/v2ray-core/v4/common/net"
 	dns_proto "github.com/v2fly/v2ray-core/v4/common/protocol/dns"
 	dns_proto "github.com/v2fly/v2ray-core/v4/common/protocol/dns"
 	"github.com/v2fly/v2ray-core/v4/common/session"
 	"github.com/v2fly/v2ray-core/v4/common/session"
+	"github.com/v2fly/v2ray-core/v4/common/signal"
 	"github.com/v2fly/v2ray-core/v4/common/task"
 	"github.com/v2fly/v2ray-core/v4/common/task"
 	"github.com/v2fly/v2ray-core/v4/features/dns"
 	"github.com/v2fly/v2ray-core/v4/features/dns"
+	"github.com/v2fly/v2ray-core/v4/features/policy"
 	"github.com/v2fly/v2ray-core/v4/transport"
 	"github.com/v2fly/v2ray-core/v4/transport"
 	"github.com/v2fly/v2ray-core/v4/transport/internet"
 	"github.com/v2fly/v2ray-core/v4/transport/internet"
 )
 )
@@ -25,8 +28,8 @@ import (
 func init() {
 func init() {
 	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
 	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
 		h := new(Handler)
 		h := new(Handler)
-		if err := core.RequireFeatures(ctx, func(dnsClient dns.Client) error {
-			return h.Init(config.(*Config), dnsClient)
+		if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error {
+			return h.Init(config.(*Config), dnsClient, policyManager)
 		}); err != nil {
 		}); err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -44,10 +47,12 @@ type Handler struct {
 	ipv6Lookup      dns.IPv6Lookup
 	ipv6Lookup      dns.IPv6Lookup
 	ownLinkVerifier ownLinkVerifier
 	ownLinkVerifier ownLinkVerifier
 	server          net.Destination
 	server          net.Destination
+	timeout         time.Duration
 }
 }
 
 
-func (h *Handler) Init(config *Config, dnsClient dns.Client) error {
+func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error {
 	h.client = dnsClient
 	h.client = dnsClient
+	h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle
 
 
 	if ipv4lookup, ok := dnsClient.(dns.IPv4Lookup); ok {
 	if ipv4lookup, ok := dnsClient.(dns.IPv4Lookup); ok {
 		h.ipv4Lookup = ipv4lookup
 		h.ipv4Lookup = ipv4lookup
@@ -160,6 +165,9 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 		}
 		}
 	}
 	}
 
 
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout)
+
 	request := func() error {
 	request := func() error {
 		defer conn.Close()
 		defer conn.Close()
 
 
@@ -173,6 +181,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 				return err
 				return err
 			}
 			}
 
 
+			timer.Update()
+
 			if !h.isOwnLink(ctx) {
 			if !h.isOwnLink(ctx) {
 				isIPQuery, domain, id, qType := parseIPQuery(b.Bytes())
 				isIPQuery, domain, id, qType := parseIPQuery(b.Bytes())
 				if isIPQuery {
 				if isIPQuery {
@@ -198,6 +208,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 				return err
 				return err
 			}
 			}
 
 
+			timer.Update()
+
 			if err := writer.WriteMessage(b); err != nil {
 			if err := writer.WriteMessage(b); err != nil {
 				return err
 				return err
 			}
 			}