Browse Source

update quic connection handling

Darien Raymond 7 years ago
parent
commit
45fbf6f059
3 changed files with 150 additions and 50 deletions
  1. 18 5
      transport/internet/quic/conn.go
  2. 126 38
      transport/internet/quic/dialer.go
  3. 6 7
      transport/internet/quic/hub.go

+ 18 - 5
transport/internet/quic/conn.go

@@ -10,6 +10,7 @@ import (
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/signal/done"
 	"v2ray.com/core/transport/internet"
 	"v2ray.com/core/transport/internet"
 )
 )
 
 
@@ -133,9 +134,11 @@ func (c *sysConn) SetWriteDeadline(t time.Time) error {
 }
 }
 
 
 type interConn struct {
 type interConn struct {
-	stream quic.Stream
-	local  net.Addr
-	remote net.Addr
+	context *sessionContext
+	stream  quic.Stream
+	done    *done.Instance
+	local   net.Addr
+	remote  net.Addr
 }
 }
 
 
 func (c *interConn) Read(b []byte) (int, error) {
 func (c *interConn) Read(b []byte) (int, error) {
@@ -162,10 +165,13 @@ func (c *interConn) WriteMultiBuffer(mb buf.MultiBuffer) error {
 	defer reader.Close()
 	defer reader.Close()
 
 
 	for {
 	for {
-		nBytes, err := reader.Read(b[:1380])
+		nBytes, err := reader.Read(b[:1200])
 		if err != nil {
 		if err != nil {
 			break
 			break
 		}
 		}
+		if nBytes == 0 {
+			continue
+		}
 		if _, err := c.Write(b[:nBytes]); err != nil {
 		if _, err := c.Write(b[:nBytes]); err != nil {
 			return err
 			return err
 		}
 		}
@@ -179,7 +185,14 @@ func (c *interConn) Write(b []byte) (int, error) {
 }
 }
 
 
 func (c *interConn) Close() error {
 func (c *interConn) Close() error {
-	return c.stream.Close()
+	if c.context != nil {
+		defer c.context.onInterConnClose()
+	}
+
+	common.Must(c.done.Close())
+	c.stream.CancelRead(1)
+	c.stream.CancelWrite(1)
+	return nil
 }
 }
 
 
 func (c *interConn) LocalAddr() net.Addr {
 func (c *interConn) LocalAddr() net.Addr {

+ 126 - 38
transport/internet/quic/dialer.go

@@ -6,21 +6,77 @@ import (
 	"time"
 	"time"
 
 
 	quic "github.com/lucas-clemente/quic-go"
 	quic "github.com/lucas-clemente/quic-go"
-
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/signal/done"
+	"v2ray.com/core/common/task"
 	"v2ray.com/core/transport/internet"
 	"v2ray.com/core/transport/internet"
 	"v2ray.com/core/transport/internet/tls"
 	"v2ray.com/core/transport/internet/tls"
 )
 )
 
 
 type sessionContext struct {
 type sessionContext struct {
-	rawConn *sysConn
-	session quic.Session
+	access     sync.Mutex
+	done       *done.Instance
+	rawConn    *sysConn
+	session    quic.Session
+	interConns []*interConn
+}
+
+var errSessionClosed = newError("session closed")
+
+func (c *sessionContext) openStream(destAddr net.Addr) (*interConn, error) {
+	c.access.Lock()
+	defer c.access.Unlock()
+
+	if c.done.Done() {
+		return nil, errSessionClosed
+	}
+
+	stream, err := c.session.OpenStream()
+	if err != nil {
+		return nil, err
+	}
+
+	conn := &interConn{
+		stream:  stream,
+		done:    done.New(),
+		local:   c.session.LocalAddr(),
+		remote:  destAddr,
+		context: c,
+	}
+
+	c.interConns = append(c.interConns, conn)
+	return conn, nil
+}
+
+func (c *sessionContext) onInterConnClose() {
+	c.access.Lock()
+	defer c.access.Unlock()
+
+	if c.done.Done() {
+		return
+	}
+
+	activeConns := 0
+	for _, conn := range c.interConns {
+		if !conn.done.Done() {
+			activeConns++
+		}
+	}
+
+	if activeConns > 0 {
+		return
+	}
+
+	c.done.Close()
+	c.session.Close()
+	c.rawConn.Close()
 }
 }
 
 
 type clientSessions struct {
 type clientSessions struct {
 	access   sync.Mutex
 	access   sync.Mutex
 	sessions map[net.Destination][]*sessionContext
 	sessions map[net.Destination][]*sessionContext
+	cleanup  *task.Periodic
 }
 }
 
 
 func isActive(s quic.Session) bool {
 func isActive(s quic.Session) bool {
@@ -37,8 +93,13 @@ func removeInactiveSessions(sessions []*sessionContext) []*sessionContext {
 	for _, s := range sessions {
 	for _, s := range sessions {
 		if isActive(s.session) {
 		if isActive(s.session) {
 			activeSessions = append(activeSessions, s)
 			activeSessions = append(activeSessions, s)
-		} else {
-			s.rawConn.Close()
+			continue
+		}
+		if err := s.session.Close(); err != nil {
+			newError("failed to close session").Base(err).AtWarning().WriteToLog()
+		}
+		if err := s.rawConn.Close(); err != nil {
+			newError("failed to close raw connection").Base(err).AtWarning().WriteToLog()
 		}
 		}
 	}
 	}
 
 
@@ -49,21 +110,42 @@ func removeInactiveSessions(sessions []*sessionContext) []*sessionContext {
 	return sessions
 	return sessions
 }
 }
 
 
-func openStream(sessions []*sessionContext) (quic.Stream, net.Addr) {
+func openStream(sessions []*sessionContext, destAddr net.Addr) *interConn {
 	for _, s := range sessions {
 	for _, s := range sessions {
 		if !isActive(s.session) {
 		if !isActive(s.session) {
 			continue
 			continue
 		}
 		}
 
 
-		stream, err := s.session.OpenStream()
+		conn, err := s.openStream(destAddr)
 		if err != nil {
 		if err != nil {
-			newError("failed to create stream").Base(err).AtWarning().WriteToLog()
 			continue
 			continue
 		}
 		}
-		return stream, s.session.LocalAddr()
+
+		return conn
+	}
+
+	return nil
+}
+
+func (s *clientSessions) cleanSessions() error {
+	s.access.Lock()
+	defer s.access.Unlock()
+
+	if len(s.sessions) == 0 {
+		return nil
+	}
+
+	newSessionMap := make(map[net.Destination][]*sessionContext)
+
+	for dest, sessions := range s.sessions {
+		sessions = removeInactiveSessions(sessions)
+		if len(sessions) > 0 {
+			newSessionMap[dest] = sessions
+		}
 	}
 	}
 
 
-	return nil, nil
+	s.sessions = newSessionMap
+	return nil
 }
 }
 
 
 func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) {
 func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) {
@@ -81,14 +163,10 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo
 		sessions = s
 		sessions = s
 	}
 	}
 
 
-	{
-		stream, local := openStream(sessions)
-		if stream != nil {
-			return &interConn{
-				stream: stream,
-				local:  local,
-				remote: destAddr,
-			}, nil
+	if true {
+		conn := openStream(sessions, destAddr)
+		if conn != nil {
+			return conn, nil
 		}
 		}
 	}
 	}
 
 
@@ -103,13 +181,11 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo
 	}
 	}
 
 
 	quicConfig := &quic.Config{
 	quicConfig := &quic.Config{
-		ConnectionIDLength:                    8,
-		HandshakeTimeout:                      time.Second * 8,
-		IdleTimeout:                           time.Second * 30,
-		MaxReceiveStreamFlowControlWindow:     128 * 1024,
-		MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024,
-		MaxIncomingUniStreams:                 -1,
-		MaxIncomingStreams:                    32,
+		ConnectionIDLength:    12,
+		HandshakeTimeout:      time.Second * 8,
+		IdleTimeout:           time.Second * 30,
+		MaxIncomingUniStreams: -1,
+		MaxIncomingStreams:    -1,
 	}
 	}
 
 
 	conn, err := wrapSysConn(rawConn, config)
 	conn, err := wrapSysConn(rawConn, config)
@@ -124,23 +200,26 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	s.sessions[dest] = append(sessions, &sessionContext{
+	context := &sessionContext{
 		session: session,
 		session: session,
 		rawConn: conn,
 		rawConn: conn,
-	})
-	stream, err := session.OpenStream()
-	if err != nil {
-		return nil, err
+		done:    done.New(),
 	}
 	}
-	return &interConn{
-		stream: stream,
-		local:  session.LocalAddr(),
-		remote: destAddr,
-	}, nil
+	s.sessions[dest] = append(sessions, context)
+	return context.openStream(destAddr)
 }
 }
 
 
 var client clientSessions
 var client clientSessions
 
 
+func init() {
+	client.sessions = make(map[net.Destination][]*sessionContext)
+	client.cleanup = &task.Periodic{
+		Interval: time.Minute,
+		Execute:  client.cleanSessions,
+	}
+	common.Must(client.cleanup.Start())
+}
+
 func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
 func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
 	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
 	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
 	if tlsConfig == nil {
 	if tlsConfig == nil {
@@ -150,9 +229,18 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 		}
 		}
 	}
 	}
 
 
