Selaa lähdekoodia

unify all address reading and writing

Darien Raymond 7 vuotta sitten
vanhempi
commit
af1abf687c

+ 183 - 0
common/protocol/address.go

@@ -0,0 +1,183 @@
+package protocol
+
+import (
+	"io"
+
+	"v2ray.com/core/common/buf"
+	"v2ray.com/core/common/net"
+)
+
+type AddressOption func(*AddressParser)
+
+func PortThenAddress() AddressOption {
+	return func(p *AddressParser) {
+		p.portFirst = true
+	}
+}
+
+func AddressFamilyByte(b byte, f net.AddressFamily) AddressOption {
+	return func(p *AddressParser) {
+		p.addrTypeMap[b] = f
+		p.addrByteMap[f] = b
+	}
+}
+
+type AddressTypeParser func(byte) byte
+
+func WithAddressTypeParser(atp AddressTypeParser) AddressOption {
+	return func(p *AddressParser) {
+		p.typeParser = atp
+	}
+}
+
+type AddressParser struct {
+	addrTypeMap map[byte]net.AddressFamily
+	addrByteMap map[net.AddressFamily]byte
+	portFirst   bool
+	typeParser  AddressTypeParser
+}
+
+func NewAddressParser(options ...AddressOption) *AddressParser {
+	p := &AddressParser{
+		addrTypeMap: make(map[byte]net.AddressFamily, 8),
+		addrByteMap: make(map[net.AddressFamily]byte, 8),
+	}
+	for _, opt := range options {
+		opt(p)
+	}
+	return p
+}
+
+func (p *AddressParser) readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) {
+	if err := b.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil {
+		return 0, err
+	}
+	return net.PortFromBytes(b.BytesFrom(-2)), nil
+}
+
+func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) {
+	if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
+		return nil, err
+	}
+
+	addrType := b.Byte(b.Len() - 1)
+	if p.typeParser != nil {
+		addrType = p.typeParser(addrType)
+	}
+
+	addrFamily, valid := p.addrTypeMap[addrType]
+	if !valid {
+		return nil, newError("unknown address type: ", addrType)
+	}
+
+	switch addrFamily {
+	case net.AddressFamilyIPv4:
+		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
+			return nil, err
+		}
+		return net.IPAddress(b.BytesFrom(-4)), nil
+	case net.AddressFamilyIPv6:
+		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil {
+			return nil, err
+		}
+		return net.IPAddress(b.BytesFrom(-16)), nil
+	case net.AddressFamilyDomain:
+		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
+			return nil, err
+		}
+		domainLength := int(b.Byte(b.Len() - 1))
+		if err := b.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil {
+			return nil, err
+		}
+		return net.DomainAddress(string(b.BytesFrom(-domainLength))), nil
+	default:
+		panic("impossible case")
+	}
+}
+
+func (p *AddressParser) ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error) {
+	if buffer == nil {
+		buffer = buf.New()
+		defer buffer.Release()
+	}
+
+	if p.portFirst {
+		port, err := p.readPort(buffer, input)
+		if err != nil {
+			return nil, 0, err
+		}
+		addr, err := p.readAddress(buffer, input)
+		if err != nil {
+			return nil, 0, err
+		}
+		return addr, port, nil
+	}
+
+	addr, err := p.readAddress(buffer, input)
+	if err != nil {
+		return nil, 0, err
+	}
+
+	port, err := p.readPort(buffer, input)
+	if err != nil {
+		return nil, 0, err
+	}
+
+	return addr, port, nil
+}
+
+func (p *AddressParser) writePort(writer io.Writer, port net.Port) error {
+	if _, err := writer.Write(port.Bytes(nil)); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (p *AddressParser) writeAddress(writer io.Writer, address net.Address) error {
+	tb, valid := p.addrByteMap[address.Family()]
+	if !valid {
+		return newError("unknown address family", address.Family())
+	}
+
+	switch address.Family() {
+	case net.AddressFamilyIPv4, net.AddressFamilyIPv6:
+		if _, err := writer.Write([]byte{tb}); err != nil {
+			return err
+		}
+		if _, err := writer.Write(address.IP()); err != nil {
+			return err
+		}
+	case net.AddressFamilyDomain:
+		domain := address.Domain()
+		if IsDomainTooLong(domain) {
+			return newError("Super long domain is not supported: ", domain)
+		}
+		if _, err := writer.Write([]byte{tb, byte(len(domain))}); err != nil {
+			return err
+		}
+		if _, err := writer.Write([]byte(domain)); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (p *AddressParser) WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error {
+	if p.portFirst {
+		if err := p.writePort(writer, port); err != nil {
+			return err
+		}
+		if err := p.writeAddress(writer, addr); err != nil {
+			return err
+		}
+		return nil
+	}
+
+	if err := p.writeAddress(writer, addr); err != nil {
+		return err
+	}
+	if err := p.writePort(writer, port); err != nil {
+		return err
+	}
+	return nil
+}

+ 70 - 0
common/protocol/address_test.go

@@ -0,0 +1,70 @@
+package protocol_test
+
+import (
+	"bytes"
+	"testing"
+
+	"v2ray.com/core/common/buf"
+	"v2ray.com/core/common/net"
+	. "v2ray.com/core/common/protocol"
+	. "v2ray.com/ext/assert"
+)
+
+func TestAddressParser(t *testing.T) {
+	assert := With(t)
+
+	data := []struct {
+		Options []AddressOption
+		Input   []byte
+		Address net.Address
+		Port    net.Port
+		Error   bool
+	}{
+		{
+			Options: []AddressOption{},
+			Input:   []byte{0, 0, 0, 0, 0},
+			Error:   true,
+		},
+		{
+			Options: []AddressOption{AddressFamilyByte(0x01, net.AddressFamilyIPv4)},
+			Input:   []byte{1, 0, 0, 0, 0, 0, 53},
+			Address: net.IPAddress([]byte{0, 0, 0, 0}),
+			Port:    net.Port(53),
+		},
+		{
+			Options: []AddressOption{AddressFamilyByte(0x01, net.AddressFamilyIPv4)},
+			Input:   []byte{1, 0, 0, 0, 0},
+			Error:   true,
+		},
+		{
+			Options: []AddressOption{AddressFamilyByte(0x04, net.AddressFamilyIPv6)},
+			Input:   []byte{4, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 80},
+			Address: net.IPAddress([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}),
+			Port:    net.Port(80),
+		},
+		{
+			Options: []AddressOption{AddressFamilyByte(0x03, net.AddressFamilyDomain)},
+			Input:   []byte{3, 9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0, 80},
+			Address: net.DomainAddress("v2ray.com"),
+			Port:    net.Port(80),
+		},
+		{
+			Options: []AddressOption{AddressFamilyByte(0x03, net.AddressFamilyDomain)},
+			Input:   []byte{3, 9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0},
+			Error:   true,
+		},
+	}
+
+	for _, tc := range data {
+		b := buf.New()
+		parser := NewAddressParser(tc.Options...)
+		addr, port, err := parser.ReadAddressPort(b, bytes.NewReader(tc.Input))
+		b.Release()
+		if tc.Error {
+			assert(err, IsNotNil)
+		} else {
+			assert(addr, Equals, tc.Address)
+			assert(port, Equals, tc.Port)
+		}
+	}
+}

+ 29 - 39
proxy/shadowsocks/protocol.go

@@ -11,16 +11,20 @@ import (
 	"v2ray.com/core/common/dice"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
-	"v2ray.com/core/proxy/socks"
 )
 
 const (
 	Version                               = 1
 	RequestOptionOneTimeAuth bitmask.Byte = 0x01
+)
 
-	AddrTypeIPv4   = 1
-	AddrTypeIPv6   = 4
-	AddrTypeDomain = 3
+var addrParser = protocol.NewAddressParser(
+	protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
+	protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
+	protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
+	protocol.WithAddressTypeParser(func(b byte) byte {
+		return b & 0x0F
+	}),
 )
 
 // ReadTCPSession reads a Shadowsocks TCP session from the given reader, returns its header and remaining parts.
@@ -58,10 +62,21 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
 		Command: protocol.RequestCommandTCP,
 	}
 
-	if err := buffer.Reset(buf.ReadFullFrom(br, 1)); err != nil {
-		return nil, nil, newError("failed to read address type").Base(err)
+	buffer.Clear()
+
+	addr, port, err := addrParser.ReadAddressPort(buffer, br)
+
+	if err != nil {
+		// Invalid address. Continue to read some bytes to confuse client.
+		nBytes := dice.Roll(32)
+		buffer.Clear()
+		buffer.AppendSupplier(buf.ReadFullFrom(br, nBytes))
+		return nil, nil, newError("failed to read address").Base(err)
 	}
 
+	request.Address = addr
+	request.Port = port
+
 	if !account.Cipher.IsAEAD() {
 		if (buffer.Byte(0) & 0x10) == 0x10 {
 			request.Option.Set(RequestOptionOneTimeAuth)
@@ -76,20 +91,6 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
 		}
 	}
 
-	addrType := (buffer.Byte(0) & 0x0F)
-
-	addr, port, err := socks.ReadAddress(buffer, addrType, br)
-	if err != nil {
-		// Invalid address. Continue to read some bytes to confuse client.
-		nBytes := dice.Roll(32)
-		buffer.Clear()
-		buffer.AppendSupplier(buf.ReadFullFrom(br, nBytes))
-		return nil, nil, newError("failed to read address").Base(err)
-	}
-
-	request.Address = addr
-	request.Port = port
-
 	if request.Option.Has(RequestOptionOneTimeAuth) {
 		actualAuth := make([]byte, AuthSize)
 		authenticator.Authenticate(buffer.Bytes())(actualAuth)
@@ -150,7 +151,7 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri
 
 	header := buf.NewLocal(512)
 
-	if err := socks.AppendAddress(header, request.Address, request.Port); err != nil {
+	if err := addrParser.WriteAddressPort(header, request.Address, request.Port); err != nil {
 		return nil, newError("failed to write address").Base(err)
 	}
 
@@ -230,7 +231,7 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
 	}
 	iv := buffer.Bytes()
 
-	if err := socks.AppendAddress(buffer, request.Address, request.Port); err != nil {
+	if err := addrParser.WriteAddressPort(buffer, request.Address, request.Port); err != nil {
 		return nil, newError("failed to write address").Base(err)
 	}
 
@@ -301,26 +302,15 @@ func DecodeUDPPacket(user *protocol.User, payload *buf.Buffer) (*protocol.Reques
 		}
 	}
 
-	addrType := (payload.Byte(0) & 0x0F)
-	payload.SliceFrom(1)
+	payload.SetByte(0, payload.Byte(0)&0x0F)
 
-	switch addrType {
-	case AddrTypeIPv4:
-		request.Address = net.IPAddress(payload.BytesTo(4))
-		payload.SliceFrom(4)
-	case AddrTypeIPv6:
-		request.Address = net.IPAddress(payload.BytesTo(16))
-		payload.SliceFrom(16)
-	case AddrTypeDomain:
-		domainLength := int(payload.Byte(0))
-		request.Address = net.DomainAddress(string(payload.BytesRange(1, 1+domainLength)))
-		payload.SliceFrom(1 + domainLength)
-	default:
-		return nil, nil, newError("unknown address type: ", addrType).AtError()
+	addr, port, err := addrParser.ReadAddressPort(nil, payload)
+	if err != nil {
+		return nil, nil, newError("failed to parse address").Base(err)
 	}
 
-	request.Port = net.PortFromBytes(payload.BytesTo(2))
-	payload.SliceFrom(2)
+	request.Address = addr
+	request.Port = port
 
 	return request, payload, nil
 }

+ 26 - 106
proxy/socks/protocol.go

@@ -34,6 +34,12 @@ const (
 	statusCmdNotSupport = 0x07
 )
 
+var addrParser = protocol.NewAddressParser(
+	protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
+	protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
+	protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
+)
+
 type ServerSession struct {
 	config *ServerConfig
 	port   net.Port
@@ -122,7 +128,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
 				return nil, newError("failed to write auth response").Base(err)
 			}
 		}
-		if err := buffer.Reset(buf.ReadFullFrom(reader, 4)); err != nil {
+		if err := buffer.Reset(buf.ReadFullFrom(reader, 3)); err != nil {
 			return nil, newError("failed to read request").Base(err)
 		}
 
@@ -139,13 +145,11 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
 			request.Command = protocol.RequestCommandUDP
 		}
 
-		addrType := buffer.Byte(3)
-
 		buffer.Clear()
 
 		request.Version = socks5Version
 
-		addr, port, err := ReadAddress(buffer, addrType, reader)
+		addr, port, err := addrParser.ReadAddressPort(buffer, reader)
 		if err != nil {
 			return nil, newError("failed to read address").Base(err)
 		}
@@ -229,30 +233,10 @@ func writeSocks5AuthenticationResponse(writer io.Writer, version byte, auth byte
 	return err
 }
 
-// AppendAddress appends Socks address into the given buffer.
-func AppendAddress(buffer *buf.Buffer, address net.Address, port net.Port) error {
-	switch address.Family() {
-	case net.AddressFamilyIPv4:
-		buffer.AppendBytes(addrTypeIPv4)
-		buffer.Append(address.IP())
-	case net.AddressFamilyIPv6:
-		buffer.AppendBytes(addrTypeIPv6)
-		buffer.Append(address.IP())
-	case net.AddressFamilyDomain:
-		if protocol.IsDomainTooLong(address.Domain()) {
-			return newError("Super long domain is not supported in Socks protocol: ", address.Domain())
-		}
-		buffer.AppendBytes(addrTypeDomain, byte(len(address.Domain())))
-		common.Must(buffer.AppendSupplier(serial.WriteString(address.Domain())))
-	}
-	common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value())))
-	return nil
-}
-
 func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error {
 	buffer := buf.NewLocal(64)
 	buffer.AppendBytes(socks5Version, errCode, 0x00 /* reserved */)
-	if err := AppendAddress(buffer, address, port); err != nil {
+	if err := addrParser.WriteAddressPort(buffer, address, port); err != nil {
 		return err
 	}
 
@@ -269,9 +253,9 @@ func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, po
 	return err
 }
 
-func DecodeUDPPacket(packet []byte) (*protocol.RequestHeader, []byte, error) {
-	if len(packet) < 5 {
-		return nil, nil, newError("insufficient length of packet.")
+func DecodeUDPPacket(packet *buf.Buffer) (*protocol.RequestHeader, error) {
+	if packet.Len() < 5 {
+		return nil, newError("insufficient length of packet.")
 	}
 	request := &protocol.RequestHeader{
 		Version: socks5Version,
@@ -279,50 +263,25 @@ func DecodeUDPPacket(packet []byte) (*protocol.RequestHeader, []byte, error) {
 	}
 
 	// packet[0] and packet[1] are reserved
-	if packet[2] != 0 /* fragments */ {
-		return nil, nil, newError("discarding fragmented payload.")
+	if packet.Byte(2) != 0 /* fragments */ {
+		return nil, newError("discarding fragmented payload.")
 	}
 
-	addrType := packet[3]
-	var dataBegin int
+	packet.SliceFrom(3)
 
-	switch addrType {
-	case addrTypeIPv4:
-		if len(packet) < 10 {
-			return nil, nil, newError("insufficient length of packet")
-		}
-		ip := packet[4:8]
-		request.Port = net.PortFromBytes(packet[8:10])
-		request.Address = net.IPAddress(ip)
-		dataBegin = 10
-	case addrTypeIPv6:
-		if len(packet) < 22 {
-			return nil, nil, newError("insufficient length of packet")
-		}
-		ip := packet[4:20]
-		request.Port = net.PortFromBytes(packet[20:22])
-		request.Address = net.IPAddress(ip)
-		dataBegin = 22
-	case addrTypeDomain:
-		domainLength := int(packet[4])
-		if len(packet) < 5+domainLength+2 {
-			return nil, nil, newError("insufficient length of packet")
-		}
-		domain := string(packet[5 : 5+domainLength])
-		request.Port = net.PortFromBytes(packet[5+domainLength : 5+domainLength+2])
-		request.Address = net.ParseAddress(domain)
-		dataBegin = 5 + domainLength + 2
-	default:
-		return nil, nil, newError("unknown address type ", addrType)
+	addr, port, err := addrParser.ReadAddressPort(nil, packet)
+	if err != nil {
+		return nil, newError("failed to read UDP header").Base(err)
 	}
-
-	return request, packet[dataBegin:], nil
+	request.Address = addr
+	request.Port = port
+	return request, nil
 }
 
 func EncodeUDPPacket(request *protocol.RequestHeader, data []byte) (*buf.Buffer, error) {
 	b := buf.New()
 	b.AppendBytes(0, 0, 0 /* Fragment */)
-	if err := AppendAddress(b, request.Address, request.Port); err != nil {
+	if err := addrParser.WriteAddressPort(b, request.Address, request.Port); err != nil {
 		return nil, err
 	}
 	b.Append(data)
@@ -342,12 +301,9 @@ func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 	if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil {
 		return nil, err
 	}
-	_, data, err := DecodeUDPPacket(b.Bytes())
-	if err != nil {
+	if _, err := DecodeUDPPacket(b); err != nil {
 		return nil, err
 	}
-	b.Clear()
-	b.Append(data)
 	return buf.NewMultiBufferValue(b), nil
 }
 
@@ -376,40 +332,6 @@ func (w *UDPWriter) Write(b []byte) (int, error) {
 	return len(b), nil
 }
 
-func ReadAddress(b *buf.Buffer, addrType byte, reader io.Reader) (net.Address, net.Port, error) {
-	var address net.Address
-	switch addrType {
-	case addrTypeIPv4:
-		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
-			return nil, 0, err
-		}
-		address = net.IPAddress(b.BytesFrom(-4))
-	case addrTypeIPv6:
-		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil {
-			return nil, 0, err
-		}
-		address = net.IPAddress(b.BytesFrom(-16))
-	case addrTypeDomain:
-		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
-			return nil, 0, err
-		}
-		domainLength := int(b.Byte(b.Len() - 1))
-		if err := b.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil {
-			return nil, 0, err
-		}
-		address = net.DomainAddress(string(b.BytesFrom(-domainLength)))
-	default:
-		return nil, 0, newError("unknown address type: ", addrType)
-	}
-
-	if err := b.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil {
-		return nil, 0, err
-	}
-	port := net.PortFromBytes(b.BytesFrom(-2))
-
-	return address, port, nil
-}
-
 func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {
 	authByte := byte(authNotRequired)
 	if request.User != nil {
@@ -462,7 +384,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 		command = byte(cmdUDPPort)
 	}
 	b.AppendBytes(socks5Version, command, 0x00 /* reserved */)
-	if err := AppendAddress(b, request.Address, request.Port); err != nil {
+	if err := addrParser.WriteAddressPort(b, request.Address, request.Port); err != nil {
 		return nil, err
 	}
 
@@ -471,7 +393,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 	}
 
 	b.Clear()
-	if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
+	if err := b.AppendSupplier(buf.ReadFullFrom(reader, 3)); err != nil {
 		return nil, err
 	}
 
@@ -480,11 +402,9 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 		return nil, newError("server rejects request: ", resp)
 	}
 
-	addrType := b.Byte(3)
-
 	b.Clear()
 
-	address, port, err := ReadAddress(b, addrType, reader)
+	address, port, err := addrParser.ReadAddressPort(b, reader)
 	if err != nil {
 		return nil, err
 	}

+ 0 - 54
proxy/socks/protocol_test.go

@@ -1,7 +1,6 @@
 package socks_test
 
 import (
-	"bytes"
 	"testing"
 
 	"v2ray.com/core/common/buf"
@@ -34,56 +33,3 @@ func TestUDPEncoding(t *testing.T) {
 	assert(err, IsNil)
 	assert(decodedPayload[0].Bytes(), Equals, content)
 }
-
-func TestReadAddress(t *testing.T) {
-	assert := With(t)
-
-	data := []struct {
-		AddrType byte
-		Input    []byte
-		Address  net.Address
-		Port     net.Port
-		Error    bool
-	}{
-		{
-			AddrType: 0,
-			Input:    []byte{0, 0, 0, 0},
-			Error:    true,
-		},
-		{
-			AddrType: 1,
-			Input:    []byte{0, 0, 0, 0, 0, 53},
-			Address:  net.IPAddress([]byte{0, 0, 0, 0}),
-			Port:     net.Port(53),
-		},
-		{
-			AddrType: 4,
-			Input:    []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 80},
-			Address:  net.IPAddress([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}),
-			Port:     net.Port(80),
-		},
-		{
-			AddrType: 3,
-			Input:    []byte{9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0, 80},
-			Address:  net.DomainAddress("v2ray.com"),
-			Port:     net.Port(80),
-		},
-		{
-			AddrType: 3,
-			Input:    []byte{9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0},
-			Error:    true,
-		},
-	}
-
-	for _, tc := range data {
-		b := buf.New()
-		addr, port, err := ReadAddress(b, tc.AddrType, bytes.NewBuffer(tc.Input))
-		b.Release()
-		if tc.Error {
-			assert(err, IsNotNil)
-		} else {
-			assert(addr, Equals, tc.Address)
-			assert(port, Equals, tc.Port)
-		}
-	}
-}

+ 6 - 6
proxy/socks/server.go

@@ -185,18 +185,20 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
 		}
 
 		for _, payload := range mpayload {
-			request, data, err := DecodeUDPPacket(payload.Bytes())
+			request, err := DecodeUDPPacket(payload)
 
 			if err != nil {
 				newError("failed to parse UDP request").Base(err).WithContext(ctx).WriteToLog()
+				payload.Release()
 				continue
 			}
 
-			if len(data) == 0 {
+			if payload.IsEmpty() {
+				payload.Release()
 				continue
 			}
 
-			newError("send packet to ", request.Destination(), " with ", len(data), " bytes").AtDebug().WithContext(ctx).WriteToLog()
+			newError("send packet to ", request.Destination(), " with ", payload.Len(), " bytes").AtDebug().WithContext(ctx).WriteToLog()
 			if source, ok := proxy.SourceFromContext(ctx); ok {
 				log.Record(&log.AccessMessage{
 					From:   source,
@@ -206,9 +208,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
 				})
 			}
 
-			dataBuf := buf.New()
-			dataBuf.Append(data)
-			udpServer.Dispatch(ctx, request.Destination(), dataBuf, func(payload *buf.Buffer) {
+			udpServer.Dispatch(ctx, request.Destination(), payload, func(payload *buf.Buffer) {
 				defer payload.Release()
 
 				newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug().WithContext(ctx).WriteToLog()

+ 2 - 18
proxy/vmess/encoding/client.go

@@ -15,7 +15,6 @@ import (
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/crypto"
 	"v2ray.com/core/common/dice"
-	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/serial"
 	"v2ray.com/core/proxy/vmess"
@@ -82,23 +81,8 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
 	buffer.AppendBytes(security, byte(0), byte(header.Command))
 
 	if header.Command != protocol.RequestCommandMux {
-		common.Must(buffer.AppendSupplier(serial.WriteUint16(header.Port.Value())))
-
-		switch header.Address.Family() {
-		case net.AddressFamilyIPv4:
-			buffer.AppendBytes(byte(protocol.AddressTypeIPv4))
-			buffer.Append(header.Address.IP())
-		case net.AddressFamilyIPv6:
-			buffer.AppendBytes(byte(protocol.AddressTypeIPv6))
-			buffer.Append(header.Address.IP())
-		case net.AddressFamilyDomain:
-			domain := header.Address.Domain()
-			if protocol.IsDomainTooLong(domain) {
-				return newError("long domain not supported: ", domain)
-			}
-			nDomain := len(domain)
-			buffer.AppendBytes(byte(protocol.AddressTypeDomain), byte(nDomain))
-			common.Must(buffer.AppendSupplier(serial.WriteString(domain)))
+		if err := addrParser.WriteAddressPort(buffer, header.Address, header.Port); err != nil {
+			return newError("failed to writer address and port").Base(err)
 		}
 	}
 

+ 0 - 5
proxy/vmess/encoding/const.go

@@ -1,5 +0,0 @@
-package encoding
-
-const (
-	Version = byte(1)
-)

+ 16 - 0
proxy/vmess/encoding/encoding.go

@@ -1,3 +1,19 @@
 package encoding
 
+import (
+	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/protocol"
+)
+
 //go:generate go run $GOPATH/src/v2ray.com/core/common/errors/errorgen/main.go -pkg encoding -path Proxy,VMess,Encoding
+
+const (
+	Version = byte(1)
+)
+
+var addrParser = protocol.NewAddressParser(
+	protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
+	protocol.AddressFamilyByte(0x02, net.AddressFamilyDomain),
+	protocol.AddressFamilyByte(0x03, net.AddressFamilyIPv6),
+	protocol.PortThenAddress(),
+)

+ 1 - 39
proxy/vmess/encoding/server.go

@@ -105,44 +105,6 @@ func NewServerSession(validator protocol.UserValidator, sessionHistory *SessionH
 	}
 }
 
-func readAddress(buffer *buf.Buffer, reader io.Reader) (net.Address, net.Port, error) {
-	var address net.Address
-	var port net.Port
-	if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 3)); err != nil {
-		return address, port, newError("failed to read port and address type").Base(err)
-	}
-	port = net.PortFromBytes(buffer.BytesRange(-3, -1))
-
-	addressType := protocol.AddressType(buffer.Byte(buffer.Len() - 1))
-	switch addressType {
-	case protocol.AddressTypeIPv4:
-		if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
-			return address, port, newError("failed to read IPv4 address").Base(err)
-		}
-		address = net.IPAddress(buffer.BytesFrom(-4))
-	case protocol.AddressTypeIPv6:
-		if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil {
-			return address, port, newError("failed to read IPv6 address").Base(err)
-		}
-		address = net.IPAddress(buffer.BytesFrom(-16))
-	case protocol.AddressTypeDomain:
-		if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
-			return address, port, newError("failed to read domain address").Base(err)
-		}
-		domainLength := int(buffer.Byte(buffer.Len() - 1))
-		if domainLength == 0 {
-			return address, port, newError("zero length domain")
-		}
-		if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil {
-			return address, port, newError("failed to read domain address").Base(err)
-		}
-		address = net.DomainAddress(string(buffer.BytesFrom(-domainLength)))
-	default:
-		return address, port, newError("invalid address type", addressType)
-	}
-	return address, port, nil
-}
-
 func parseSecurityType(b byte) protocol.SecurityType {
 	if _, f := protocol.SecurityType_name[int32(b)]; f {
 		return protocol.SecurityType(b)
@@ -221,7 +183,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 		request.Address = net.DomainAddress("v1.mux.cool")
 		request.Port = 0
 	case protocol.RequestCommandTCP, protocol.RequestCommandUDP:
-		if addr, port, err := readAddress(buffer, decryptor); err == nil {
+		if addr, port, err := addrParser.ReadAddressPort(buffer, decryptor); err == nil {
 			request.Address = addr
 			request.Port = port
 		} else {