Browse Source

improve address serialization performance

Darien Raymond 7 years ago
parent
commit
16102271dd
2 changed files with 240 additions and 93 deletions
  1. 135 93
      common/protocol/address.go
  2. 105 0
      common/protocol/address_test.go

+ 135 - 93
common/protocol/address.go

@@ -7,19 +7,21 @@ import (
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/serial"
-	"v2ray.com/core/common/task"
 )
 
-type AddressOption func(*AddressParser)
+type AddressOption func(*option)
 
 func PortThenAddress() AddressOption {
-	return func(p *AddressParser) {
+	return func(p *option) {
 		p.portFirst = true
 	}
 }
 
 func AddressFamilyByte(b byte, f net.AddressFamily) AddressOption {
-	return func(p *AddressParser) {
+	if b >= 16 {
+		panic("address family byte too big")
+	}
+	return func(p *option) {
 		p.addrTypeMap[b] = f
 		p.addrByteMap[f] = b
 	}
@@ -28,38 +30,127 @@ func AddressFamilyByte(b byte, f net.AddressFamily) AddressOption {
 type AddressTypeParser func(byte) byte
 
 func WithAddressTypeParser(atp AddressTypeParser) AddressOption {
-	return func(p *AddressParser) {
+	return func(p *option) {
 		p.typeParser = atp
 	}
 }
 
-// AddressParser is a utility for reading and writer addresses.
-type AddressParser struct {
-	addrTypeMap map[byte]net.AddressFamily
-	addrByteMap map[net.AddressFamily]byte
+type AddressSerializer interface {
+	ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error)
+
+	WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error
+}
+
+const afInvalid = 255
+
+type option struct {
+	addrTypeMap [16]net.AddressFamily
+	addrByteMap [16]byte
 	portFirst   bool
 	typeParser  AddressTypeParser
 }
 
 // NewAddressParser creates a new AddressParser
-func NewAddressParser(options ...AddressOption) *AddressParser {
-	p := &AddressParser{
-		addrTypeMap: make(map[byte]net.AddressFamily, 8),
-		addrByteMap: make(map[net.AddressFamily]byte, 8),
+func NewAddressParser(options ...AddressOption) AddressSerializer {
+	var o option
+	for i := range o.addrByteMap {
+		o.addrByteMap[i] = afInvalid
+	}
+	for i := range o.addrTypeMap {
+		o.addrTypeMap[i] = net.AddressFamily(afInvalid)
 	}
 	for _, opt := range options {
-		opt(p)
+		opt(&o)
+	}
+
+	ap := &addressParser{
+		addrByteMap: o.addrByteMap,
+		addrTypeMap: o.addrTypeMap,
+	}
+
+	if o.typeParser != nil {
+		ap.typeParser = o.typeParser
+	}
+
+	if o.portFirst {
+		return portFirstAddressParser{ap: ap}
+	}
+
+	return portLastAddressParser{ap: ap}
+}
+
+type portFirstAddressParser struct {
+	ap *addressParser
+}
+
+func (p portFirstAddressParser) ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error) {
+	if buffer == nil {
+		buffer = buf.New()
+		defer buffer.Release()
+	}
+
+	port, err := readPort(buffer, input)
+	if err != nil {
+		return nil, 0, err
+	}
+
+	addr, err := p.ap.readAddress(buffer, input)
+	if err != nil {
+		return nil, 0, err
+	}
+	return addr, port, nil
+}
+
+func (p portFirstAddressParser) WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error {
+	if err := writePort(writer, port); err != nil {
+		return err
+	}
+
+	return p.ap.writeAddress(writer, addr)
+}
+
+type portLastAddressParser struct {
+	ap *addressParser
+}
+
+func (p portLastAddressParser) ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error) {
+	if buffer == nil {
+		buffer = buf.New()
+		defer buffer.Release()
+	}
+
+	addr, err := p.ap.readAddress(buffer, input)
+	if err != nil {
+		return nil, 0, err
+	}
+
+	port, err := readPort(buffer, input)
+	if err != nil {
+		return nil, 0, err
+	}
+
+	return addr, port, nil
+}
+
+func (p portLastAddressParser) WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error {
+	if err := p.ap.writeAddress(writer, addr); err != nil {
+		return err
 	}
-	return p
+
+	return writePort(writer, port)
 }
 
-func (p *AddressParser) readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) {
+func readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) {
 	if _, err := b.ReadFullFrom(reader, 2); err != nil {
 		return 0, err
 	}
 	return net.PortFromBytes(b.BytesFrom(-2)), nil
 }
 
+func writePort(writer io.Writer, port net.Port) error {
+	return common.Error2(serial.WriteUint16(writer, port.Value()))
+}
+
 func maybeIPPrefix(b byte) bool {
 	return b == '[' || (b >= '0' && b <= '9')
 }
@@ -73,7 +164,13 @@ func isValidDomain(d string) bool {
 	return true
 }
 
-func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) {
+type addressParser struct {
+	addrTypeMap [16]net.AddressFamily
+	addrByteMap [16]byte
+	typeParser  AddressTypeParser
+}
+
+func (p *addressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) {
 	if _, err := b.ReadFullFrom(reader, 1); err != nil {
 		return nil, err
 	}
@@ -83,8 +180,12 @@ func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Addres
 		addrType = p.typeParser(addrType)
 	}
 
-	addrFamily, valid := p.addrTypeMap[addrType]
-	if !valid {
+	if addrType >= 16 {
+		return nil, newError("unknown address type: ", addrType)
+	}
+
+	addrFamily := p.addrTypeMap[addrType]
+	if addrFamily == net.AddressFamily(afInvalid) {
 		return nil, newError("unknown address type: ", addrType)
 	}
 
@@ -123,93 +224,34 @@ func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Addres
 	}
 }
 
-// ReadAddressPort reads address and port from the given input.
-func (p *AddressParser) ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error) {
-	if buffer == nil {
-		buffer = buf.New()
-		defer buffer.Release()
+func (p *addressParser) writeAddress(writer io.Writer, address net.Address) error {
+	tb := p.addrByteMap[address.Family()]
+	if tb == afInvalid {
+		return newError("unknown address family", address.Family())
 	}
 
-	var addr net.Address
-	var port net.Port
-
-	pTask := func() error {
-		lp, err := p.readPort(buffer, input)
-		if err != nil {
+	switch address.Family() {
+	case net.AddressFamilyIPv4, net.AddressFamilyIPv6:
+		if _, err := writer.Write([]byte{tb}); err != nil {
 			return err
 		}
-		port = lp
-		return nil
-	}
-
-	aTask := func() error {
-		a, err := p.readAddress(buffer, input)
-		if err != nil {
+		if _, err := writer.Write(address.IP()); err != nil {
 			return err
 		}
-		addr = a
-		return nil
-	}
-
-	var err error
-
-	if p.portFirst {
-		err = task.Run(task.Sequential(pTask, aTask))()
-	} else {
-		err = task.Run(task.Sequential(aTask, pTask))()
-	}
-
-	if err != nil {
-		return nil, 0, err
-	}
-
-	return addr, port, nil
-}
-
-func (p *AddressParser) writePort(writer io.Writer, port net.Port) error {
-	return common.Error2(serial.WriteUint16(writer, port.Value()))
-}
-
-func (p *AddressParser) writeAddress(writer io.Writer, address net.Address) error {
-	tb, valid := p.addrByteMap[address.Family()]
-	if !valid {
-		return newError("unknown address family", address.Family())
-	}
-
-	switch address.Family() {
-	case net.AddressFamilyIPv4, net.AddressFamilyIPv6:
-		return task.Run(task.Sequential(func() error {
-			return common.Error2(writer.Write([]byte{tb}))
-		}, func() error {
-			return common.Error2(writer.Write(address.IP()))
-		}))()
 	case net.AddressFamilyDomain:
 		domain := address.Domain()
 		if isDomainTooLong(domain) {
 			return newError("Super long domain is not supported: ", domain)
 		}
-		return task.Run(task.Sequential(func() error {
-			return common.Error2(writer.Write([]byte{tb, byte(len(domain))}))
-		}, func() error {
-			return common.Error2(writer.Write([]byte(domain)))
-		}))()
+		if _, err := writer.Write([]byte{tb, byte(len(domain))}); err != nil {
+			return err
+		}
+		if _, err := writer.Write([]byte(domain)); err != nil {
+			return err
+		}
 	default:
 		panic("Unknown family type.")
 	}
-}
-
-// WriteAddressPort writes address and port into the given writer.
-func (p *AddressParser) WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error {
-	pTask := func() error {
-		return p.writePort(writer, port)
-	}
-	aTask := func() error {
-		return p.writeAddress(writer, addr)
-	}
-
-	if p.portFirst {
-		return task.Run(task.Sequential(pTask, aTask))()
-	}
 
-	return task.Run(task.Sequential(aTask, pTask))()
+	return nil
 }

+ 105 - 0
common/protocol/address_test.go

@@ -36,6 +36,12 @@ func TestAddressReading(t *testing.T) {
 			Port:    net.Port(53),
 		},
 		{
+			Options: []AddressOption{AddressFamilyByte(0x01, net.AddressFamilyIPv4), PortThenAddress()},
+			Input:   []byte{0, 53, 1, 0, 0, 0, 0},
+			Address: net.IPAddress([]byte{0, 0, 0, 0}),
+			Port:    net.Port(53),
+		},
+		{
 			Options: []AddressOption{AddressFamilyByte(0x01, net.AddressFamilyIPv4)},
 			Input:   []byte{1, 0, 0, 0, 0},
 			Error:   true,
@@ -134,3 +140,102 @@ func TestAddressWriting(t *testing.T) {
 		}
 	}
 }
+
+func BenchmarkAddressReadingIPv4(b *testing.B) {
+	parser := NewAddressParser(AddressFamilyByte(0x01, net.AddressFamilyIPv4))
+	cache := buf.New()
+	defer cache.Release()
+
+	payload := buf.New()
+	defer payload.Release()
+
+	raw := []byte{1, 0, 0, 0, 0, 0, 53}
+	payload.Write(raw)
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		_, _, err := parser.ReadAddressPort(cache, payload)
+		common.Must(err)
+		cache.Clear()
+		payload.Clear()
+		payload.Extend(int32(len(raw)))
+	}
+}
+
+func BenchmarkAddressReadingIPv6(b *testing.B) {
+	parser := NewAddressParser(AddressFamilyByte(0x04, net.AddressFamilyIPv6))
+	cache := buf.New()
+	defer cache.Release()
+
+	payload := buf.New()
+	defer payload.Release()
+
+	raw := []byte{4, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 80}
+	payload.Write(raw)
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		_, _, err := parser.ReadAddressPort(cache, payload)
+		common.Must(err)
+		cache.Clear()
+		payload.Clear()
+		payload.Extend(int32(len(raw)))
+	}
+}
+
+func BenchmarkAddressReadingDomain(b *testing.B) {
+	parser := NewAddressParser(AddressFamilyByte(0x03, net.AddressFamilyDomain))
+	cache := buf.New()
+	defer cache.Release()
+
+	payload := buf.New()
+	defer payload.Release()
+
+	raw := []byte{3, 9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0, 80}
+	payload.Write(raw)
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		_, _, err := parser.ReadAddressPort(cache, payload)
+		common.Must(err)
+		cache.Clear()
+		payload.Clear()
+		payload.Extend(int32(len(raw)))
+	}
+}
+
+func BenchmarkAddressWritingIPv4(b *testing.B) {
+	parser := NewAddressParser(AddressFamilyByte(0x01, net.AddressFamilyIPv4))
+	writer := buf.New()
+	defer writer.Release()
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		common.Must(parser.WriteAddressPort(writer, net.LocalHostIP, net.Port(80)))
+		writer.Clear()
+	}
+}
+
+func BenchmarkAddressWritingIPv6(b *testing.B) {
+	parser := NewAddressParser(AddressFamilyByte(0x04, net.AddressFamilyIPv6))
+	writer := buf.New()
+	defer writer.Release()
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		common.Must(parser.WriteAddressPort(writer, net.LocalHostIPv6, net.Port(80)))
+		writer.Clear()
+	}
+}
+
+func BenchmarkAddressWritingDomain(b *testing.B) {
+	parser := NewAddressParser(AddressFamilyByte(0x02, net.AddressFamilyDomain))
+	writer := buf.New()
+	defer writer.Release()
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		common.Must(parser.WriteAddressPort(writer, net.DomainAddress("www.v2ray.com"), net.Port(80)))
+		writer.Clear()
+	}
+}