Browse Source

update quic connection handling

Darien Raymond 7 years ago
parent
commit
4ff26a36ad
2 changed files with 67 additions and 41 deletions
  1. 39 32
      transport/internet/quic/dialer.go
  2. 28 9
      transport/internet/quic/hub.go

+ 39 - 32
transport/internet/quic/dialer.go

@@ -13,47 +13,54 @@ import (
 	"v2ray.com/core/transport/internet/tls"
 	"v2ray.com/core/transport/internet/tls"
 )
 )
 
 
+type sessionContext struct {
+	rawConn *sysConn
+	session quic.Session
+}
+
 type clientSessions struct {
 type clientSessions struct {
 	access   sync.Mutex
 	access   sync.Mutex
-	sessions map[net.Destination][]quic.Session
+	sessions map[net.Destination][]*sessionContext
+}
+
+func isActive(s quic.Session) bool {
+	select {
+	case <-s.Context().Done():
+		return false
+	default:
+		return true
+	}
 }
 }
 
 
-func removeInactiveSessions(sessions []quic.Session) []quic.Session {
-	lastActive := 0
+func removeInactiveSessions(sessions []*sessionContext) []*sessionContext {
+	activeSessions := make([]*sessionContext, 0, len(sessions))
 	for _, s := range sessions {
 	for _, s := range sessions {
-		active := true
-		select {
-		case <-s.Context().Done():
-			active = false
-		default:
-		}
-		if active {
-			sessions[lastActive] = s
-			lastActive++
+		if isActive(s.session) {
+			activeSessions = append(activeSessions, s)
+		} else {
+			s.rawConn.Close()
+			s.session.Close()
 		}
 		}
 	}
 	}
 
 
-	if lastActive < len(sessions) {
-		for i := lastActive; i < len(sessions); i++ {
-			sessions[i] = nil
-		}
-		sessions = sessions[:lastActive]
+	if len(activeSessions) < len(sessions) {
+		return activeSessions
 	}
 	}
 
 
 	return sessions
 	return sessions
 }
 }
 
 
