Darien Raymond преди 8 години
родител
ревизия
bf7b8798a9
променени са 4 файла, в които са добавени 60 реда и са изтрити 188 реда
  1. 15 13
      transport/internet/kcp/connection.go
  2. 15 39
      transport/internet/kcp/connection_test.go
  3. 22 77
      transport/internet/kcp/dialer.go
  4. 8 59
      transport/internet/kcp/listener.go

+ 15 - 13
transport/internet/kcp/connection.go

@@ -168,13 +168,15 @@ type SystemConnection interface {
 	Overhead() int
 }
 
-var (
-	_ buf.Reader = (*Connection)(nil)
-)
+type ConnMetadata struct {
+	LocalAddr  net.Addr
+	RemoteAddr net.Addr
+}
 
 // Connection is a KCP connection over UDP.
 type Connection struct {
-	conn       SystemConnection
+	meta       *ConnMetadata
+	closer     io.Closer
 	rd         time.Time
 	wd         time.Time // write deadline
 	since      int64
@@ -201,24 +203,24 @@ type Connection struct {
 }
 
 // NewConnection create a new KCP connection between local and remote.
-func NewConnection(conv uint16, sysConn SystemConnection, config *Config) *Connection {
+func NewConnection(conv uint16, meta *ConnMetadata, writer PacketWriter, closer io.Closer, config *Config) *Connection {
 	log.Trace(newError("creating connection ", conv))
 
 	conn := &Connection{
 		conv:       conv,
-		conn:       sysConn,
+		meta:       meta,
+		closer:     closer,
 		since:      nowMillisec(),
 		dataInput:  make(chan bool, 1),
 		dataOutput: make(chan bool, 1),
 		Config:     config,
-		output:     NewRetryableWriter(NewSegmentWriter(sysConn)),
-		mss:        config.GetMTUValue() - uint32(sysConn.Overhead()) - DataSegmentOverhead,
+		output:     NewRetryableWriter(NewSegmentWriter(writer)),
+		mss:        config.GetMTUValue() - uint32(writer.Overhead()) - DataSegmentOverhead,
 		roundTrip: &RoundTripInfo{
 			rto:    100,
 			minRtt: config.GetTTIValue(),
 		},
 	}
-	sysConn.Reset(conn.Input)
 
 	conn.receivingWorker = NewReceivingWorker(conn)
 	conn.sendingWorker = NewSendingWorker(conn)
@@ -413,7 +415,7 @@ func (v *Connection) Close() error {
 	if state.Is(StateReadyToClose, StateTerminating, StateTerminated) {
 		return ErrClosedConnection
 	}
-	log.Trace(newError("closing connection to ", v.conn.RemoteAddr()))
+	log.Trace(newError("closing connection to ", v.meta.RemoteAddr))
 
 	if state == StateActive {
 		v.SetState(StateReadyToClose)
@@ -433,7 +435,7 @@ func (v *Connection) LocalAddr() net.Addr {
 	if v == nil {
 		return nil
 	}
-	return v.conn.LocalAddr()
+	return v.meta.LocalAddr
 }
 
 // RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
@@ -441,7 +443,7 @@ func (v *Connection) RemoteAddr() net.Addr {
 	if v == nil {
 		return nil
 	}
-	return v.conn.RemoteAddr()
+	return v.meta.RemoteAddr
 }
 
 // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
@@ -488,7 +490,7 @@ func (v *Connection) Terminate() {
 	v.OnDataInput()
 	v.OnDataOutput()
 
-	v.conn.Close()
+	v.closer.Close()
 	v.sendingWorker.Release()
 	v.receivingWorker.Release()
 }

+ 15 - 39
transport/internet/kcp/connection_test.go

@@ -1,59 +1,27 @@
 package kcp_test
 
 import (
-	"net"
+	"io"
 	"testing"
 	"time"
 
+	"v2ray.com/core/common/buf"
 	. "v2ray.com/core/transport/internet/kcp"
 	. "v2ray.com/ext/assert"
 )
 
-type NoOpConn struct{}
+type NoOpCloser int
 
-func (o *NoOpConn) Overhead() int {
-	return 0
-}
-
-// Write implements io.Writer.
-func (o *NoOpConn) Write(b []byte) (int, error) {
-	return len(b), nil
-}
-
-func (o *NoOpConn) Close() error {
-	return nil
-}
-
-func (o *NoOpConn) Read([]byte) (int, error) {
-	panic("Should not be called.")
-}
-
-func (o *NoOpConn) LocalAddr() net.Addr {
-	return nil
-}
-
-func (o *NoOpConn) RemoteAddr() net.Addr {
-	return nil
-}
-
-func (o *NoOpConn) SetDeadline(time.Time) error {
+func (NoOpCloser) Close() error {
 	return nil
 }
 
-func (o *NoOpConn) SetReadDeadline(time.Time) error {
-	return nil
-}
-
-func (o *NoOpConn) SetWriteDeadline(time.Time) error {
-	return nil
-}
-
-func (o *NoOpConn) Reset(input func([]Segment)) {}
-
 func TestConnectionReadTimeout(t *testing.T) {
 	assert := With(t)
 
-	conn := NewConnection(1, &NoOpConn{}, &Config{})
+	conn := NewConnection(1, &ConnMetadata{}, &KCPPacketWriter{
+		Writer: buf.DiscardBytes,
+	}, NoOpCloser(0), &Config{})
 	conn.SetReadDeadline(time.Now().Add(time.Second))
 
 	b := make([]byte, 1024)
@@ -63,3 +31,11 @@ func TestConnectionReadTimeout(t *testing.T) {
 
 	conn.Terminate()
 }
+
+func TestConnectionInterface(t *testing.T) {
+	assert := With(t)
+
+	assert((*Connection)(nil), Implements, (*io.Writer)(nil))
+	assert((*Connection)(nil), Implements, (*io.Reader)(nil))
+	assert((*Connection)(nil), Implements, (*buf.Reader)(nil))
+}

+ 22 - 77
transport/internet/kcp/dialer.go

@@ -2,9 +2,8 @@ package kcp
 
 import (
 	"context"
-	"crypto/cipher"
 	"crypto/tls"
-	"sync"
+	"io"
 	"sync/atomic"
 
 	"v2ray.com/core/app/log"
@@ -20,84 +19,20 @@ var (
 	globalConv = uint32(dice.RollUint16())
 )
 
-type ClientConnection struct {
-	sync.RWMutex
-	net.Conn
-	input  func([]Segment)
-	reader PacketReader
-	writer PacketWriter
-}
-
-func (c *ClientConnection) Overhead() int {
-	c.RLock()
-	defer c.RUnlock()
-	if c.writer == nil {
-		return 0
-	}
-	return c.writer.Overhead()
-}
-
-// Write implements io.Writer.
-func (c *ClientConnection) Write(b []byte) (int, error) {
-	c.RLock()
-	defer c.RUnlock()
-
-	if c.writer == nil {
-		return len(b), nil
-	}
-
-	return c.writer.Write(b)
-}
-
-func (*ClientConnection) Read([]byte) (int, error) {
-	panic("KCP|ClientConnection: Read should not be called.")
-}
-
-func (c *ClientConnection) Close() error {
-	return c.Conn.Close()
-}
-
-func (c *ClientConnection) Reset(inputCallback func([]Segment)) {
-	c.Lock()
-	c.input = inputCallback
-	c.Unlock()
-}
-
-func (c *ClientConnection) ResetSecurity(header internet.PacketHeader, security cipher.AEAD) {
-	c.Lock()
-	if c.reader == nil {
-		c.reader = new(KCPPacketReader)
-	}
-	c.reader.(*KCPPacketReader).Header = header
-	c.reader.(*KCPPacketReader).Security = security
-	if c.writer == nil {
-		c.writer = new(KCPPacketWriter)
-	}
-	c.writer.(*KCPPacketWriter).Header = header
-	c.writer.(*KCPPacketWriter).Security = security
-	c.writer.(*KCPPacketWriter).Writer = c.Conn
-
-	c.Unlock()
-}
-
-func (c *ClientConnection) Run() {
+func fetchInput(ctx context.Context, input io.Reader, reader PacketReader, conn *Connection) {
 	payload := buf.New()
 	defer payload.Release()
 
 	for {
-		err := payload.Reset(buf.ReadFrom(c.Conn))
+		err := payload.Reset(buf.ReadFrom(input))
 		if err != nil {
 			payload.Release()
 			return
 		}
-		c.RLock()
-		if c.input != nil {
-			segments := c.reader.Read(payload.Bytes())
-			if len(segments) > 0 {
-				c.input(segments)
-			}
+		segments := reader.Read(payload.Bytes())
+		if len(segments) > 0 {
+			conn.Input(segments)
 		}
-		c.RUnlock()
 	}
 }
 
@@ -110,10 +45,6 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er
 	if err != nil {
 		return nil, newError("failed to dial to dest: ", err).AtWarning().Base(err)
 	}
-	conn := &ClientConnection{
-		Conn: rawConn,
-	}
-	go conn.Run()
 
 	kcpSettings := internet.TransportSettingsFromContext(ctx).(*Config)
 
@@ -125,9 +56,23 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er
 	if err != nil {
 		return nil, newError("failed to create security").Base(err)
 	}
-	conn.ResetSecurity(header, security)
+	reader := &KCPPacketReader{
+		Header:   header,
+		Security: security,
+	}
+	writer := &KCPPacketWriter{
+		Header:   header,
+		Security: security,
+		Writer:   rawConn,
+	}
+
 	conv := uint16(atomic.AddUint32(&globalConv, 1))
-	session := NewConnection(conv, conn, kcpSettings)
+	session := NewConnection(conv, &ConnMetadata{
+		LocalAddr:  rawConn.LocalAddr(),
+		RemoteAddr: rawConn.RemoteAddr(),
+	}, writer, rawConn, kcpSettings)
+
+	go fetchInput(ctx, rawConn, reader, session)
 
 	var iConn internet.Connection = session
 

+ 8 - 59
transport/internet/kcp/listener.go

@@ -4,9 +4,7 @@ import (
 	"context"
 	"crypto/cipher"
 	"crypto/tls"
-	"io"
 	"sync"
-	"time"
 
 	"v2ray.com/core/app/log"
 	"v2ray.com/core/common"
@@ -23,52 +21,6 @@ type ConnectionID struct {
 	Conv   uint16
 }
 
-type ServerConnection struct {
-	local  net.Addr
-	remote net.Addr
-	writer PacketWriter
-	closer io.Closer
-}
-
-func (c *ServerConnection) Overhead() int {
-	return c.writer.Overhead()
-}
-
-func (*ServerConnection) Read([]byte) (int, error) {
-	panic("KCP|ServerConnection: Read should not be called.")
-}
-
-func (c *ServerConnection) Write(b []byte) (int, error) {
-	return c.writer.Write(b)
-}
-
-func (c *ServerConnection) Close() error {
-	return c.closer.Close()
-}
-
-func (*ServerConnection) Reset(input func([]Segment)) {
-}
-
-func (c *ServerConnection) LocalAddr() net.Addr {
-	return c.local
-}
-
-func (c *ServerConnection) RemoteAddr() net.Addr {
-	return c.remote
-}
-
-func (*ServerConnection) SetDeadline(time.Time) error {
-	return nil
-}
-
-func (*ServerConnection) SetReadDeadline(time.Time) error {
-	return nil
-}
-
-func (*ServerConnection) SetWriteDeadline(time.Time) error {
-	return nil
-}
-
 // Listener defines a server listening for connections
 type Listener struct {
 	sync.Mutex
@@ -172,17 +124,14 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src net.Destination, originalD
 			Port: int(src.Port),
 		}
 		localAddr := v.hub.Addr()
-		sConn := &ServerConnection{
-			local:  localAddr,
-			remote: remoteAddr,
-			writer: &KCPPacketWriter{
-				Header:   v.header,
-				Writer:   writer,
-				Security: v.security,
-			},
-			closer: writer,
-		}
-		conn = NewConnection(conv, sConn, v.config)
+		conn = NewConnection(conv, &ConnMetadata{
+			LocalAddr:  localAddr,
+			RemoteAddr: remoteAddr,
+		}, &KCPPacketWriter{
+			Header:   v.header,
+			Security: v.security,
+			Writer:   writer,
+		}, writer, v.config)
 		var netConn internet.Connection = conn
 		if v.tlsConfig != nil {
 			tlsConn := tls.Server(conn, v.tlsConfig)