Darien Raymond 8 лет назад
Родитель
Сommit
5f4acaa9ea

+ 22 - 0
common/buf/io.go

@@ -87,6 +87,17 @@ func NewReader(reader io.Reader) Reader {
 	}
 }
 
+func NewMergingReader(reader io.Reader) Reader {
+	return NewMergingReaderSize(reader, 32*1024)
+}
+
+func NewMergingReaderSize(reader io.Reader, size uint32) Reader {
+	return &BytesToBufferReader{
+		reader: reader,
+		buffer: make([]byte, size),
+	}
+}
+
 // ToBytesReader converts a Reaaer to io.Reader.
 func ToBytesReader(stream Reader) io.Reader {
 	return &bufferToBytesReader{
@@ -107,6 +118,17 @@ func NewWriter(writer io.Writer) Writer {
 	}
 }
 
+func NewMergingWriter(writer io.Writer) Writer {
+	return NewMergingWriterSize(writer, 4096)
+}
+
+func NewMergingWriterSize(writer io.Writer, size uint32) Writer {
+	return &mergingWriter{
+		writer: writer,
+		buffer: make([]byte, size),
+	}
+}
+
 // ToBytesWriter converts a Writer to io.Writer
 func ToBytesWriter(writer Writer) io.Writer {
 	return &bytesToBufferWriter{

+ 17 - 0
common/buf/writer.go

@@ -25,6 +25,23 @@ func (w *writerAdapter) Write(mb MultiBuffer) error {
 	return err
 }
 
+type mergingWriter struct {
+	writer io.Writer
+	buffer []byte
+}
+
+func (w *mergingWriter) Write(mb MultiBuffer) error {
+	defer mb.Release()
+
+	for !mb.IsEmpty() {
+		nBytes, _ := mb.Read(w.buffer)
+		if _, err := w.writer.Write(w.buffer[:nBytes]); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 type bytesToBufferWriter struct {
 	writer Writer
 }

+ 9 - 16
transport/internet/kcp/connection.go

@@ -168,6 +168,10 @@ type SystemConnection interface {
 	Overhead() int
 }
 
+var (
+	_ buf.MultiBufferWriter = (*Connection)(nil)
+)
+
 // Connection is a KCP connection over UDP.
 type Connection struct {
 	conn       SystemConnection
@@ -194,6 +198,8 @@ type Connection struct {
 
 	dataUpdater *Updater
 	pingUpdater *Updater
+
+	mergingWriter buf.Writer
 }
 
 // NewConnection create a new KCP connection between local and remote.
@@ -332,23 +338,10 @@ func (v *Connection) Write(b []byte) (int, error) {
 }
 
 func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) (int, error) {
-	defer mb.Release()
-
-	buffer := buf.New()
-	defer buffer.Release()
-
-	totalBytes := 0
-	for !mb.IsEmpty() {
-		buffer.Reset(func(b []byte) (int, error) {
-			return mb.Read(b[:c.mss])
-		})
-		nBytes, err := c.Write(buffer.Bytes())
-		totalBytes += nBytes
-		if err != nil {
-			return totalBytes, err
-		}
+	if c.mergingWriter == nil {
+		c.mergingWriter = buf.NewMergingWriterSize(c, c.mss)
 	}
-	return totalBytes, nil
+	return mb.Len(), c.mergingWriter.Write(mb)
 }
 
 func (v *Connection) SetState(state State) {

+ 2 - 3
transport/internet/tcp/dialer.go

@@ -2,13 +2,12 @@ package tcp
 
 import (
 	"context"
-	"crypto/tls"
 
 	"v2ray.com/core/app/log"
 	"v2ray.com/core/common"
 	v2net "v2ray.com/core/common/net"
 	"v2ray.com/core/transport/internet"
-	v2tls "v2ray.com/core/transport/internet/tls"
+	"v2ray.com/core/transport/internet/tls"
 )
 
 func getTCPSettingsFromContext(ctx context.Context) *Config {
@@ -28,7 +27,7 @@ func Dial(ctx context.Context, dest v2net.Destination) (internet.Connection, err
 		return nil, err
 	}
 	if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
-		tlsConfig, ok := securitySettings.(*v2tls.Config)
+		tlsConfig, ok := securitySettings.(*tls.Config)
 		if ok {
 			config := tlsConfig.GetTLSConfig()
 			if dest.Address.Family().IsDomain() {

+ 4 - 4
transport/internet/tcp/hub.go

@@ -2,7 +2,7 @@ package tcp
 
 import (
 	"context"
-	"crypto/tls"
+	gotls "crypto/tls"
 	"net"
 	"time"
 
@@ -11,7 +11,7 @@ import (
 	v2net "v2ray.com/core/common/net"
 	"v2ray.com/core/common/retry"
 	"v2ray.com/core/transport/internet"
-	v2tls "v2ray.com/core/transport/internet/tls"
+	"v2ray.com/core/transport/internet/tls"
 )
 
 var (
@@ -21,7 +21,7 @@ var (
 type TCPListener struct {
 	ctx        context.Context
 	listener   *net.TCPListener
-	tlsConfig  *tls.Config
+	tlsConfig  *gotls.Config
 	authConfig internet.ConnectionAuthenticator
 	config     *Config
 	conns      chan<- internet.Connection
@@ -46,7 +46,7 @@ func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port, conn
 		conns:    conns,
 	}
 	if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
-		tlsConfig, ok := securitySettings.(*v2tls.Config)
+		tlsConfig, ok := securitySettings.(*tls.Config)
 		if ok {
 			l.tlsConfig = tlsConfig.GetTLSConfig()
 		}

+ 38 - 0
transport/internet/tls/tls.go

@@ -1,3 +1,41 @@
 package tls
 
+import (
+	"crypto/tls"
+	"net"
+
+	"v2ray.com/core/common/buf"
+)
+
 //go:generate go run $GOPATH/src/v2ray.com/core/tools/generrorgen/main.go -pkg tls -path Transport,Internet,TLS
+
+type conn struct {
+	net.Conn
+
+	mergingReader buf.Reader
+	mergingWriter buf.Writer
+}
+
+func (c *conn) ReadMultiBuffer() (buf.MultiBuffer, error) {
+	if c.mergingReader == nil {
+		c.mergingReader = buf.NewMergingReaderSize(c.Conn, 16*1024)
+	}
+	return c.mergingReader.Read()
+}
+
+func (c *conn) WriteMultiBuffer(mb buf.MultiBuffer) (int, error) {
+	if c.mergingWriter == nil {
+		c.mergingWriter = buf.NewMergingWriter(c.Conn)
+	}
+	return mb.Len(), c.mergingWriter.Write(mb)
+}
+
+func Client(c net.Conn, config *tls.Config) net.Conn {
+	tlsConn := tls.Client(c, config)
+	return &conn{Conn: tlsConn}
+}
+
+func Server(c net.Conn, config *tls.Config) net.Conn {
+	tlsConn := tls.Server(c, config)
+	return &conn{Conn: tlsConn}
+}

+ 20 - 16
transport/internet/websocket/connection.go

@@ -10,11 +10,18 @@ import (
 	"v2ray.com/core/common/errors"
 )
 
+var (
+	_ buf.MultiBufferReader = (*connection)(nil)
+	_ buf.MultiBufferWriter = (*connection)(nil)
+)
+
 // connection is a wrapper for net.Conn over WebSocket connection.
 type connection struct {
-	wsc         *websocket.Conn
-	reader      io.Reader
-	writeBuffer []byte
+	wsc    *websocket.Conn
+	reader io.Reader
+
+	mergingReader buf.Reader
+	mergingWriter buf.Writer
 }
 
 // Read implements net.Conn.Read()
@@ -34,6 +41,13 @@ func (c *connection) Read(b []byte) (int, error) {
 	}
 }
 
+func (c *connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
+	if c.mergingReader == nil {
+		c.mergingReader = buf.NewMergingReader(c)
+	}
+	return c.mergingReader.Read()
+}
+
 func (c *connection) getReader() (io.Reader, error) {
 	if c.reader != nil {
 		return c.reader, nil
@@ -55,20 +69,10 @@ func (c *connection) Write(b []byte) (int, error) {
 }
 
 func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) (int, error) {
-	defer mb.Release()
-
-	if c.writeBuffer == nil {
-		c.writeBuffer = make([]byte, 4096)
-	}
-	totalBytes := 0
-	for !mb.IsEmpty() {
-		nBytes, _ := mb.Read(c.writeBuffer)
-		totalBytes += nBytes
-		if _, err := c.Write(c.writeBuffer[:nBytes]); err != nil {
-			return totalBytes, err
-		}
+	if c.mergingWriter == nil {
+		c.mergingWriter = buf.NewMergingWriter(c)
 	}
-	return totalBytes, nil
+	return mb.Len(), c.mergingWriter.Write(mb)
 }
 
 func (c *connection) Close() error {