Browse Source

fix connection reading in UDP

Darien Raymond 6 years ago
parent
commit
4e77570f36
5 changed files with 97 additions and 3 deletions
  1. 11 0
      common/buf/multi_buffer.go
  2. 9 2
      common/buf/reader.go
  3. 9 0
      common/net/connection.go
  4. 7 1
      functions.go
  5. 61 0
      functions_test.go

+ 11 - 0
common/buf/multi_buffer.go

@@ -122,6 +122,17 @@ func SplitBytes(mb MultiBuffer, b []byte) (MultiBuffer, int) {
 	return mb, totalBytes
 	return mb, totalBytes
 }
 }
 
 
+// SplitFirstBytes splits the first buffer from MultiBuffer, and then copy its content into the given slice.
+func SplitFirstBytes(mb MultiBuffer, p []byte) (MultiBuffer, int) {
+	mb, b := SplitFirst(mb)
+	if b == nil {
+		return mb, 0
+	}
+	n := copy(p, b.Bytes())
+	b.Release()
+	return mb, n
+}
+
 // Compact returns another MultiBuffer by merging all content of the given one together.
 // Compact returns another MultiBuffer by merging all content of the given one together.
 func Compact(mb MultiBuffer) MultiBuffer {
 func Compact(mb MultiBuffer) MultiBuffer {
 	if len(mb) == 0 {
 	if len(mb) == 0 {

+ 9 - 2
common/buf/reader.go

@@ -58,6 +58,8 @@ type BufferedReader struct {
 	Reader Reader
 	Reader Reader
 	// Buffer is the internal buffer to be read from first
 	// Buffer is the internal buffer to be read from first
 	Buffer MultiBuffer
 	Buffer MultiBuffer
+	// Spliter is a function to read bytes from MultiBuffer
+	Spliter func(MultiBuffer, []byte) (MultiBuffer, int)
 }
 }
 
 
 // BufferedBytes returns the number of bytes that is cached in this reader.
 // BufferedBytes returns the number of bytes that is cached in this reader.
@@ -74,8 +76,13 @@ func (r *BufferedReader) ReadByte() (byte, error) {
 
 
 // Read implements io.Reader. It reads from internal buffer first (if available) and then reads from the underlying reader.
 // Read implements io.Reader. It reads from internal buffer first (if available) and then reads from the underlying reader.
 func (r *BufferedReader) Read(b []byte) (int, error) {
 func (r *BufferedReader) Read(b []byte) (int, error) {
+	spliter := r.Spliter
+	if spliter == nil {
+		spliter = SplitBytes
+	}
+
 	if !r.Buffer.IsEmpty() {
 	if !r.Buffer.IsEmpty() {
-		buffer, nBytes := SplitBytes(r.Buffer, b)
+		buffer, nBytes := spliter(r.Buffer, b)
 		r.Buffer = buffer
 		r.Buffer = buffer
 		if r.Buffer.IsEmpty() {
 		if r.Buffer.IsEmpty() {
 			r.Buffer = nil
 			r.Buffer = nil
@@ -88,7 +95,7 @@ func (r *BufferedReader) Read(b []byte) (int, error) {
 		return 0, err
 		return 0, err
 	}
 	}
 
 
-	mb, nBytes := SplitBytes(mb, b)
+	mb, nBytes := spliter(mb, b)
 	if !mb.IsEmpty() {
 	if !mb.IsEmpty() {
 		r.Buffer = mb
 		r.Buffer = mb
 	}
 	}

+ 9 - 0
common/net/connection.go

@@ -48,6 +48,15 @@ func ConnectionOutputMulti(reader buf.Reader) ConnectionOption {
 	}
 	}
 }
 }
 
 
+func ConnectionOutputMultiUDP(reader buf.Reader) ConnectionOption {
+	return func(c *connection) {
+		c.reader = &buf.BufferedReader{
+			Reader:  reader,
+			Spliter: buf.SplitFirstBytes,
+		}
+	}
+}
+
 func ConnectionOnClose(n io.Closer) ConnectionOption {
 func ConnectionOnClose(n io.Closer) ConnectionOption {
 	return func(c *connection) {
 	return func(c *connection) {
 		c.onClose = n
 		c.onClose = n

+ 7 - 1
functions.go

@@ -53,7 +53,13 @@ func Dial(ctx context.Context, v *Instance, dest net.Destination) (net.Conn, err
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return net.NewConnection(net.ConnectionInputMulti(r.Writer), net.ConnectionOutputMulti(r.Reader)), nil
+	var readerOpt net.ConnectionOption
+	if dest.Network == net.Network_TCP {
+		readerOpt = net.ConnectionOutputMulti(r.Reader)
+	} else {
+		readerOpt = net.ConnectionOutputMultiUDP(r.Reader)
+	}
+	return net.NewConnection(net.ConnectionInputMulti(r.Writer), readerOpt), nil
 }
 }
 
 
 // DialUDP provides a way to exchange UDP packets through V2Ray instance to remote servers.
 // DialUDP provides a way to exchange UDP packets through V2Ray instance to remote servers.

+ 61 - 0
functions_test.go

@@ -5,6 +5,7 @@ import (
 	"crypto/rand"
 	"crypto/rand"
 	"io"
 	"io"
 	"testing"
 	"testing"
+	"time"
 
 
 	"github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/proto"
 	"github.com/google/go-cmp/cmp"
 	"github.com/google/go-cmp/cmp"
@@ -86,6 +87,66 @@ func TestV2RayDial(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestV2RayDialUDPConn(t *testing.T) {
+	udpServer := udp.Server{
+		MsgProcessor: xor,
+	}
+	dest, err := udpServer.Start()
+	common.Must(err)
+	defer udpServer.Close()
+
+	config := &core.Config{
+		App: []*serial.TypedMessage{
+			serial.ToTypedMessage(&dispatcher.Config{}),
+			serial.ToTypedMessage(&proxyman.InboundConfig{}),
+			serial.ToTypedMessage(&proxyman.OutboundConfig{}),
+		},
+		Outbound: []*core.OutboundHandlerConfig{
+			{
+				ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
+			},
+		},
+	}
+
+	cfgBytes, err := proto.Marshal(config)
+	common.Must(err)
+
+	server, err := core.StartInstance("protobuf", cfgBytes)
+	common.Must(err)
+	defer server.Close()
+
+	conn, err := core.Dial(context.Background(), server, dest)
+	common.Must(err)
+	defer conn.Close()
+
+	const size = 1024
+	payload := make([]byte, size)
+	common.Must2(rand.Read(payload))
+
+	for i := 0; i < 2; i++ {
+		if _, err := conn.Write(payload); err != nil {
+			t.Fatal(err)
+		}
+	}
+
+	time.Sleep(time.Millisecond * 500)
+
+	receive := make([]byte, size*2)
+	for i := 0; i < 2; i++ {
+		n, err := conn.Read(receive)
+		if err != nil {
+			t.Fatal("expect no error, but got ", err)
+		}
+		if n != size {
+			t.Fatal("expect read size ", size, " but got ", n)
+		}
+
+		if r := cmp.Diff(xor(receive[:n]), payload); r != "" {
+			t.Fatal(r)
+		}
+	}
+}
+
 func TestV2RayDialUDP(t *testing.T) {
 func TestV2RayDialUDP(t *testing.T) {
 	udpServer1 := udp.Server{
 	udpServer1 := udp.Server{
 		MsgProcessor: xor,
 		MsgProcessor: xor,