Browse Source

long running reverse test case

Darien Raymond 7 years ago
parent
commit
f1ab89d9d8
4 changed files with 239 additions and 14 deletions
  1. 22 11
      app/reverse/portal.go
  2. 4 0
      common/mux/client.go
  3. 1 0
      common/mux/session.go
  4. 212 3
      testing/scenarios/reverse_test.go

+ 22 - 11
app/reverse/portal.go

@@ -68,10 +68,7 @@ func (s *Portal) HandleConnection(ctx context.Context, link *vio.Link) error {
 	}
 	}
 
 
 	if isDomain(outboundMeta.Target, s.domain) {
 	if isDomain(outboundMeta.Target, s.domain) {
-		muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{
-			MaxConcurrency: 0,
-			MaxConnection:  256,
-		})
+		muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{})
 		if err != nil {
 		if err != nil {
 			return newError("failed to create mux client worker").Base(err).AtWarning()
 			return newError("failed to create mux client worker").Base(err).AtWarning()
 		}
 		}
@@ -157,7 +154,7 @@ func (p *StaticMuxPicker) PickAvailable() (*mux.ClientWorker, error) {
 	var minIdx int = -1
 	var minIdx int = -1
 	var minConn uint32 = 9999
 	var minConn uint32 = 9999
 	for i, w := range p.workers {
 	for i, w := range p.workers {
-		if w.IsFull() {
+		if w.draining {
 			continue
 			continue
 		}
 		}
 		if w.client.ActiveConnections() < minConn {
 		if w.client.ActiveConnections() < minConn {
@@ -166,6 +163,18 @@ func (p *StaticMuxPicker) PickAvailable() (*mux.ClientWorker, error) {
 		}
 		}
 	}
 	}
 
 
+	if minIdx == -1 {
+		for i, w := range p.workers {
+			if w.IsFull() {
+				continue
+			}
+			if w.client.ActiveConnections() < minConn {
+				minConn = w.client.ActiveConnections()
+				minIdx = i
+			}
+		}
+	}
+
 	if minIdx != -1 {
 	if minIdx != -1 {
 		return p.workers[minIdx].client, nil
 		return p.workers[minIdx].client, nil
 	}
 	}
@@ -181,10 +190,11 @@ func (p *StaticMuxPicker) AddWorker(worker *PortalWorker) {
 }
 }
 
 
 type PortalWorker struct {
 type PortalWorker struct {
-	client  *mux.ClientWorker
-	control *task.Periodic
-	writer  buf.Writer
-	reader  buf.Reader
+	client   *mux.ClientWorker
+	control  *task.Periodic
+	writer   buf.Writer
+	reader   buf.Reader
+	draining bool
 }
 }
 
 
 func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
 func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
@@ -221,14 +231,15 @@ func (w *PortalWorker) heartbeat() error {
 		return newError("client worker stopped")
 		return newError("client worker stopped")
 	}
 	}
 
 
-	if w.writer == nil {
+	if w.draining || w.writer == nil {
 		return newError("already disposed")
 		return newError("already disposed")
 	}
 	}
 
 
 	msg := &Control{}
 	msg := &Control{}
 	msg.FillInRandom()
 	msg.FillInRandom()
 
 
-	if w.client.IsClosing() {
+	if w.client.TotalConnections() > 256 {
+		w.draining = true
 		msg.State = Control_DRAIN
 		msg.State = Control_DRAIN
 
 
 		defer func() {
 		defer func() {

+ 4 - 0
common/mux/client.go

@@ -190,6 +190,10 @@ func NewClientWorker(stream vio.Link, s ClientStrategy) (*ClientWorker, error) {
 	return c, nil
 	return c, nil
 }
 }
 
 
+func (m *ClientWorker) TotalConnections() uint32 {
+	return uint32(m.sessionManager.Count())
+}
+
 func (m *ClientWorker) ActiveConnections() uint32 {
 func (m *ClientWorker) ActiveConnections() uint32 {
 	return uint32(m.sessionManager.Size())
 	return uint32(m.sessionManager.Size())
 }
 }

+ 1 - 0
common/mux/session.go

@@ -61,6 +61,7 @@ func (m *SessionManager) Add(s *Session) {
 		return
 		return
 	}
 	}
 
 
+	m.count++
 	m.sessions[s.ID] = s
 	m.sessions[s.ID] = s
 }
 }
 
 

