Quellcode durchsuchen

better termination logic

v2ray vor 9 Jahren
Ursprung
Commit
bac9304e05

+ 29 - 0
common/signal/once.go

@@ -0,0 +1,29 @@
+package signal
+
+import (
+	"sync"
+	"sync/atomic"
+)
+
+type Once struct {
+	m    sync.Mutex
+	done uint32
+}
+
+func (o *Once) Do(f func()) {
+	if atomic.LoadUint32(&o.done) == 1 {
+		return
+	}
+	o.m.Lock()
+	defer o.m.Unlock()
+	if o.done == 0 {
+		atomic.StoreUint32(&o.done, 1)
+		f()
+	}
+}
+
+func (o *Once) Reset() {
+	o.m.Lock()
+	defer o.m.Unlock()
+	atomic.StoreUint32(&o.done, 0)
+}

+ 44 - 18
transport/internet/kcp/connection.go

@@ -9,6 +9,7 @@ import (
 
 	"github.com/v2ray/v2ray-core/common/alloc"
 	"github.com/v2ray/v2ray-core/common/log"
+	"github.com/v2ray/v2ray-core/common/signal"
 )
 
 var (
@@ -63,6 +64,7 @@ type Connection struct {
 	chReadEvent   chan struct{}
 	writer        io.WriteCloser
 	since         int64
+	terminateOnce signal.Once
 }
 
 // NewConnection create a new KCP connection between local and remote.
@@ -76,21 +78,7 @@ func NewConnection(conv uint32, writerCloser io.WriteCloser, local *net.UDPAddr,
 	conn.since = nowMillisec()
 
 	mtu := uint32(effectiveConfig.Mtu - block.HeaderSize() - headerSize)
-	conn.kcp = NewKCP(conv, mtu, func(buf []byte, size int) {
-		if size >= IKCP_OVERHEAD {
-			ext := alloc.NewBuffer().Clear().Append(buf[:size])
-			cmd := CommandData
-			opt := Option(0)
-			if conn.state == ConnStateReadyToClose {
-				opt = OptionClose
-			}
-			ext.Prepend([]byte{byte(cmd), byte(opt)})
-			go conn.output(ext)
-		}
-		if conn.state == ConnStateReadyToClose && conn.kcp.WaitSnd() == 0 {
-			go conn.NotifyTermination()
-		}
-	})
+	conn.kcp = NewKCP(conv, mtu, conn.output)
 	conn.kcp.WndSize(effectiveConfig.Sndwnd, effectiveConfig.Rcvwnd)
 	conn.kcp.NoDelay(1, 20, 2, 1)
 	conn.kcp.current = conn.Elapsed()
@@ -199,7 +187,7 @@ func (this *Connection) NotifyTermination() {
 		this.RUnlock()
 		buffer := alloc.NewSmallBuffer().Clear()
 		buffer.AppendBytes(byte(CommandTerminate), byte(OptionClose), byte(0), byte(0), byte(0), byte(0))
-		this.output(buffer)
+		this.outputBuffer(buffer)
 
 		time.Sleep(time.Second)
 
@@ -207,6 +195,19 @@ func (this *Connection) NotifyTermination() {
 	this.Terminate()
 }
 
+func (this *Connection) ForceTimeout() {
+	if this == nil {
+		return
+	}
+	for i := 0; i < 5; i++ {
+		if this.state == ConnStateClosed {
+			return
+		}
+		time.Sleep(time.Minute)
+	}
+	go this.terminateOnce.Do(this.NotifyTermination)
+}
+
 // Close closes the connection.
 func (this *Connection) Close() error {
 	if this == nil || this.state == ConnStateClosed || this.state == ConnStateReadyToClose {
@@ -219,7 +220,9 @@ func (this *Connection) Close() error {
 	if this.state == ConnStateActive {
 		this.state = ConnStateReadyToClose
 		if this.kcp.WaitSnd() == 0 {
-			go this.NotifyTermination()
+			go this.terminateOnce.Do(this.NotifyTermination)
+		} else {
+			go this.ForceTimeout()
 		}
 	}
 
@@ -280,7 +283,7 @@ func (this *Connection) SetWriteDeadline(t time.Time) error {
 	return nil
 }
 
-func (this *Connection) output(payload *alloc.Buffer) {
+func (this *Connection) outputBuffer(payload *alloc.Buffer) {
 	defer payload.Release()
 	if this == nil {
 		return
@@ -296,6 +299,29 @@ func (this *Connection) output(payload *alloc.Buffer) {
 	this.writer.Write(payload.Value)
 }
 
+func (this *Connection) output(payload []byte) {
+	if this == nil || this.state == ConnStateClosed {
+		return
+	}
+
+	if this.state == ConnStateReadyToClose && this.kcp.WaitSnd() == 0 {
+		go this.terminateOnce.Do(this.NotifyTermination)
+	}
+
+	if len(payload) < IKCP_OVERHEAD {
+		return
+	}
+
+	buffer := alloc.NewBuffer().Clear().Append(payload)
+	cmd := CommandData
+	opt := Option(0)
+	if this.state == ConnStateReadyToClose {
+		opt = OptionClose
+	}
+	buffer.Prepend([]byte{byte(cmd), byte(opt)})
+	this.outputBuffer(buffer)
+}
+
 // kcp update, input loop
 func (this *Connection) updateTask() {
 	for this.state != ConnStateClosed {

+ 6 - 6
transport/internet/kcp/kcp.go

@@ -34,7 +34,7 @@ const (
 )
 
 // Output is a closure which captures conn and calls conn.Write
-type Output func(buf []byte, size int)
+type Output func(buf []byte)
 
 /* encode 8 bits unsigned int */
 func ikcp_encode8u(p []byte, c byte) []byte {
@@ -573,7 +573,7 @@ func (kcp *KCP) flush() {
 	for i := 0; i < count; i++ {
 		size := len(buffer) - len(ptr)
 		if size+IKCP_OVERHEAD > int(kcp.mtu) {
-			kcp.output(buffer, size)
+			kcp.output(buffer[:size])
 			ptr = buffer
 		}
 		seg.sn, seg.ts = kcp.ack_get(i)
@@ -609,7 +609,7 @@ func (kcp *KCP) flush() {
 		seg.cmd = IKCP_CMD_WASK
 		size := len(buffer) - len(ptr)
 		if size+IKCP_OVERHEAD > int(kcp.mtu) {
-			kcp.output(buffer, size)
+			kcp.output(buffer[:size])
 			ptr = buffer
 		}
 		ptr = seg.encode(ptr)
@@ -620,7 +620,7 @@ func (kcp *KCP) flush() {
 		seg.cmd = IKCP_CMD_WINS
 		size := len(buffer) - len(ptr)
 		if size+IKCP_OVERHEAD > int(kcp.mtu) {
-			kcp.output(buffer, size)
+			kcp.output(buffer[:size])
 			ptr = buffer
 		}
 		ptr = seg.encode(ptr)
@@ -703,7 +703,7 @@ func (kcp *KCP) flush() {
 			need := IKCP_OVERHEAD + len(segment.data)
 
 			if size+need >= int(kcp.mtu) {
-				kcp.output(buffer, size)
+				kcp.output(buffer[:size])
 				ptr = buffer
 			}
 
@@ -720,7 +720,7 @@ func (kcp *KCP) flush() {
 	// flash remain segments
 	size := len(buffer) - len(ptr)
 	if size > 0 {
-		kcp.output(buffer, size)
+		kcp.output(buffer[:size])
 	}
 
 	// update ssthresh

+ 3 - 25
transport/internet/tcp/connection_cache.go

@@ -3,32 +3,10 @@ package tcp
 import (
 	"net"
 	"sync"
-	"sync/atomic"
 	"time"
-)
-
-type Once struct {
-	m    sync.Mutex
-	done uint32
-}
 
-func (o *Once) Do(f func()) {
-	if atomic.LoadUint32(&o.done) == 1 {
-		return
-	}
-	o.m.Lock()
-	defer o.m.Unlock()
-	if o.done == 0 {
-		atomic.StoreUint32(&o.done, 1)
-		f()
-	}
-}
-
-func (o *Once) Reset() {
-	o.m.Lock()
-	defer o.m.Unlock()
-	atomic.StoreUint32(&o.done, 0)
-}
+	"github.com/v2ray/v2ray-core/common/signal"
+)
 
 type AwaitingConnection struct {
 	conn   net.Conn
@@ -42,7 +20,7 @@ func (this *AwaitingConnection) Expired() bool {
 type ConnectionCache struct {
 	sync.Mutex
 	cache       map[string][]*AwaitingConnection
-	cleanupOnce Once
+	cleanupOnce signal.Once
 }
 
 func NewConnectionCache() *ConnectionCache {