|
|
@@ -65,11 +65,11 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
|
|
|
}
|
|
|
port := net.PortFromBytes(buffer.BytesRange(2, 4))
|
|
|
address := net.IPAddress(buffer.BytesRange(4, 8))
|
|
|
- if _, err := readUntilNull(reader); /* user id */ err != nil {
|
|
|
+ if _, err := ReadUntilNull(reader); /* user id */ err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
if address.IP()[0] == 0x00 {
|
|
|
- domain, err := readUntilNull(reader)
|
|
|
+ domain, err := ReadUntilNull(reader)
|
|
|
if err != nil {
|
|
|
return nil, newError("failed to read domain for socks 4a").Base(err)
|
|
|
}
|
|
|
@@ -113,7 +113,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
|
|
|
}
|
|
|
|
|
|
if expectedAuth == authPassword {
|
|
|
- username, password, err := readUsernamePassword(reader)
|
|
|
+ username, password, err := ReadUsernamePassword(reader)
|
|
|
if err != nil {
|
|
|
return nil, newError("failed to read username and password for authentication").Base(err)
|
|
|
}
|
|
|
@@ -183,7 +183,13 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
|
|
|
return nil, newError("unknown Socks version: ", version)
|
|
|
}
|
|
|
|
|
|
-func readUsernamePassword(reader io.Reader) (string, string, error) {
|
|
|
+// ReadUsernamePassword reads Socks 5 username/password message from the given reader.
|
|
|
+// +----+------+----------+------+----------+
|
|
|
+// |VER | ULEN | UNAME | PLEN | PASSWD |
|
|
|
+// +----+------+----------+------+----------+
|
|
|
+// | 1 | 1 | 1 to 255 | 1 | 1 to 255 |
|
|
|
+// +----+------+----------+------+----------+
|
|
|
+func ReadUsernamePassword(reader io.Reader) (string, string, error) {
|
|
|
buffer := buf.New()
|
|
|
defer buffer.Release()
|
|
|
|
|
|
@@ -212,19 +218,21 @@ func readUsernamePassword(reader io.Reader) (string, string, error) {
|
|
|
return username, password, nil
|
|
|
}
|
|
|
|
|
|
-func readUntilNull(reader io.Reader) (string, error) {
|
|
|
- var b [256]byte
|
|
|
- size := 0
|
|
|
+// ReadUntilNull reads content from given reader, until a null (0x00) byte.
|
|
|
+func ReadUntilNull(reader io.Reader) (string, error) {
|
|
|
+ b := buf.New()
|
|
|
+ defer b.Release()
|
|
|
+
|
|
|
for {
|
|
|
- _, err := reader.Read(b[size : size+1])
|
|
|
+ _, err := b.ReadFullFrom(reader, 1)
|
|
|
if err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
- if b[size] == 0x00 {
|
|
|
- return string(b[:size]), nil
|
|
|
+ if b.Byte(b.Len()-1) == 0x00 {
|
|
|
+ b.Resize(0, b.Len()-1)
|
|
|
+ return b.String(), nil
|
|
|
}
|
|
|
- size++
|
|
|
- if size == 256 {
|
|
|
+ if b.IsFull() {
|
|
|
return "", newError("buffer overrun")
|
|
|
}
|
|
|
}
|