Ver Fonte

integration test case for policy

Darien Raymond há 7 anos atrás
pai
commit
c25a76a0cf
4 ficheiros alterados com 210 adições e 29 exclusões
  1. 34 0
      app/policy/config.go
  2. 6 4
      app/policy/manager.go
  3. 2 25
      policy.go
  4. 168 0
      testing/scenarios/policy_test.go

+ 34 - 0
app/policy/config.go

@@ -14,6 +14,40 @@ func (s *Second) Duration() time.Duration {
 	return time.Second * time.Duration(s.Value)
 }
 
+func defaultPolicy() *Policy {
+	p := core.DefaultPolicy()
+
+	return &Policy{
+		Timeout: &Policy_Timeout{
+			Handshake:      &Second{Value: uint32(p.Timeouts.Handshake / time.Second)},
+			ConnectionIdle: &Second{Value: uint32(p.Timeouts.ConnectionIdle / time.Second)},
+			UplinkOnly:     &Second{Value: uint32(p.Timeouts.UplinkOnly / time.Second)},
+			DownlinkOnly:   &Second{Value: uint32(p.Timeouts.DownlinkOnly / time.Second)},
+		},
+	}
+}
+
+func (p *Policy_Timeout) overrideWith(another *Policy_Timeout) {
+	if another.Handshake != nil {
+		p.Handshake = &Second{Value: another.Handshake.Value}
+	}
+	if another.ConnectionIdle != nil {
+		p.ConnectionIdle = &Second{Value: another.ConnectionIdle.Value}
+	}
+	if another.UplinkOnly != nil {
+		p.UplinkOnly = &Second{Value: another.UplinkOnly.Value}
+	}
+	if another.DownlinkOnly != nil {
+		p.DownlinkOnly = &Second{Value: another.DownlinkOnly.Value}
+	}
+}
+
+func (p *Policy) overrideWith(another *Policy) {
+	if another.Timeout != nil {
+		p.Timeout.overrideWith(another.Timeout)
+	}
+}
+
 func (p *Policy) ToCorePolicy() core.Policy {
 	var cp core.Policy
 	if p.Timeout != nil {

+ 6 - 4
app/policy/manager.go

@@ -9,17 +9,19 @@ import (
 
 // Instance is an instance of Policy manager.
 type Instance struct {
-	levels map[uint32]core.Policy
+	levels map[uint32]*Policy
 }
 
 // New creates new Policy manager instance.
 func New(ctx context.Context, config *Config) (*Instance, error) {
 	m := &Instance{
-		levels: make(map[uint32]core.Policy),
+		levels: make(map[uint32]*Policy),
 	}
 	if len(config.Level) > 0 {
 		for lv, p := range config.Level {
-			m.levels[lv] = p.ToCorePolicy().OverrideWith(core.DefaultPolicy())
+			pp := defaultPolicy()
+			pp.overrideWith(p)
+			m.levels[lv] = pp
 		}
 	}
 
@@ -36,7 +38,7 @@ func New(ctx context.Context, config *Config) (*Instance, error) {
 // ForLevel implements core.PolicyManager.
 func (m *Instance) ForLevel(level uint32) core.Policy {
 	if p, ok := m.levels[level]; ok {
-		return p
+		return p.ToCorePolicy()
 	}
 	return core.DefaultPolicy()
 }

+ 2 - 25
policy.go

@@ -13,40 +13,17 @@ type TimeoutPolicy struct {
 	Handshake time.Duration
 	// Timeout for connection being idle, i.e., there is no egress or ingress traffic in this connection.
 	ConnectionIdle time.Duration
-	// Timeout for an uplink only connection, i.e., the downlink of the connection has ben closed.
+	// Timeout for an uplink only connection, i.e., the downlink of the connection has been closed.
 	UplinkOnly time.Duration
-	// Timeout for an downlink only connection, i.e., the uplink of the connection has ben closed.
+	// Timeout for an downlink only connection, i.e., the uplink of the connection has been closed.
 	DownlinkOnly time.Duration
 }
 
-// OverrideWith overrides the current TimeoutPolicy with another one. All timeouts with zero value will be overridden with the new value.
-func (p TimeoutPolicy) OverrideWith(another TimeoutPolicy) TimeoutPolicy {
-	if p.Handshake == 0 {
-		p.Handshake = another.Handshake
-	}
-	if p.ConnectionIdle == 0 {
-		p.ConnectionIdle = another.ConnectionIdle
-	}
-	if p.UplinkOnly == 0 {
-		p.UplinkOnly = another.UplinkOnly
-	}
-	if p.DownlinkOnly == 0 {
-		p.DownlinkOnly = another.DownlinkOnly
-	}
-	return p
-}
-
 // Policy is session based settings for controlling V2Ray requests. It contains various settings (or limits) that may differ for different users in the context.
 type Policy struct {
 	Timeouts TimeoutPolicy // Timeout settings
 }
 
-// OverrideWith overrides the current Policy with another one. All values with default value will be overridden.
-func (p Policy) OverrideWith(another Policy) Policy {
-	p.Timeouts = p.Timeouts.OverrideWith(another.Timeouts)
-	return p
-}
-
 // PolicyManager is a feature that provides Policy for the given user by its id or level.
 type PolicyManager interface {
 	Feature

+ 168 - 0
testing/scenarios/policy_test.go

@@ -0,0 +1,168 @@
+package scenarios
+
+import (
+	"io"
+	"testing"
+	"time"
+
+	"v2ray.com/core"
+	"v2ray.com/core/app/policy"
+	"v2ray.com/core/app/proxyman"
+	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/protocol"
+	"v2ray.com/core/common/serial"
+	"v2ray.com/core/common/uuid"
+	"v2ray.com/core/proxy/dokodemo"
+	"v2ray.com/core/proxy/freedom"
+	"v2ray.com/core/proxy/vmess"
+	"v2ray.com/core/proxy/vmess/inbound"
+	"v2ray.com/core/proxy/vmess/outbound"
+	. "v2ray.com/ext/assert"
+)
+
+func startQuickClosingTCPServer() (net.Listener, error) {
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		return nil, err
+	}
+	go func() {
+		for {
+			conn, err := listener.Accept()
+			if err != nil {
+				break
+			}
+			b := make([]byte, 1024)
+			conn.Read(b)
+			conn.Close()
+		}
+	}()
+	return listener, nil
+}
+
+func TestVMessClosing(t *testing.T) {
+	assert := With(t)
+
+	tcpServer, err := startQuickClosingTCPServer()
+	assert(err, IsNil)
+	defer tcpServer.Close()
+
+	dest := net.DestinationFromAddr(tcpServer.Addr())
+
+	userID := protocol.NewID(uuid.New())
+	serverPort := pickPort()
+	serverConfig := &core.Config{
+		App: []*serial.TypedMessage{
+			serial.ToTypedMessage(&policy.Config{
+				Level: map[uint32]*policy.Policy{
+					0: &policy.Policy{
+						Timeout: &policy.Policy_Timeout{
+							UplinkOnly:   &policy.Second{Value: 0},
+							DownlinkOnly: &policy.Second{Value: 0},
+						},
+					},
+				},
+			}),
+		},
+		Inbound: []*core.InboundHandlerConfig{
+			{
+				ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{
+					PortRange: net.SinglePortRange(serverPort),
+					Listen:    net.NewIPOrDomain(net.LocalHostIP),
+				}),
+				ProxySettings: serial.ToTypedMessage(&inbound.Config{
+					User: []*protocol.User{
+						{
+							Account: serial.ToTypedMessage(&vmess.Account{
+								Id:      userID.String(),
+								AlterId: 64,
+							}),
+						},
+					},
+				}),
+			},
+		},
+		Outbound: []*core.OutboundHandlerConfig{
+			{
+				ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
+			},
+		},
+	}
+
+	clientPort := pickPort()
+	clientConfig := &core.Config{
+		App: []*serial.TypedMessage{
+			serial.ToTypedMessage(&policy.Config{
+				Level: map[uint32]*policy.Policy{
+					0: &policy.Policy{
+						Timeout: &policy.Policy_Timeout{
+							UplinkOnly:   &policy.Second{Value: 0},
+							DownlinkOnly: &policy.Second{Value: 0},
+						},
+					},
+				},
+			}),
+		},
+		Inbound: []*core.InboundHandlerConfig{
+			{
+				ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{
+					PortRange: net.SinglePortRange(clientPort),
+					Listen:    net.NewIPOrDomain(net.LocalHostIP),
+				}),
+				ProxySettings: serial.ToTypedMessage(&dokodemo.Config{
+					Address: net.NewIPOrDomain(dest.Address),
+					Port:    uint32(dest.Port),
+					NetworkList: &net.NetworkList{
+						Network: []net.Network{net.Network_TCP},
+					},
+				}),
+			},
+		},
+		Outbound: []*core.OutboundHandlerConfig{
+			{
+				ProxySettings: serial.ToTypedMessage(&outbound.Config{
+					Receiver: []*protocol.ServerEndpoint{
+						{
+							Address: net.NewIPOrDomain(net.LocalHostIP),
+							Port:    uint32(serverPort),
+							User: []*protocol.User{
+								{
+									Account: serial.ToTypedMessage(&vmess.Account{
+										Id:      userID.String(),
+										AlterId: 64,
+										SecuritySettings: &protocol.SecurityConfig{
+											Type: protocol.SecurityType_AES128_GCM,
+										},
+									}),
+								},
+							},
+						},
+					},
+				}),
+			},
+		},
+	}
+
+	servers, err := InitializeServerConfigs(serverConfig, clientConfig)
+	assert(err, IsNil)
+
+	defer CloseAllServers(servers)
+
+	conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{
+		IP:   []byte{127, 0, 0, 1},
+		Port: int(clientPort),
+	})
+	assert(err, IsNil)
+
+	conn.SetDeadline(time.Now().Add(time.Second * 2))
+
+	nBytes, err := conn.Write([]byte("test payload"))
+	assert(nBytes, GreaterThan, 0)
+	assert(err, IsNil)
+
+	resp := make([]byte, 1024)
+	nBytes, err = conn.Read(resp)
+	assert(err, Equals, io.EOF)
+	assert(nBytes, Equals, 0)
+
+	CloseAllServers(servers)
+}