瀏覽代碼

unify address reading in socks and shadowsocks

Darien Raymond 7 年之前
父節點
當前提交
1077e33d62
共有 3 個文件被更改,包括 107 次插入56 次删除
  1. 11 27
      proxy/shadowsocks/protocol.go
  2. 41 29
      proxy/socks/protocol.go
  3. 55 0
      proxy/socks/protocol_test.go

+ 11 - 27
proxy/shadowsocks/protocol.go

@@ -8,6 +8,7 @@ import (
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/bitmask"
 	"v2ray.com/core/common/buf"
+	"v2ray.com/core/common/dice"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/proxy/socks"
@@ -76,36 +77,18 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
 	}
 
 	addrType := (buffer.Byte(0) & 0x0F)
-	switch addrType {
-	case AddrTypeIPv4:
-		if err := buffer.AppendSupplier(buf.ReadFullFrom(br, 4)); err != nil {
-			return nil, nil, newError("failed to read IPv4 address").Base(err)
-		}
-		request.Address = net.IPAddress(buffer.BytesFrom(-4))
-	case AddrTypeIPv6:
-		if err := buffer.AppendSupplier(buf.ReadFullFrom(br, 16)); err != nil {
-			return nil, nil, newError("failed to read IPv6 address").Base(err)
-		}
-		request.Address = net.IPAddress(buffer.BytesFrom(-16))
-	case AddrTypeDomain:
-		if err := buffer.AppendSupplier(buf.ReadFullFrom(br, 1)); err != nil {
-			return nil, nil, newError("failed to read domain lenth.").Base(err)
-		}
-		domainLength := int(buffer.BytesFrom(-1)[0])
-		err = buffer.AppendSupplier(buf.ReadFullFrom(br, domainLength))
-		if err != nil {
-			return nil, nil, newError("failed to read domain").Base(err)
-		}
-		request.Address = net.DomainAddress(string(buffer.BytesFrom(-domainLength)))
-	default:
-		// Check address validity after OTA verification.
-	}
 
-	err = buffer.AppendSupplier(buf.ReadFullFrom(br, 2))
+	addr, port, err := socks.ReadAddress(buffer, addrType, br)
 	if err != nil {
-		return nil, nil, newError("failed to read port").Base(err)
+		// Invalid address. Continue to read some bytes to confuse client.
+		nBytes := dice.Roll(32)
+		buffer.Clear()
+		buffer.AppendSupplier(buf.ReadFullFrom(br, nBytes))
+		return nil, nil, newError("failed to read address").Base(err)
 	}