-func openStream(sessions []quic.Session) (quic.Stream, net.Addr, error) {
+func openStream(sessions []*sessionContext) (quic.Stream, net.Addr) {
 	for _, s := range sessions {
 	for _, s := range sessions {
-		stream, err := s.OpenStream()
+		stream, err := s.session.OpenStream()
 		if err != nil {
 		if err != nil {
-			newError("failed to create stream").Base(err).WriteToLog()
+			newError("failed to create stream").Base(err).AtWarning().WriteToLog()
 			continue
 			continue
 		}
 		}
-		return stream, s.LocalAddr(), nil
+		return stream, s.session.LocalAddr()
 	}
 	}
 
 
-	return nil, nil, nil
+	return nil, nil
 }
 }
 
 
 func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) {
 func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) {
@@ -61,12 +68,12 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo
 	defer s.access.Unlock()
 	defer s.access.Unlock()
 
 
 	if s.sessions == nil {
 	if s.sessions == nil {
-		s.sessions = make(map[net.Destination][]quic.Session)
+		s.sessions = make(map[net.Destination][]*sessionContext)
 	}
 	}
 
 
 	dest := net.DestinationFromAddr(destAddr)
 	dest := net.DestinationFromAddr(destAddr)
 
 
-	var sessions []quic.Session
+	var sessions []*sessionContext
 	if s, found := s.sessions[dest]; found {
 	if s, found := s.sessions[dest]; found {
 		sessions = s
 		sessions = s
 	}
 	}
@@ -74,10 +81,7 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo
 	sessions = removeInactiveSessions(sessions)
 	sessions = removeInactiveSessions(sessions)
 	s.sessions[dest] = sessions
 	s.sessions[dest] = sessions
 
 
-	stream, local, err := openStream(sessions)
-	if err != nil {
-		return nil, err
-	}
+	stream, local := openStream(sessions)
 	if stream != nil {
 	if stream != nil {
 		return &interConn{
 		return &interConn{
 			stream: stream,
 			stream: stream,
@@ -96,8 +100,8 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo
 
 
 	quicConfig := &quic.Config{
 	quicConfig := &quic.Config{
 		ConnectionIDLength:                    12,
 		ConnectionIDLength:                    12,
-		HandshakeTimeout:                      time.Second * 4,
-		IdleTimeout:                           time.Second * 60,
+		HandshakeTimeout:                      time.Second * 8,
+		IdleTimeout:                           time.Second * 600,
 		MaxReceiveStreamFlowControlWindow:     512 * 1024,
 		MaxReceiveStreamFlowControlWindow:     512 * 1024,
 		MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024,
 		MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024,
 		MaxIncomingUniStreams:                 -1,
 		MaxIncomingUniStreams:                 -1,
@@ -111,11 +115,14 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo
 
 
 	session, err := quic.DialContext(context.Background(), conn, destAddr, "", tlsConfig.GetTLSConfig(tls.WithDestination(dest)), quicConfig)
 	session, err := quic.DialContext(context.Background(), conn, destAddr, "", tlsConfig.GetTLSConfig(tls.WithDestination(dest)), quicConfig)
 	if err != nil {
 	if err != nil {
-		rawConn.Close()
+		conn.Close()
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	s.sessions[dest] = append(sessions, session)
+	s.sessions[dest] = append(sessions, &sessionContext{
+		session: session,
+		rawConn: conn,
+	})
 	stream, err = session.OpenStream()
 	stream, err = session.OpenStream()
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err

+ 28 - 9
transport/internet/quic/hub.go

@@ -8,13 +8,16 @@ import (
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol/tls/cert"
 	"v2ray.com/core/common/protocol/tls/cert"
+	"v2ray.com/core/common/signal/done"
 	"v2ray.com/core/transport/internet"
 	"v2ray.com/core/transport/internet"
 	"v2ray.com/core/transport/internet/tls"
 	"v2ray.com/core/transport/internet/tls"
 )
 )
 
 
 // Listener is an internet.Listener that listens for TCP connections.
 // Listener is an internet.Listener that listens for TCP connections.
 type Listener struct {
 type Listener struct {
+	rawConn  *sysConn
 	listener quic.Listener
 	listener quic.Listener
+	done     *done.Instance
 	addConn  internet.ConnHandler
 	addConn  internet.ConnHandler
 }
 }
 
 
@@ -22,9 +25,17 @@ func (l *Listener) acceptStreams(session quic.Session) {
 	for {
 	for {
 		stream, err := session.AcceptStream()
 		stream, err := session.AcceptStream()
 		if err != nil {
 		if err != nil {
-			newError("failed to accept stream").Base(err).WriteToLog()
-			session.Close()
-			return
+			newError("failed to accept stream").Base(err).AtWarning().WriteToLog()
+			select {
+			case <-session.Context().Done():
+				return
+			case <-l.done.Wait():
+				session.Close()
+				return
+			default:
+				time.Sleep(time.Second)
+				continue
+			}
 		}
 		}
 
 
 		conn := &interConn{
 		conn := &interConn{
@@ -42,7 +53,10 @@ func (l *Listener) keepAccepting() {
 	for {
 	for {
 		conn, err := l.listener.Accept()
 		conn, err := l.listener.Accept()
 		if err != nil {
 		if err != nil {
-			newError("failed to accept QUIC sessions").Base(err).WriteToLog()
+			newError("failed to accept QUIC sessions").Base(err).AtWarning().WriteToLog()
+			if l.done.Done() {
+				break
+			}
 			time.Sleep(time.Second)
 			time.Sleep(time.Second)
 			continue
 			continue
 		}
 		}
@@ -57,7 +71,10 @@ func (l *Listener) Addr() net.Addr {
 
 
 // Close implements internet.Listener.Close.
 // Close implements internet.Listener.Close.
 func (l *Listener) Close() error {
 func (l *Listener) Close() error {
-	return l.listener.Close()
+	l.done.Close()
+	l.listener.Close()
+	l.rawConn.Close()
+	return nil
 }
 }
 
 
 // Listen creates a new Listener based on configurations.
 // Listen creates a new Listener based on configurations.
@@ -85,11 +102,11 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
 
 
 	quicConfig := &quic.Config{
 	quicConfig := &quic.Config{
 		ConnectionIDLength:                    12,
 		ConnectionIDLength:                    12,
-		HandshakeTimeout:                      time.Second * 4,
-		IdleTimeout:                           time.Second * 60,
+		HandshakeTimeout:                      time.Second * 8,
+		IdleTimeout:                           time.Second * 600,
 		MaxReceiveStreamFlowControlWindow:     512 * 1024,
 		MaxReceiveStreamFlowControlWindow:     512 * 1024,
 		MaxReceiveConnectionFlowControlWindow: 4 * 1024 * 1024,
 		MaxReceiveConnectionFlowControlWindow: 4 * 1024 * 1024,
-		MaxIncomingStreams:                    64,
+		MaxIncomingStreams:                    8192,
 		MaxIncomingUniStreams:                 -1,
 		MaxIncomingUniStreams:                 -1,
 	}
 	}
 
 
@@ -101,11 +118,13 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
 
 
 	qListener, err := quic.Listen(conn, tlsConfig.GetTLSConfig(), quicConfig)
 	qListener, err := quic.Listen(conn, tlsConfig.GetTLSConfig(), quicConfig)
 	if err != nil {
 	if err != nil {
-		rawConn.Close()
+		conn.Close()
 		return nil, err
 		return nil, err
 	}
 	}
 
 
 	listener := &Listener{
 	listener := &Listener{
+		done:     done.New(),
+		rawConn:  conn,
 		listener: qListener,
 		listener: qListener,
 		addConn:  handler,
 		addConn:  handler,
 	}
 	}