Browse Source

fix reader/writer for packet conn

Darien Raymond 7 years ago
parent
commit
3340f81d03
2 changed files with 56 additions and 2 deletions
  1. 25 2
      common/buf/io.go
  2. 31 0
      common/buf/reader.go

+ 25 - 2
common/buf/io.go

@@ -2,6 +2,7 @@ package buf
 
 import (
 	"io"
+	"net"
 	"syscall"
 	"time"
 )
@@ -38,6 +39,11 @@ func WriteAllBytes(writer io.Writer, payload []byte) error {
 	return nil
 }
 
+func isPacketReader(reader io.Reader) bool {
+	_, ok := reader.(net.PacketConn)
+	return ok
+}
+
 // NewReader creates a new Reader.
 // The Reader instance doesn't take the ownership of reader.
 func NewReader(reader io.Reader) Reader {
@@ -45,6 +51,12 @@ func NewReader(reader io.Reader) Reader {
 		return mr
 	}
 
+	if isPacketReader(reader) {
+		return &PacketReader{
+			Reader: reader,
+		}
+	}
+
 	if useReadv {
 		if sc, ok := reader.(syscall.Conn); ok {
 			rawConn, err := sc.SyscallConn()
@@ -61,14 +73,25 @@ func NewReader(reader io.Reader) Reader {
 	}
 }
 
+func isPacketWriter(writer io.Writer) bool {
+	if _, ok := writer.(net.PacketConn); ok {
+		return true
+	}
+
+	// If the writer doesn't implement syscall.Conn, it is probably not a TCP connection.
+	if _, ok := writer.(syscall.Conn); !ok {
+		return true
+	}
+	return false
+}
+
 // NewWriter creates a new Writer.
 func NewWriter(writer io.Writer) Writer {
 	if mw, ok := writer.(Writer); ok {
 		return mw
 	}
 
-	if _, ok := writer.(syscall.Conn); !ok {
-		// If the writer doesn't implement syscall.Conn, it is probably not a TCP connection.
+	if isPacketWriter(writer) {
 		return &SequentialWriter{
 			Writer: writer,
 		}

+ 31 - 0
common/buf/reader.go

@@ -7,6 +7,23 @@ import (
 	"v2ray.com/core/common/errors"
 )
 
+func readOneUDP(r io.Reader) (*Buffer, error) {
+	b := New()
+	for i := 0; i < 64; i++ {
+		_, err := b.ReadFrom(r)
+		if !b.IsEmpty() {
+			return b, nil
+		}
+		if err != nil {
+			b.Release()
+			return nil, err
+		}
+	}
+
+	b.Release()
+	return nil, newError("Reader returns too many empty payloads.")
+}
+
 func readOne(r io.Reader) (*Buffer, error) {
 	// Use an one-byte buffer to wait for incoming payload.
 	var firstByte [1]byte
@@ -152,3 +169,17 @@ func (r *SingleReader) ReadMultiBuffer() (MultiBuffer, error) {
 	}
 	return MultiBuffer{b}, nil
 }
+
+// PacketReader is a Reader that read one Buffer every time.
+type PacketReader struct {
+	io.Reader
+}
+
+// ReadMultiBuffer implements Reader.
+func (r *PacketReader) ReadMultiBuffer() (MultiBuffer, error) {
+	b, err := readOneUDP(r.Reader)
+	if err != nil {
+		return nil, err
+	}
+	return MultiBuffer{b}, nil
+}