Browse Source

use buffer for reading user id in socks

Darien Raymond 7 years ago
parent
commit
5c5816072e
2 changed files with 94 additions and 12 deletions
  1. 20 12
      proxy/socks/protocol.go
  2. 74 0
      proxy/socks/protocol_test.go

+ 20 - 12
proxy/socks/protocol.go

@@ -65,11 +65,11 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
 		}
 		port := net.PortFromBytes(buffer.BytesRange(2, 4))
 		address := net.IPAddress(buffer.BytesRange(4, 8))
-		if _, err := readUntilNull(reader); /* user id */ err != nil {
+		if _, err := ReadUntilNull(reader); /* user id */ err != nil {
 			return nil, err
 		}
 		if address.IP()[0] == 0x00 {
-			domain, err := readUntilNull(reader)
+			domain, err := ReadUntilNull(reader)
 			if err != nil {
 				return nil, newError("failed to read domain for socks 4a").Base(err)
 			}
@@ -113,7 +113,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
 		}
 
 		if expectedAuth == authPassword {
-			username, password, err := readUsernamePassword(reader)
+			username, password, err := ReadUsernamePassword(reader)
 			if err != nil {
 				return nil, newError("failed to read username and password for authentication").Base(err)
 			}
@@ -183,7 +183,13 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
 	return nil, newError("unknown Socks version: ", version)
 }
 
-func readUsernamePassword(reader io.Reader) (string, string, error) {
+// ReadUsernamePassword reads Socks 5 username/password message from the given reader.
+// +----+------+----------+------+----------+
+// |VER | ULEN |  UNAME   | PLEN |  PASSWD  |
+// +----+------+----------+------+----------+
+// | 1  |  1   | 1 to 255 |  1   | 1 to 255 |
+// +----+------+----------+------+----------+
+func ReadUsernamePassword(reader io.Reader) (string, string, error) {
 	buffer := buf.New()
 	defer buffer.Release()
 
@@ -212,19 +218,21 @@ func readUsernamePassword(reader io.Reader) (string, string, error) {
 	return username, password, nil
 }
 
-func readUntilNull(reader io.Reader) (string, error) {
-	var b [256]byte
-	size := 0
+// ReadUntilNull reads content from given reader, until a null (0x00) byte.
+func ReadUntilNull(reader io.Reader) (string, error) {
+	b := buf.New()
+	defer b.Release()
+
 	for {
-		_, err := reader.Read(b[size : size+1])
+		_, err := b.ReadFullFrom(reader, 1)
 		if err != nil {
 			return "", err
 		}
-		if b[size] == 0x00 {
-			return string(b[:size]), nil
+		if b.Byte(b.Len()-1) == 0x00 {
+			b.Resize(0, b.Len()-1)
+			return b.String(), nil
 		}
-		size++
-		if size == 256 {
+		if b.IsFull() {
 			return "", newError("buffer overrun")
 		}
 	}

+ 74 - 0
proxy/socks/protocol_test.go

@@ -1,6 +1,7 @@
 package socks_test
 
 import (
+	"bytes"
 	"testing"
 
 	"v2ray.com/core/common/buf"
@@ -33,3 +34,76 @@ func TestUDPEncoding(t *testing.T) {
 	assert(err, IsNil)
 	assert(decodedPayload[0].Bytes(), Equals, content)
 }
+
+func TestReadUsernamePassword(t *testing.T) {
+	testCases := []struct {
+		Input    []byte
+		Username string
+		Password string
+		Error    bool
+	}{
+		{
+			Input:    []byte{0x05, 0x01, 'a', 0x02, 'b', 'c'},
+			Username: "a",
+			Password: "bc",
+		},
+		{
+			Input: []byte{0x05, 0x18, 'a', 0x02, 'b', 'c'},
+			Error: true,
+		},
+	}
+
+	for _, testCase := range testCases {
+		reader := bytes.NewReader(testCase.Input)
+		username, password, err := ReadUsernamePassword(reader)
+		if testCase.Error {
+			if err == nil {
+				t.Error("for input: ", testCase.Input, " expect error, but actually nil")
+			}
+		} else {
+			if err != nil {
+				t.Error("for input: ", testCase.Input, " expect no error, but actually ", err.Error())
+			}
+			if testCase.Username != username {
+				t.Error("for input: ", testCase.Input, " expect username ", testCase.Username, " but actually ", username)
+			}
+			if testCase.Password != password {
+				t.Error("for input: ", testCase.Input, " expect passowrd ", testCase.Password, " but actually ", password)
+			}
+		}
+	}
+}
+
+func TestReadUntilNull(t *testing.T) {
+	testCases := []struct {
+		Input  []byte
+		Output string
+		Error  bool
+	}{
+		{
+			Input:  []byte{'a', 'b', 0x00},
+			Output: "ab",
+		},
+		{
+			Input: []byte{'a'},
+			Error: true,
+		},
+	}
+
+	for _, testCase := range testCases {
+		reader := bytes.NewReader(testCase.Input)
+		value, err := ReadUntilNull(reader)
+		if testCase.Error {
+			if err == nil {
+				t.Error("for input: ", testCase.Input, " expect error, but actually nil")
+			}
+		} else {
+			if err != nil {
+				t.Error("for input: ", testCase.Input, " expect no error, but actually ", err.Error())
+			}
+			if testCase.Output != value {
+				t.Error("for input: ", testCase.Input, " expect output ", testCase.Output, " but actually ", value)
+			}
+		}
+	}
+}