Explorar o código

reuse buffered writer in auth writer

Darien Raymond %!s(int64=8) %!d(string=hai) anos
pai
achega
ade88fd5c7

+ 17 - 14
common/buf/buffered_writer.go

@@ -11,10 +11,14 @@ type BufferedWriter struct {
 }
 
 // NewBufferedWriter creates a new BufferedWriter.
-func NewBufferedWriter(rawWriter io.Writer) *BufferedWriter {
+func NewBufferedWriter(writer io.Writer) *BufferedWriter {
+	return NewBufferedWriterSize(writer, 1024)
+}
+
+func NewBufferedWriterSize(writer io.Writer, size uint32) *BufferedWriter {
 	return &BufferedWriter{
-		writer:   rawWriter,
-		buffer:   NewLocal(1024),
+		writer:   writer,
+		buffer:   NewLocal(int(size)),
 		buffered: true,
 	}
 }
@@ -24,21 +28,20 @@ func (w *BufferedWriter) Write(b []byte) (int, error) {
 	if !w.buffered || w.buffer == nil {
 		return w.writer.Write(b)
 	}
-	nBytes, err := w.buffer.Write(b)
-	if err != nil {
-		return 0, err
-	}
-	if w.buffer.IsFull() {
-		if err := w.Flush(); err != nil {
-			return 0, err
+	bytesWritten := 0
+	for bytesWritten < len(b) {
+		nBytes, err := w.buffer.Write(b[bytesWritten:])
+		if err != nil {
+			return bytesWritten, err
 		}
-		if nBytes < len(b) {
-			if _, err := w.writer.Write(b[nBytes:]); err != nil {
-				return nBytes, err
+		bytesWritten += nBytes
+		if w.buffer.IsFull() {
+			if err := w.Flush(); err != nil {
+				return bytesWritten, err
 			}
 		}
 	}
-	return len(b), nil
+	return bytesWritten, nil
 }
 
 // Flush writes all buffered content into underlying writer, if any.

+ 1 - 0
common/buf/buffered_writer_test.go

@@ -47,6 +47,7 @@ func TestBufferedWriterLargePayload(t *testing.T) {
 
 	nBytes, err = writer.Write(payload[512:])
 	assert.Error(err).IsNil()
+	assert.Error(writer.Flush()).IsNil()
 	assert.Int(nBytes).Equals(64*1024 - 512)
 	assert.Bytes(content.Bytes()).Equals(payload)
 }

+ 30 - 40
common/crypto/auth.go

@@ -4,6 +4,7 @@ import (
 	"crypto/cipher"
 	"io"
 
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/protocol"
 )
@@ -123,7 +124,12 @@ func (r *AuthenticationReader) readChunk(waitForData bool) ([]byte, error) {
 		if !waitForData {
 			return nil, io.ErrNoProgress
 		}
-		r.buffer.Reset(buf.ReadFrom(r.buffer))
+
+		if r.buffer.IsEmpty() {
+			r.buffer.Clear()
+		} else {
+			common.Must(r.buffer.Reset(buf.ReadFrom(r.buffer)))
+		}
 
 		delta := r.size - r.buffer.Len()
 		if err := r.buffer.AppendSupplier(buf.ReadAtLeastFrom(r.reader, delta)); err != nil {
@@ -184,42 +190,39 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
 
 type AuthenticationWriter struct {
 	auth         Authenticator
+	buffer       []byte
 	payload      []byte
-	buffer       *buf.Buffer
-	writer       io.Writer
+	writer       *buf.BufferedWriter
 	sizeParser   ChunkSizeEncoder
 	transferType protocol.TransferType
 }
 
 func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType) *AuthenticationWriter {
+	const payloadSize = 1024
 	return &AuthenticationWriter{
 		auth:         auth,
-		payload:      make([]byte, 1024),
-		buffer:       buf.NewLocal(readerBufferSize),
-		writer:       writer,
+		buffer:       make([]byte, payloadSize+sizeParser.SizeBytes()+auth.Overhead()),
+		payload:      make([]byte, payloadSize),
+		writer:       buf.NewBufferedWriterSize(writer, readerBufferSize),
 		sizeParser:   sizeParser,
 		transferType: transferType,
 	}
 }
 
-func (w *AuthenticationWriter) append(b []byte) {
+func (w *AuthenticationWriter) append(b []byte) error {
 	encryptedSize := len(b) + w.auth.Overhead()
+	buffer := w.sizeParser.Encode(uint16(encryptedSize), w.buffer[:0])
 
-	w.buffer.AppendSupplier(func(bb []byte) (int, error) {
-		w.sizeParser.Encode(uint16(encryptedSize), bb[:0])
-		return w.sizeParser.SizeBytes(), nil
-	})
+	buffer, err := w.auth.Seal(buffer, b)
+	if err != nil {
+		return err
+	}
 
-	w.buffer.AppendSupplier(func(bb []byte) (int, error) {
-		w.auth.Seal(bb[:0], b)
-		return encryptedSize, nil
-	})
-}
+	if _, err := w.writer.Write(buffer); err != nil {
+		return err
+	}
 
-func (w *AuthenticationWriter) flush() error {
-	_, err := w.writer.Write(w.buffer.Bytes())
-	w.buffer.Clear()
-	return err
+	return nil
 }
 
 func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
@@ -227,21 +230,15 @@ func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
 
 	for {
 		n, _ := mb.Read(w.payload)
-		w.append(w.payload[:n])
-		if w.buffer.Len() > readerBufferSize-2*1024 {
-			if err := w.flush(); err != nil {
-				return err
-			}
+		if err := w.append(w.payload[:n]); err != nil {
+			return err
 		}
 		if mb.IsEmpty() {
 			break
 		}
 	}
 
-	if !w.buffer.IsEmpty() {
-		return w.flush()
-	}
-	return nil
+	return w.writer.Flush()
 }
 
 func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error {
@@ -252,24 +249,17 @@ func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error {
 		if b == nil {
 			b = buf.New()
 		}
-		if w.buffer.Len() > readerBufferSize-b.Len()-128 {
-			if err := w.flush(); err != nil {
-				b.Release()
-				return err
-			}
+		if err := w.append(b.Bytes()); err != nil {
+			b.Release()
+			return err
 		}
-		w.append(b.Bytes())
 		b.Release()
 		if mb.IsEmpty() {
 			break
 		}
 	}
 
-	if !w.buffer.IsEmpty() {
-		return w.flush()
-	}
-
-	return nil
+	return w.writer.Flush()
 }
 
 func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error {

+ 2 - 2
proxy/vmess/inbound/inbound.go

@@ -181,7 +181,7 @@ func (v *Handler) Process(ctx context.Context, network net.Network, connection i
 	if err != nil {
 		if errors.Cause(err) != io.EOF {
 			log.Access(connection.RemoteAddr(), "", log.AccessRejected, err)
-			log.Trace(newError("invalid request from ", connection.RemoteAddr(), ": ", err))
+			log.Trace(newError("invalid request from ", connection.RemoteAddr(), ": ", err).AtInfo())
 		}
 		return err
 	}
@@ -194,7 +194,7 @@ func (v *Handler) Process(ctx context.Context, network net.Network, connection i
 	log.Access(connection.RemoteAddr(), request.Destination(), log.AccessAccepted, "")
 	log.Trace(newError("received request for ", request.Destination()))
 
-	connection.SetReadDeadline(time.Time{})
+	common.Must(connection.SetReadDeadline(time.Time{}))
 
 	userSettings := request.User.GetSettings()
 

+ 13 - 0
testing/scenarios/dokodemo_test.go

@@ -5,6 +5,7 @@ import (
 	"testing"
 
 	"v2ray.com/core"
+	"v2ray.com/core/app/log"
 	"v2ray.com/core/app/proxyman"
 	v2net "v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
@@ -55,6 +56,12 @@ func TestDokodemoTCP(t *testing.T) {
 				ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
 			},
 		},
+		App: []*serial.TypedMessage{
+			serial.ToTypedMessage(&log.Config{
+				ErrorLogLevel: log.LogLevel_Debug,
+				ErrorLogType:  log.LogType_Console,
+			}),
+		},
 	}
 
 	clientPort := uint32(pickPort())
@@ -94,6 +101,12 @@ func TestDokodemoTCP(t *testing.T) {
 				}),
 			},
 		},
+		App: []*serial.TypedMessage{
+			serial.ToTypedMessage(&log.Config{
+				ErrorLogLevel: log.LogLevel_Debug,
+				ErrorLogType:  log.LogType_Console,
+			}),
+		},
 	}
 
 	servers, err := InitializeServerConfigs(serverConfig, clientConfig)