-	destAddr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
-	if err != nil {
-		return nil, err
+	var destAddr *net.UDPAddr
+	if dest.Address.Family().IsIP() {
+		destAddr = &net.UDPAddr{
+			IP:   dest.Address.IP(),
+			Port: int(dest.Port),
+		}
+	} else {
+		addr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
+		if err != nil {
+			return nil, err
+		}
+		destAddr = addr
 	}
 	}
 
 
 	config := streamSettings.ProtocolSettings.(*Config)
 	config := streamSettings.ProtocolSettings.(*Config)

+ 6 - 7
transport/internet/quic/hub.go

@@ -40,6 +40,7 @@ func (l *Listener) acceptStreams(session quic.Session) {
 
 
 		conn := &interConn{
 		conn := &interConn{
 			stream: stream,
 			stream: stream,
+			done:   done.New(),
 			local:  session.LocalAddr(),
 			local:  session.LocalAddr(),
 			remote: session.RemoteAddr(),
 			remote: session.RemoteAddr(),
 		}
 		}
@@ -101,13 +102,11 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
 	}
 	}
 
 
 	quicConfig := &quic.Config{
 	quicConfig := &quic.Config{
-		ConnectionIDLength:                    8,
-		HandshakeTimeout:                      time.Second * 8,
-		IdleTimeout:                           time.Second * 30,
-		MaxReceiveStreamFlowControlWindow:     128 * 1024,
-		MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024,
-		MaxIncomingStreams:                    32,
-		MaxIncomingUniStreams:                 -1,
+		ConnectionIDLength:    12,
+		HandshakeTimeout:      time.Second * 8,
+		IdleTimeout:           time.Second * 30,
+		MaxIncomingStreams:    256,
+		MaxIncomingUniStreams: -1,
 	}
 	}
 
 
 	conn, err := wrapSysConn(rawConn, config)
 	conn, err := wrapSysConn(rawConn, config)