Browse Source

simplify metadata reader

Darien Raymond 8 years ago
parent
commit
27c099dd37
3 changed files with 20 additions and 32 deletions
  1. 4 6
      app/proxyman/mux/mux.go
  2. 9 10
      app/proxyman/mux/mux_test.go
  3. 7 16
      app/proxyman/mux/reader.go

+ 4 - 6
app/proxyman/mux/mux.go

@@ -235,10 +235,9 @@ func (m *Client) fetchOutput() {
 	defer m.cancel()
 
 	reader := buf.ToBytesReader(m.inboundRay.InboundOutput())
-	metaReader := NewMetadataReader(reader)
 
 	for {
-		meta, err := metaReader.Read()
+		meta, err := ReadMetadata(reader)
 		if err != nil {
 			if errors.Cause(err) != io.EOF {
 				log.Trace(newError("failed to read metadata").Base(err))
@@ -370,8 +369,8 @@ func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader io.Reader) er
 	return nil
 }
 
-func (w *ServerWorker) handleFrame(ctx context.Context, reader io.Reader, metaReader *MetadataReader) error {
-	meta, err := metaReader.Read()
+func (w *ServerWorker) handleFrame(ctx context.Context, reader io.Reader) error {
+	meta, err := ReadMetadata(reader)
 	if err != nil {
 		return newError("failed to read metadata").Base(err)
 	}
@@ -398,7 +397,6 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader io.Reader, metaRe
 func (w *ServerWorker) run(ctx context.Context) {
 	input := w.outboundRay.OutboundInput()
 	reader := buf.ToBytesReader(input)
-	metaReader := NewMetadataReader(reader)
 
 	defer w.sessionManager.Close()
 
@@ -407,7 +405,7 @@ func (w *ServerWorker) run(ctx context.Context) {
 		case <-ctx.Done():
 			return
 		default:
-			err := w.handleFrame(ctx, reader, metaReader)
+			err := w.handleFrame(ctx, reader)
 			if err != nil {
 				if errors.Cause(err) != io.EOF {
 					log.Trace(newError("unexpected EOF").Base(err))

+ 9 - 10
app/proxyman/mux/mux_test.go

@@ -61,10 +61,9 @@ func TestReaderWriter(t *testing.T) {
 	writer2.Close()
 
 	bytesReader := buf.ToBytesReader(stream)
-	metaReader := NewMetadataReader(bytesReader)
 	streamReader := NewStreamReader(bytesReader)
 
-	meta, err := metaReader.Read()
+	meta, err := ReadMetadata(bytesReader)
 	assert(err, IsNil)
 	assert(meta.SessionID, Equals, uint16(1))
 	assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew))
@@ -76,14 +75,14 @@ func TestReaderWriter(t *testing.T) {
 	assert(len(data), Equals, 1)
 	assert(data[0].String(), Equals, "abcd")
 
-	meta, err = metaReader.Read()
+	meta, err = ReadMetadata(bytesReader)
 	assert(err, IsNil)
 	assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew))
 	assert(meta.SessionID, Equals, uint16(2))
 	assert(byte(meta.Option), Equals, byte(0))
 	assert(meta.Target, Equals, dest2)
 
-	meta, err = metaReader.Read()
+	meta, err = ReadMetadata(bytesReader)
 	assert(err, IsNil)
 	assert(byte(meta.SessionStatus), Equals, byte(SessionStatusKeep))
 	assert(meta.SessionID, Equals, uint16(1))
@@ -94,7 +93,7 @@ func TestReaderWriter(t *testing.T) {
 	assert(len(data), Equals, 1)
 	assert(data[0].String(), Equals, "efgh")
 
-	meta, err = metaReader.Read()
+	meta, err = ReadMetadata(bytesReader)
 	assert(err, IsNil)
 	assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew))
 	assert(meta.SessionID, Equals, uint16(3))
@@ -106,19 +105,19 @@ func TestReaderWriter(t *testing.T) {
 	assert(len(data), Equals, 1)
 	assert(data[0].String(), Equals, "x")
 
-	meta, err = metaReader.Read()
+	meta, err = ReadMetadata(bytesReader)
 	assert(err, IsNil)
 	assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd))
 	assert(meta.SessionID, Equals, uint16(1))
 	assert(byte(meta.Option), Equals, byte(0))
 
-	meta, err = metaReader.Read()
+	meta, err = ReadMetadata(bytesReader)
 	assert(err, IsNil)
 	assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd))
 	assert(meta.SessionID, Equals, uint16(3))
 	assert(byte(meta.Option), Equals, byte(0))
 
-	meta, err = metaReader.Read()
+	meta, err = ReadMetadata(bytesReader)
 	assert(err, IsNil)
 	assert(byte(meta.SessionStatus), Equals, byte(SessionStatusKeep))
 	assert(meta.SessionID, Equals, uint16(2))
@@ -129,7 +128,7 @@ func TestReaderWriter(t *testing.T) {
 	assert(len(data), Equals, 1)
 	assert(data[0].String(), Equals, "y")
 
-	meta, err = metaReader.Read()
+	meta, err = ReadMetadata(bytesReader)
 	assert(err, IsNil)
 	assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd))
 	assert(meta.SessionID, Equals, uint16(2))
@@ -137,7 +136,7 @@ func TestReaderWriter(t *testing.T) {
 
 	stream.Close()
 
-	meta, err = metaReader.Read()
+	meta, err = ReadMetadata(bytesReader)
 	assert(err, IsNotNil)
 	assert(meta, IsNil)
 }

+ 7 - 16
app/proxyman/mux/reader.go

@@ -7,20 +7,8 @@ import (
 	"v2ray.com/core/common/serial"
 )
 
-type MetadataReader struct {
-	reader io.Reader
-	buffer []byte
-}
-
-func NewMetadataReader(reader io.Reader) *MetadataReader {
-	return &MetadataReader{
-		reader: reader,
-		buffer: make([]byte, 1024),
-	}
-}
-
-func (r *MetadataReader) Read() (*FrameMetadata, error) {
-	metaLen, err := serial.ReadUint16(r.reader)
+func ReadMetadata(reader io.Reader) (*FrameMetadata, error) {
+	metaLen, err := serial.ReadUint16(reader)
 	if err != nil {
 		return nil, err
 	}
@@ -28,10 +16,13 @@ func (r *MetadataReader) Read() (*FrameMetadata, error) {
 		return nil, newError("invalid metalen ", metaLen).AtWarning()
 	}
 
-	if _, err := io.ReadFull(r.reader, r.buffer[:metaLen]); err != nil {
+	b := buf.New()
+	defer b.Release()
+
+	if err := b.Reset(buf.ReadFullFrom(reader, int(metaLen))); err != nil {
 		return nil, err
 	}
-	return ReadFrameFrom(r.buffer)
+	return ReadFrameFrom(b.Bytes())
 }
 
 type PacketReader struct {