+ 212 - 3
testing/scenarios/reverse_test.go

@@ -6,13 +6,15 @@ import (
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
-	"v2ray.com/core/app/reverse"
-	"v2ray.com/core/app/router"
-
 	"v2ray.com/core"
 	"v2ray.com/core"
+	"v2ray.com/core/app/log"
+	"v2ray.com/core/app/policy"
 	"v2ray.com/core/app/proxyman"
 	"v2ray.com/core/app/proxyman"
+	"v2ray.com/core/app/reverse"
+	"v2ray.com/core/app/router"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/compare"
 	"v2ray.com/core/common/compare"
+	clog "v2ray.com/core/common/log"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/serial"
 	"v2ray.com/core/common/serial"
@@ -210,3 +212,210 @@ func TestReverseProxy(t *testing.T) {
 	}
 	}
 	wg.Wait()
 	wg.Wait()
 }
 }
+
+func TestReverseProxyLongRunning(t *testing.T) {
+	tcpServer := tcp.Server{
+		MsgProcessor: xor,
+	}
+	dest, err := tcpServer.Start()
+	common.Must(err)
+
+	defer tcpServer.Close()
+
+	userID := protocol.NewID(uuid.New())
+	externalPort := tcp.PickPort()
+	reversePort := tcp.PickPort()
+
+	serverConfig := &core.Config{
+		App: []*serial.TypedMessage{
+			serial.ToTypedMessage(&log.Config{
+				ErrorLogLevel: clog.Severity_Warning,
+				ErrorLogType:  log.LogType_Console,
+			}),
+			serial.ToTypedMessage(&policy.Config{
+				Level: map[uint32]*policy.Policy{
+					0: {
+						Timeout: &policy.Policy_Timeout{
+							UplinkOnly:   &policy.Second{Value: 0},
+							DownlinkOnly: &policy.Second{Value: 0},
+						},
+					},
+				},
+			}),
+			serial.ToTypedMessage(&reverse.Config{
+				PortalConfig: []*reverse.PortalConfig{
+					{
+						Tag:    "portal",
+						Domain: "test.v2ray.com",
+					},
+				},
+			}),
+			serial.ToTypedMessage(&router.Config{
+				Rule: []*router.RoutingRule{
+					{
+						Domain: []*router.Domain{
+							{Type: router.Domain_Full, Value: "test.v2ray.com"},
+						},
+						Tag: "portal",
+					},
+					{
+						InboundTag: []string{"external"},
+						Tag:        "portal",
+					},
+				},
+			}),
+		},
+		Inbound: []*core.InboundHandlerConfig{
+			{
+				Tag: "external",
+				ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{
+					PortRange: net.SinglePortRange(externalPort),
+					Listen:    net.NewIPOrDomain(net.LocalHostIP),
+				}),
+				ProxySettings: serial.ToTypedMessage(&dokodemo.Config{
+					Address: net.NewIPOrDomain(dest.Address),
+					Port:    uint32(dest.Port),
+					NetworkList: &net.NetworkList{
+						Network: []net.Network{net.Network_TCP},
+					},
+				}),
+			},
+			{
+				ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{
+					PortRange: net.SinglePortRange(reversePort),
+					Listen:    net.NewIPOrDomain(net.LocalHostIP),
+				}),
+				ProxySettings: serial.ToTypedMessage(&inbound.Config{
+					User: []*protocol.User{
+						{
+							Account: serial.ToTypedMessage(&vmess.Account{
+								Id:      userID.String(),
+								AlterId: 64,
+							}),
+						},
+					},
+				}),
+			},
+		},
+		Outbound: []*core.OutboundHandlerConfig{
+			{
+				ProxySettings: serial.ToTypedMessage(&blackhole.Config{}),
+			},
+		},
+	}
+
+	clientPort := tcp.PickPort()
+	clientConfig := &core.Config{
+		App: []*serial.TypedMessage{
+			serial.ToTypedMessage(&log.Config{
+				ErrorLogLevel: clog.Severity_Warning,
+				ErrorLogType:  log.LogType_Console,
+			}),
+			serial.ToTypedMessage(&policy.Config{
+				Level: map[uint32]*policy.Policy{
+					0: {
+						Timeout: &policy.Policy_Timeout{
+							UplinkOnly:   &policy.Second{Value: 0},
+							DownlinkOnly: &policy.Second{Value: 0},
+						},
+					},
+				},
+			}),
+			serial.ToTypedMessage(&reverse.Config{
+				BridgeConfig: []*reverse.BridgeConfig{
+					{
+						Tag:    "bridge",
+						Domain: "test.v2ray.com",
+					},
+				},
+			}),
+			serial.ToTypedMessage(&router.Config{
+				Rule: []*router.RoutingRule{
+					{
+						Domain: []*router.Domain{
+							{Type: router.Domain_Full, Value: "test.v2ray.com"},
+						},
+						Tag: "reverse",
+					},
+					{
+						InboundTag: []string{"bridge"},
+						Tag:        "freedom",
+					},
+				},
+			}),
+		},
+		Inbound: []*core.InboundHandlerConfig{
+			{
+				ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{
+					PortRange: net.SinglePortRange(clientPort),
+					Listen:    net.NewIPOrDomain(net.LocalHostIP),
+				}),
+				ProxySettings: serial.ToTypedMessage(&dokodemo.Config{
+					Address: net.NewIPOrDomain(dest.Address),
+					Port:    uint32(dest.Port),
+					NetworkList: &net.NetworkList{
+						Network: []net.Network{net.Network_TCP},
+					},
+				}),
+			},
+		},
+		Outbound: []*core.OutboundHandlerConfig{
+			{
+				Tag:           "freedom",
+				ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
+			},
+			{
+				Tag: "reverse",
+				ProxySettings: serial.ToTypedMessage(&outbound.Config{
+					Receiver: []*protocol.ServerEndpoint{
+						{
+							Address: net.NewIPOrDomain(net.LocalHostIP),
+							Port:    uint32(reversePort),
+							User: []*protocol.User{
+								{
+									Account: serial.ToTypedMessage(&vmess.Account{
+										Id:      userID.String(),
+										AlterId: 64,
+										SecuritySettings: &protocol.SecurityConfig{
+											Type: protocol.SecurityType_AES128_GCM,
+										},
+									}),
+								},
+							},
+						},
+					},
+				}),
+			},
+		},
+	}
+
+	servers, err := InitializeServerConfigs(serverConfig, clientConfig)
+	common.Must(err)
+
+	defer CloseAllServers(servers)
+
+	for i := 0; i < 4096; i++ {
+		conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{
+			IP:   []byte{127, 0, 0, 1},
+			Port: int(externalPort),
+		})
+		common.Must(err)
+
+		payload := make([]byte, 1024)
+		rand.Read(payload)
+
+		nBytes, err := conn.Write([]byte(payload))
+		common.Must(err)
+
+		if nBytes != len(payload) {
+			t.Error("only part of payload is written: ", nBytes)
+		}
+
+		response := readFrom(conn, time.Second*5, 1024)
+		if err := compare.BytesEqualWithDetail(response, xor([]byte(payload))); err != nil {
+			t.Error(err)
+		}
+
+		conn.Close()
+	}
+}