Ver código fonte

clean udp writer

Darien Raymond 8 anos atrás
pai
commit
498c7dafdf

+ 6 - 0
common/buf/io.go

@@ -128,6 +128,12 @@ func NewMergingWriterSize(writer io.Writer, size uint32) Writer {
 	}
 }
 
+func NewSequentialWriter(writer io.Writer) Writer {
+	return &seqWriter{
+		writer: writer,
+	}
+}
+
 // ToBytesWriter converts a Writer to io.Writer
 func ToBytesWriter(writer Writer) io.Writer {
 	return &bytesToBufferWriter{

+ 19 - 0
common/buf/writer.go

@@ -42,6 +42,25 @@ func (w *mergingWriter) Write(mb MultiBuffer) error {
 	return nil
 }
 
+type seqWriter struct {
+	writer io.Writer
+}
+
+func (w *seqWriter) Write(mb MultiBuffer) error {
+	defer mb.Release()
+
+	for _, b := range mb {
+		if b.IsEmpty() {
+			continue
+		}
+		if _, err := w.writer.Write(b.Bytes()); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
 type bytesToBufferWriter struct {
 	writer Writer
 }

+ 1 - 18
proxy/freedom/freedom.go

@@ -4,7 +4,6 @@ package freedom
 
 import (
 	"context"
-	"io"
 	"runtime"
 	"time"
 
@@ -117,7 +116,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
 		if destination.Network == net.Network_TCP {
 			writer = buf.NewWriter(conn)
 		} else {
-			writer = &seqWriter{writer: conn}
+			writer = buf.NewSequentialWriter(conn)
 		}
 		if err := buf.Copy(timer, input, writer); err != nil {
 			return newError("failed to process request").Base(err)
@@ -151,19 +150,3 @@ func init() {
 		return New(ctx, config.(*Config))
 	}))
 }
-
-type seqWriter struct {
-	writer io.Writer
-}
-
-func (w *seqWriter) Write(mb buf.MultiBuffer) error {
-	defer mb.Release()
-
-	for _, b := range mb {
-		if _, err := w.writer.Write(b.Bytes()); err != nil {
-			return err
-		}
-	}
-
-	return nil
-}

+ 2 - 2
proxy/shadowsocks/client.go

@@ -135,10 +135,10 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
 
 	if request.Command == protocol.RequestCommandUDP {
 
-		writer := &UDPWriter{
+		writer := buf.NewSequentialWriter(&UDPWriter{
 			Writer:  conn,
 			Request: request,
-		}
+		})
 
 		requestDone := signal.ExecuteAsync(func() error {
 			if err := buf.Copy(timer, outboundRay.OutboundInput(), writer); err != nil {

+ 8 - 19
proxy/shadowsocks/protocol.go

@@ -238,7 +238,7 @@ func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Wr
 	return buf.NewWriter(crypto.NewCryptionWriter(stream, writer)), nil
 }
 
-func EncodeUDPPacket(request *protocol.RequestHeader, payload *buf.Buffer) (*buf.Buffer, error) {
+func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buffer, error) {
 	user := request.User
 	rawAccount, err := user.GetTypedAccount()
 	if err != nil {
@@ -266,7 +266,7 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload *buf.Buffer) (*buf
 	}
 
 	buffer.AppendSupplier(serial.WriteUint16(uint16(request.Port)))
-	buffer.Append(payload.Bytes())
+	buffer.Append(payload)
 
 	if request.Option.Has(RequestOptionOneTimeAuth) {
 		authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv))
@@ -382,23 +382,12 @@ type UDPWriter struct {
 	Request *protocol.RequestHeader
 }
 
-func (w *UDPWriter) Write(mb buf.MultiBuffer) error {
-	defer mb.Release()
-
-	for _, b := range mb {
-		if err := w.writeInternal(b); err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-func (w *UDPWriter) writeInternal(buffer *buf.Buffer) error {
-	payload, err := EncodeUDPPacket(w.Request, buffer)
+func (w *UDPWriter) Write(payload []byte) (int, error) {
+	packet, err := EncodeUDPPacket(w.Request, payload)
 	if err != nil {
-		return err
+		return 0, err
 	}
-	_, err = w.Writer.Write(payload.Bytes())
-	payload.Release()
-	return err
+	_, err = w.Writer.Write(packet.Bytes())
+	packet.Release()
+	return len(payload), err
 }

+ 3 - 3
proxy/shadowsocks/protocol_test.go

@@ -31,7 +31,7 @@ func TestUDPEncoding(t *testing.T) {
 
 	data := buf.NewLocal(256)
 	data.AppendSupplier(serial.WriteString("test string"))
-	encodedData, err := EncodeUDPPacket(request, data)
+	encodedData, err := EncodeUDPPacket(request, data.Bytes())
 	assert.Error(err).IsNil()
 
 	decodedRequest, decodedData, err := DecodeUDPPacket(request.User, encodedData)
@@ -88,7 +88,7 @@ func TestUDPReaderWriter(t *testing.T) {
 		}),
 	}
 	cache := buf.New()
-	writer := &UDPWriter{
+	writer := buf.NewSequentialWriter(&UDPWriter{
 		Writer: cache,
 		Request: &protocol.RequestHeader{
 			Version: Version,
@@ -97,7 +97,7 @@ func TestUDPReaderWriter(t *testing.T) {
 			User:    user,
 			Option:  RequestOptionOneTimeAuth,
 		},
-	}
+	})
 
 	reader := &UDPReader{
 		Reader: cache,

+ 1 - 1
proxy/shadowsocks/server.go

@@ -113,7 +113,7 @@ func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 			udpServer.Dispatch(ctx, dest, data, func(payload *buf.Buffer) {
 				defer payload.Release()
 
-				data, err := EncodeUDPPacket(request, payload)
+				data, err := EncodeUDPPacket(request, payload.Bytes())
 				if err != nil {
 					log.Trace(newError("failed to encode UDP packet").Base(err).AtWarning())
 					return

+ 1 - 1
proxy/socks/client.go

@@ -103,7 +103,7 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay, dialer proxy.
 		}
 		defer udpConn.Close()
 		requestFunc = func() error {
-			return buf.Copy(timer, ray.OutboundInput(), &UDPWriter{request: request, writer: udpConn})
+			return buf.Copy(timer, ray.OutboundInput(), buf.NewSequentialWriter(NewUDPWriter(request, udpConn)))
 		}
 		responseFunc = func() error {
 			defer ray.OutboundOutput().Close()

+ 6 - 10
proxy/socks/protocol.go

@@ -369,17 +369,13 @@ func NewUDPWriter(request *protocol.RequestHeader, writer io.Writer) *UDPWriter
 	}
 }
 
-func (w *UDPWriter) Write(mb buf.MultiBuffer) error {
-	defer mb.Release()
-
-	for _, b := range mb {
-		eb := EncodeUDPPacket(w.request, b.Bytes())
-		defer eb.Release()
-		if _, err := w.writer.Write(eb.Bytes()); err != nil {
-			return err
-		}
+func (w *UDPWriter) Write(b []byte) (int, error) {
+	eb := EncodeUDPPacket(w.request, b)
+	defer eb.Release()
+	if _, err := w.writer.Write(eb.Bytes()); err != nil {
+		return 0, err
 	}
-	return nil
+	return len(b), nil
 }
 
 func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {

+ 1 - 1
proxy/socks/protocol_test.go

@@ -19,7 +19,7 @@ func TestUDPEncoding(t *testing.T) {
 		Address: net.IPAddress([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}),
 		Port:    1024,
 	}
-	writer := NewUDPWriter(request, b)
+	writer := buf.NewSequentialWriter(NewUDPWriter(request, b))
 
 	content := []byte{'a'}
 	payload := buf.New()