-	request.Port = net.PortFromBytes(buffer.BytesFrom(-2))
+
+	request.Address = addr
+	request.Port = port
 
 	if request.Option.Has(RequestOptionOneTimeAuth) {
 		actualAuth := make([]byte, AuthSize)
@@ -320,6 +303,7 @@ func DecodeUDPPacket(user *protocol.User, payload *buf.Buffer) (*protocol.Reques
 
 	addrType := (payload.Byte(0) & 0x0F)
 	payload.SliceFrom(1)
+
 	switch addrType {
 	case AddrTypeIPv4:
 		request.Address = net.IPAddress(payload.BytesTo(4))

+ 41 - 29
proxy/socks/protocol.go

@@ -279,7 +279,7 @@ func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, po
 func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error {
 	buffer := buf.NewLocal(32)
 	buffer.AppendBytes(0x00, errCode)
-	buffer.AppendSupplier(serial.WriteUint16(port.Value()))
+	common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value())))
 	buffer.Append(address.IP())
 	_, err := writer.Write(buffer.Bytes())
 	return err
@@ -392,6 +392,40 @@ func (w *UDPWriter) Write(b []byte) (int, error) {
 	return len(b), nil
 }
 
+func ReadAddress(b *buf.Buffer, addrType byte, reader io.Reader) (net.Address, net.Port, error) {
+	var address net.Address
+	switch addrType {
+	case addrTypeIPv4:
+		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
+			return nil, 0, err
+		}
+		address = net.IPAddress(b.BytesFrom(-4))
+	case addrTypeIPv6:
+		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil {
+			return nil, 0, err
+		}
+		address = net.IPAddress(b.BytesFrom(-16))
+	case addrTypeDomain:
+		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
+			return nil, 0, err
+		}
+		domainLength := int(b.Byte(b.Len() - 1))
+		if err := b.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil {
+			return nil, 0, err
+		}
+		address = net.DomainAddress(string(b.BytesFrom(-domainLength)))
+	default:
+		return nil, 0, newError("unknown address type: ", addrType)
+	}
+
+	if err := b.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil {
+		return nil, 0, err
+	}
+	port := net.PortFromBytes(b.BytesFrom(-2))
+
+	return address, port, nil
+}
+
 func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {
 	authByte := byte(authNotRequired)
 	if request.User != nil {
@@ -444,7 +478,10 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 		command = byte(cmdUDPPort)
 	}
 	b.AppendBytes(socks5Version, command, 0x00 /* reserved */)
-	AppendAddress(b, request.Address, request.Port)
+	if err := AppendAddress(b, request.Address, request.Port); err != nil {
+		return nil, err
+	}
+
 	if _, err := writer.Write(b.Bytes()); err != nil {
 		return nil, err
 	}
@@ -463,35 +500,10 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 
 	b.Clear()
 
-	var address net.Address
-	switch addrType {
-	case addrTypeIPv4:
-		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
-			return nil, err
-		}
-		address = net.IPAddress(b.Bytes())
-	case addrTypeIPv6:
-		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil {
-			return nil, err
-		}
-		address = net.IPAddress(b.Bytes())
-	case addrTypeDomain:
-		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
-			return nil, err
-		}
-		domainLength := int(b.Byte(0))
-		if err := b.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil {
-			return nil, err
-		}
-		address = net.DomainAddress(string(b.BytesFrom(-domainLength)))
-	default:
-		return nil, newError("unknown address type: ", addrType)
-	}
-
-	if err := b.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil {
+	address, port, err := ReadAddress(b, addrType, reader)
+	if err != nil {
 		return nil, err
 	}
-	port := net.PortFromBytes(b.BytesFrom(-2))
 
 	if request.Command == protocol.RequestCommandUDP {
 		udpRequest := &protocol.RequestHeader{

+ 55 - 0
proxy/socks/protocol_test.go

@@ -1,10 +1,12 @@
 package socks_test
 
 import (
+	"bytes"
 	"testing"
 
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
+	_ "v2ray.com/core/common/net/testing"
 	"v2ray.com/core/common/protocol"
 	. "v2ray.com/core/proxy/socks"
 	. "v2ray.com/ext/assert"
@@ -32,3 +34,56 @@ func TestUDPEncoding(t *testing.T) {
 	assert(err, IsNil)
 	assert(decodedPayload[0].Bytes(), Equals, content)
 }
+
+func TestReadAddress(t *testing.T) {
+	assert := With(t)
+
+	data := []struct {
+		AddrType byte
+		Input    []byte
+		Address  net.Address
+		Port     net.Port
+		Error    bool
+	}{
+		{
+			AddrType: 0,
+			Input:    []byte{0, 0, 0, 0},
+			Error:    true,
+		},
+		{
+			AddrType: 1,
+			Input:    []byte{0, 0, 0, 0, 0, 53},
+			Address:  net.IPAddress([]byte{0, 0, 0, 0}),
+			Port:     net.Port(53),
+		},
+		{
+			AddrType: 4,
+			Input:    []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 80},
+			Address:  net.IPAddress([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}),
+			Port:     net.Port(80),
+		},
+		{
+			AddrType: 3,
+			Input:    []byte{9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0, 80},
+			Address:  net.DomainAddress("v2ray.com"),
+			Port:     net.Port(80),
+		},
+		{
+			AddrType: 3,
+			Input:    []byte{9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0},
+			Error:    true,
+		},
+	}
+
+	for _, tc := range data {
+		b := buf.New()
+		addr, port, err := ReadAddress(b, tc.AddrType, bytes.NewBuffer(tc.Input))
+		b.Release()
+		if tc.Error {
+			assert(err, IsNotNil)
+		} else {
+			assert(addr, Equals, tc.Address)
+			assert(port, Equals, tc.Port)
+		}
+	}
+}