Browse Source

refactor tcp worker

Darien Raymond 8 years ago
parent
commit
d93ff628bc

+ 27 - 3
app/proxyman/inbound/worker.go

@@ -37,7 +37,7 @@ type tcpWorker struct {
 
 	ctx    context.Context
 	cancel context.CancelFunc
-	hub    *internet.TCPHub
+	hub    internet.Listener
 }
 
 func (w *tcpWorker) callback(conn internet.Connection) {
@@ -73,17 +73,41 @@ func (w *tcpWorker) Start() error {
 	ctx, cancel := context.WithCancel(context.Background())
 	w.ctx = ctx
 	w.cancel = cancel
-	hub, err := internet.ListenTCP(w.address, w.port, w.callback, w.stream)
+	ctx = internet.ContextWithStreamSettings(ctx, w.stream)
+	conns := make(chan internet.Connection, 16)
+	hub, err := internet.ListenTCP(ctx, w.address, w.port, conns)
 	if err != nil {
 		return err
 	}
+	go w.handleConnections(conns)
 	w.hub = hub
 	return nil
 }
 
+func (w *tcpWorker) handleConnections(conns <-chan internet.Connection) {
+	for {
+		select {
+		case <-w.ctx.Done():
+			w.hub.Close()
+			nconns := len(conns)
+		L:
+			for i := 0; i < nconns; i++ {
+				select {
+				case conn := <-conns:
+					conn.Close()
+				default:
+					break L
+				}
+			}
+			return
+		case conn := <-conns:
+			go w.callback(conn)
+		}
+	}
+}
+
 func (w *tcpWorker) Close() {
 	if w.hub != nil {
-		w.hub.Close()
 		w.cancel()
 	}
 }

+ 6 - 3
transport/internet/context.go

@@ -19,9 +19,12 @@ func ContextWithStreamSettings(ctx context.Context, streamSettings *StreamConfig
 	return context.WithValue(ctx, streamSettingsKey, streamSettings)
 }
 
-func StreamSettingsFromContext(ctx context.Context) (*StreamConfig, bool) {
-	ss, ok := ctx.Value(streamSettingsKey).(*StreamConfig)
-	return ss, ok
+func StreamSettingsFromContext(ctx context.Context) *StreamConfig {
+	ss := ctx.Value(streamSettingsKey)
+	if ss == nil {
+		return nil
+	}
+	return ss.(*StreamConfig)
 }
 
 func ContextWithDialerSource(ctx context.Context, addr net.Address) context.Context {

+ 1 - 1
transport/internet/dialer.go

@@ -24,7 +24,7 @@ func RegisterTransportDialer(protocol TransportProtocol, dialer Dialer) error {
 
 func Dial(ctx context.Context, dest v2net.Destination) (Connection, error) {
 	if dest.Network == v2net.Network_TCP {
-		streamSettings, _ := StreamSettingsFromContext(ctx)
+		streamSettings := StreamSettingsFromContext(ctx)
 		protocol := streamSettings.GetEffectiveProtocol()
 		transportSettings, err := streamSettings.GetEffectiveTransportSettings()
 		if err != nil {

+ 9 - 11
transport/internet/kcp/kcp_test.go

@@ -18,30 +18,27 @@ import (
 func TestDialAndListen(t *testing.T) {
 	assert := assert.On(t)
 
-	listerner, err := NewListener(internet.ContextWithTransportSettings(context.Background(), &Config{}), v2net.LocalHostIP, v2net.Port(0))
+	conns := make(chan internet.Connection, 16)
+	listerner, err := NewListener(internet.ContextWithTransportSettings(context.Background(), &Config{}), v2net.LocalHostIP, v2net.Port(0), conns)
 	assert.Error(err).IsNil()
 	port := v2net.Port(listerner.Addr().(*net.UDPAddr).Port)
 
 	go func() {
-		for {
-			conn, err := listerner.Accept()
-			if err != nil {
-				break
-			}
-			go func() {
+		for conn := range conns {
+			go func(c internet.Connection) {
 				payload := make([]byte, 4096)
 				for {
-					nBytes, err := conn.Read(payload)
+					nBytes, err := c.Read(payload)
 					if err != nil {
 						break
 					}
 					for idx, b := range payload[:nBytes] {
 						payload[idx] = b ^ 'c'
 					}
-					conn.Write(payload[:nBytes])
+					c.Write(payload[:nBytes])
 				}
-				conn.Close()
-			}()
+				c.Close()
+			}(conn)
 		}
 	}()
 
@@ -79,4 +76,5 @@ func TestDialAndListen(t *testing.T) {
 	assert.Int(listerner.ActiveConnections()).Equals(0)
 
 	listerner.Close()
+	close(conns)
 }

+ 23 - 39
transport/internet/kcp/listener.go

@@ -80,18 +80,18 @@ func (o *ServerConnection) Id() internal.ConnectionID {
 // Listener defines a server listening for connections
 type Listener struct {
 	sync.Mutex
-	closed        chan bool
-	sessions      map[ConnectionID]*Connection
-	awaitingConns chan *Connection
-	hub           *udp.Hub
-	tlsConfig     *tls.Config
-	config        *Config
-	reader        PacketReader
-	header        internet.PacketHeader
-	security      cipher.AEAD
+	closed    chan bool
+	sessions  map[ConnectionID]*Connection
+	hub       *udp.Hub
+	tlsConfig *tls.Config
+	config    *Config
+	reader    PacketReader
+	header    internet.PacketHeader
+	security  cipher.AEAD
+	conns     chan<- internet.Connection
 }
 
-func NewListener(ctx context.Context, address v2net.Address, port v2net.Port) (*Listener, error) {
+func NewListener(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (*Listener, error) {
 	networkSettings := internet.TransportSettingsFromContext(ctx)
 	kcpSettings := networkSettings.(*Config)
 	kcpSettings.ConnectionReuse = &ConnectionReuse{Enable: false}
@@ -111,10 +111,10 @@ func NewListener(ctx context.Context, address v2net.Address, port v2net.Port) (*
 			Header:   header,
 			Security: security,
 		},
-		sessions:      make(map[ConnectionID]*Connection),
-		awaitingConns: make(chan *Connection, 64),
-		closed:        make(chan bool),
-		config:        kcpSettings,
+		sessions: make(map[ConnectionID]*Connection),
+		closed:   make(chan bool),
+		config:   kcpSettings,
+		conns:    conns,
 	}
 	securitySettings := internet.SecuritySettingsFromContext(ctx)
 	if securitySettings != nil {
@@ -194,8 +194,14 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src v2net.Destination, origina
 			closer: writer,
 		}
 		conn = NewConnection(conv, sConn, v, v.config)
+		var netConn internet.Connection = conn
+		if v.tlsConfig != nil {
+			tlsConn := tls.Server(conn, v.tlsConfig)
+			netConn = UnreusableConnection{Conn: tlsConn}
+		}
+
 		select {
-		case v.awaitingConns <- conn:
+		case v.conns <- netConn:
 		case <-time.After(time.Second * 5):
 			conn.Close()
 			return
@@ -216,27 +222,6 @@ func (v *Listener) Remove(id ConnectionID) {
 	}
 }
 
-// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn.
-func (v *Listener) Accept() (internet.Connection, error) {
-	for {
-		select {
-		case <-v.closed:
-			return nil, ErrClosedListener
-		case conn, open := <-v.awaitingConns:
-			if !open {
-				break
-			}
-			if v.tlsConfig != nil {
-				tlsConn := tls.Server(conn, v.tlsConfig)
-				return UnreusableConnection{Conn: tlsConn}, nil
-			}
-			return conn, nil
-		case <-time.After(time.Second):
-
-		}
-	}
-}
-
 // Close stops listening on the UDP address. Already Accepted connections are not closed.
 func (v *Listener) Close() error {
 
@@ -249,7 +234,6 @@ func (v *Listener) Close() error {
 	}
 
 	close(v.closed)
-	close(v.awaitingConns)
 	for _, conn := range v.sessions {
 		go conn.Terminate()
 	}
@@ -288,8 +272,8 @@ func (v *Writer) Close() error {
 	return nil
 }
 
-func ListenKCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
-	return NewListener(ctx, address, port)
+func ListenKCP(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (internet.Listener, error) {
+	return NewListener(ctx, address, port, conns)
 }
 
 func init() {

+ 1 - 2
transport/internet/system_dialer.go

@@ -1,11 +1,10 @@
 package internet
 
 import (
+	"context"
 	"net"
 	"time"
 
-	"context"
-
 	v2net "v2ray.com/core/common/net"
 )
 

+ 23 - 52
transport/internet/tcp/hub.go

@@ -20,22 +20,17 @@ var (
 	ErrClosedListener = errors.New("Listener is closed.")
 )
 
-type ConnectionWithError struct {
-	conn net.Conn
-	err  error
-}
-
 type TCPListener struct {
 	sync.Mutex
-	acccepting    bool
-	listener      *net.TCPListener
-	awaitingConns chan *ConnectionWithError
-	tlsConfig     *tls.Config
-	authConfig    internet.ConnectionAuthenticator
-	config        *Config
+	acccepting bool
+	listener   *net.TCPListener
+	tlsConfig  *tls.Config
+	authConfig internet.ConnectionAuthenticator
+	config     *Config
+	conns      chan<- internet.Connection
 }
 
-func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
+func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (internet.Listener, error) {
 	listener, err := net.ListenTCP("tcp", &net.TCPAddr{
 		IP:   address.IP(),
 		Port: int(port),
@@ -48,10 +43,10 @@ func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (int
 	tcpSettings := networkSettings.(*Config)
 
 	l := &TCPListener{
-		acccepting:    true,
-		listener:      listener,
-		awaitingConns: make(chan *ConnectionWithError, 32),
-		config:        tcpSettings,
+		acccepting: true,
+		listener:   listener,
+		config:     tcpSettings,
+		conns:      conns,
 	}
 	if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
 		tlsConfig, ok := securitySettings.(*v2tls.Config)
@@ -74,24 +69,6 @@ func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (int
 	return l, nil
 }
 
-func (v *TCPListener) Accept() (internet.Connection, error) {
-	for v.acccepting {
-		select {
-		case connErr, open := <-v.awaitingConns:
-			if !open {
-				return nil, ErrClosedListener
-			}
-			if connErr.err != nil {
-				return nil, connErr.err
-			}
-			conn := connErr.conn
-			return internal.NewConnection(internal.ConnectionID{}, conn, v, internal.ReuseConnection(v.config.IsConnectionReuse())), nil
-		case <-time.After(time.Second * 2):
-		}
-	}
-	return nil, ErrClosedListener
-}
-
 func (v *TCPListener) KeepAccepting() {
 	for v.acccepting {
 		conn, err := v.listener.Accept()
@@ -100,22 +77,22 @@ func (v *TCPListener) KeepAccepting() {
 			v.Unlock()
 			break
 		}
-		if conn != nil && v.tlsConfig != nil {
+		if err != nil {
+			log.Warning("TCP|Listener: Failed to accepted raw connections: ", err)
+			v.Unlock()
+			continue
+		}
+		if v.tlsConfig != nil {
 			conn = tls.Server(conn, v.tlsConfig)
 		}
-		if conn != nil && v.authConfig != nil {
+		if v.authConfig != nil {
 			conn = v.authConfig.Server(conn)
 		}
 
 		select {
-		case v.awaitingConns <- &ConnectionWithError{
-			conn: conn,
-			err:  err,
-		}:
-		default:
-			if conn != nil {
-				conn.Close()
-			}
+		case v.conns <- internal.NewConnection(internal.ConnectionID{}, conn, v, internal.ReuseConnection(v.config.IsConnectionReuse())):
+		case <-time.After(time.Second * 5):
+			conn.Close()
 		}
 
 		v.Unlock()
@@ -129,8 +106,8 @@ func (v *TCPListener) Put(id internal.ConnectionID, conn net.Conn) {
 		return
 	}
 	select {
-	case v.awaitingConns <- &ConnectionWithError{conn: conn}:
-	default:
+	case v.conns <- internal.NewConnection(internal.ConnectionID{}, conn, v, internal.ReuseConnection(v.config.IsConnectionReuse())):
+	case <-time.After(time.Second * 5):
 		conn.Close()
 	}
 }
@@ -144,12 +121,6 @@ func (v *TCPListener) Close() error {
 	defer v.Unlock()
 	v.acccepting = false
 	v.listener.Close()
-	close(v.awaitingConns)
-	for connErr := range v.awaitingConns {
-		if connErr.conn != nil {
-			connErr.conn.Close()
-		}
-	}
 	return nil
 }
 

+ 6 - 66
transport/internet/tcp_hub.go

@@ -1,14 +1,11 @@
 package internet
 
 import (
-	"net"
-
 	"context"
+	"net"
 
-	"v2ray.com/core/app/log"
 	"v2ray.com/core/common/errors"
 	v2net "v2ray.com/core/common/net"
-	"v2ray.com/core/common/retry"
 )
 
 var (
@@ -23,10 +20,9 @@ func RegisterTransportListener(protocol TransportProtocol, listener ListenFunc)
 	return nil
 }
 
-type ListenFunc func(ctx context.Context, address v2net.Address, port v2net.Port) (Listener, error)
+type ListenFunc func(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- Connection) (Listener, error)
 
 type Listener interface {
-	Accept() (Connection, error)
 	Close() error
 	Addr() net.Addr
 }
@@ -37,8 +33,8 @@ type TCPHub struct {
 	closed       chan bool
 }
 
-func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandler, settings *StreamConfig) (*TCPHub, error) {
-	ctx := context.Background()
+func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- Connection) (Listener, error) {
+	settings := StreamSettingsFromContext(ctx)
 	protocol := settings.GetEffectiveProtocol()
 	transportSettings, err := settings.GetEffectiveTransportSettings()
 	if err != nil {
@@ -56,65 +52,9 @@ func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandle
 	if listenFunc == nil {
 		return nil, errors.New("Internet|TCPHub: ", protocol, " listener not registered.")
 	}
-	listener, err := listenFunc(ctx, address, port)
+	listener, err := listenFunc(ctx, address, port, conns)
 	if err != nil {
 		return nil, errors.Base(err).Message("Internet|TCPHub: Failed to listen on address: ", address, ":", port)
 	}
-
-	hub := &TCPHub{
-		listener:     listener,
-		connCallback: callback,
-	}
-
-	go hub.start()
-	return hub, nil
-}
-
-func (v *TCPHub) Close() {
-	defer func() {
-		recover()
-	}()
-
-	select {
-	case <-v.closed:
-		return
-	default:
-		v.listener.Close()
-		close(v.closed)
-	}
-}
-
-func (v *TCPHub) start() {
-	for {
-		select {
-		case <-v.closed:
-			return
-		default:
-		}
-		var newConn Connection
-		err := retry.ExponentialBackoff(10, 500).On(func() error {
-			select {
-			case <-v.closed:
-				return nil
-			default:
-				conn, err := v.listener.Accept()
-				if err != nil {
-					return errors.Base(err).RequireUserAction().Message("Internet|Listener: Failed to accept new TCP connection.")
-				}
-				newConn = conn
-				return nil
-			}
-		})
-		if err != nil {
-			if errors.IsActionRequired(err) {
-				log.Warning(err)
-			} else {
-				log.Info(err)
-			}
-			continue
-		}
-		if newConn != nil {
-			go v.connCallback(newConn)
-		}
-	}
+	return listener, nil
 }

+ 17 - 46
transport/internet/websocket/hub.go

@@ -23,14 +23,9 @@ var (
 	ErrClosedListener = errors.New("Listener is closed.")
 )
 
-type ConnectionWithError struct {
-	conn net.Conn
-	err  error
-}
-
 type requestHandler struct {
-	path  string
-	conns chan *ConnectionWithError
+	path string
+	ln   *Listener
 }
 
 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
@@ -45,29 +40,29 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 	}
 
 	select {
-	case h.conns <- &ConnectionWithError{conn: conn}:
-	default:
+	case h.ln.conns <- internal.NewConnection(internal.ConnectionID{}, conn, h.ln, internal.ReuseConnection(h.ln.config.IsConnectionReuse())):
+	case <-time.After(time.Second * 5):
 		conn.Close()
 	}
 }
 
 type Listener struct {
 	sync.Mutex
-	closed        chan bool
-	awaitingConns chan *ConnectionWithError
-	listener      net.Listener
-	tlsConfig     *tls.Config
-	config        *Config
+	closed    chan bool
+	listener  net.Listener
+	tlsConfig *tls.Config
+	config    *Config
+	conns     chan<- internet.Connection
 }
 
-func ListenWS(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
+func ListenWS(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (internet.Listener, error) {
 	networkSettings := internet.TransportSettingsFromContext(ctx)
 	wsSettings := networkSettings.(*Config)
 
 	l := &Listener{
-		closed:        make(chan bool),
-		awaitingConns: make(chan *ConnectionWithError, 32),
-		config:        wsSettings,
+		closed: make(chan bool),
+		config: wsSettings,
+		conns:  conns,
 	}
 	if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
 		tlsConfig, ok := securitySettings.(*v2tls.Config)
@@ -101,8 +96,8 @@ func (ln *Listener) listenws(address v2net.Address, port v2net.Port) error {
 
 	go func() {
 		http.Serve(listener, &requestHandler{
-			path:  ln.config.GetNormailzedPath(),
-			conns: ln.awaitingConns,
+			path: ln.config.GetNormailzedPath(),
+			ln:   ln,
 		})
 	}()
 
@@ -123,24 +118,6 @@ func converttovws(w http.ResponseWriter, r *http.Request) (*connection, error) {
 	return &connection{wsc: conn}, nil
 }
 
-func (ln *Listener) Accept() (internet.Connection, error) {
-	for {
-		select {
-		case <-ln.closed:
-			return nil, ErrClosedListener
-		case connErr, open := <-ln.awaitingConns:
-			if !open {
-				return nil, ErrClosedListener
-			}
-			if connErr.err != nil {
-				return nil, connErr.err
-			}
-			return internal.NewConnection(internal.ConnectionID{}, connErr.conn, ln, internal.ReuseConnection(ln.config.IsConnectionReuse())), nil
-		case <-time.After(time.Second * 2):
-		}
-	}
-}
-
 func (ln *Listener) Put(id internal.ConnectionID, conn net.Conn) {
 	ln.Lock()
 	defer ln.Unlock()
@@ -150,8 +127,8 @@ func (ln *Listener) Put(id internal.ConnectionID, conn net.Conn) {
 	default:
 	}
 	select {
-	case ln.awaitingConns <- &ConnectionWithError{conn: conn}:
-	default:
+	case ln.conns <- internal.NewConnection(internal.ConnectionID{}, conn, ln, internal.ReuseConnection(ln.config.IsConnectionReuse())):
+	case <-time.After(time.Second * 5):
 		conn.Close()
 	}
 }
@@ -170,12 +147,6 @@ func (ln *Listener) Close() error {
 	}
 	close(ln.closed)
 	ln.listener.Close()
-	close(ln.awaitingConns)
-	for connErr := range ln.awaitingConns {
-		if connErr.conn != nil {
-			connErr.conn.Close()
-		}
-	}
 	return nil
 }
 

+ 14 - 15
transport/internet/websocket/ws_test.go

@@ -16,31 +16,28 @@ import (
 
 func Test_listenWSAndDial(t *testing.T) {
 	assert := assert.On(t)
+	conns := make(chan internet.Connection, 16)
 	listen, err := ListenWS(internet.ContextWithTransportSettings(context.Background(), &Config{
 		Path: "ws",
-	}), v2net.DomainAddress("localhost"), 13146)
+	}), v2net.DomainAddress("localhost"), 13146, conns)
 	assert.Error(err).IsNil()
 	go func() {
-		for {
-			conn, err := listen.Accept()
-			if err != nil {
-				break
-			}
-			go func() {
-				defer conn.Close()
+		for conn := range conns {
+			go func(c internet.Connection) {
+				defer c.Close()
 
 				var b [1024]byte
-				n, err := conn.Read(b[:])
+				n, err := c.Read(b[:])
 				//assert.Error(err).IsNil()
 				if err != nil {
-					conn.SetReusable(false)
+					c.SetReusable(false)
 					return
 				}
 				assert.Bool(bytes.HasPrefix(b[:n], []byte("Test connection"))).IsTrue()
 
-				_, err = conn.Write([]byte("Response"))
+				_, err = c.Write([]byte("Response"))
 				assert.Error(err).IsNil()
-			}()
+			}(conn)
 		}
 	}()
 
@@ -77,6 +74,8 @@ func Test_listenWSAndDial(t *testing.T) {
 	assert.Error(conn.Close()).IsNil()
 
 	assert.Error(listen.Close()).IsNil()
+
+	close(conns)
 }
 
 func Test_listenWSAndDial_TLS(t *testing.T) {
@@ -96,11 +95,11 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
 		AllowInsecure: true,
 		Certificate:   []*v2tls.Certificate{tlsgen.GenerateCertificateForTest()},
 	})
-	listen, err := ListenWS(ctx, v2net.DomainAddress("localhost"), 13143)
+	conns := make(chan internet.Connection, 16)
+	listen, err := ListenWS(ctx, v2net.DomainAddress("localhost"), 13143, conns)
 	assert.Error(err).IsNil()
 	go func() {
-		conn, err := listen.Accept()
-		assert.Error(err).IsNil()
+		conn := <-conns
 		conn.Close()
 		listen.Close()
 	}()