Browse Source

rewrite connection interface

Shelikhoo 3 years ago
parent
commit
cdfef7e66b

+ 19 - 10
common/net/packetaddr/connection_adaptor.go

@@ -69,16 +69,17 @@ func (c *packetConnectionAdaptor) ReadFrom(p []byte) (n int, addr gonet.Addr, er
 		}
 		}
 	}
 	}
 	c.readerBuffer, n = buf.SplitFirstBytes(c.readerBuffer, p)
 	c.readerBuffer, n = buf.SplitFirstBytes(c.readerBuffer, p)
-	var w []byte
-	w, addr = ExtractAddressFromPacket(p[:n])
-	n = copy(p, w)
+	var w *buf.Buffer
+	w, addr, err = ExtractAddressFromPacket(buf.FromBytes(p[:n]))
+	n = copy(p, w.Bytes())
+	w.Release()
 	return
 	return
 }
 }
 
 
 func (c *packetConnectionAdaptor) WriteTo(p []byte, addr gonet.Addr) (n int, err error) {
 func (c *packetConnectionAdaptor) WriteTo(p []byte, addr gonet.Addr) (n int, err error) {
 	payloadLen := len(p)
 	payloadLen := len(p)
-	p = AttachAddressToPacket(p, addr)
-	buffer := buf.FromBytes(p)
+	var buffer *buf.Buffer
+	buffer, err = AttachAddressToPacket(buf.FromBytes(p), addr)
 	mb := buf.MultiBuffer{buffer}
 	mb := buf.MultiBuffer{buffer}
 	err = c.link.Writer.WriteMultiBuffer(mb)
 	err = c.link.Writer.WriteMultiBuffer(mb)
 	if err != nil {
 	if err != nil {
@@ -134,18 +135,26 @@ func (pc *packetConnWrapper) Read(p []byte) (n int, err error) {
 	if err != nil {
 	if err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
-	result := AttachAddressToPacket(recbuf.Bytes()[0:n], addr)
-	n = copy(p, result)
-	recbuf.Release()
+	recbuf.Resize(0, int32(n))
+	result, err := AttachAddressToPacket(&recbuf, addr)
+	if err != nil {
+		return 0, err
+	}
+	n = copy(p, result.Bytes())
+	result.Release()
 	return n, nil
 	return n, nil
 }
 }
 
 
 func (pc *packetConnWrapper) Write(p []byte) (n int, err error) {
 func (pc *packetConnWrapper) Write(p []byte) (n int, err error) {
-	data, addr := ExtractAddressFromPacket(p)
-	_, err = pc.PacketConn.WriteTo(data, addr)
+	data, addr, err := ExtractAddressFromPacket(buf.FromBytes(p))
+	if err != nil {
+		return 0, err
+	}
+	_, err = pc.PacketConn.WriteTo(data.Bytes(), addr)
 	if err != nil {
 	if err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
+	data.Release()
 	return len(p), nil
 	return len(p), nil
 }
 }
 
 

+ 22 - 14
common/net/packetaddr/packetaddr.go

@@ -13,35 +13,43 @@ var addrParser = protocol.NewAddressParser(
 	protocol.AddressFamilyByte(0x02, net.AddressFamilyIPv6),
 	protocol.AddressFamilyByte(0x02, net.AddressFamilyIPv6),
 )
 )
 
 
-func AttachAddressToPacket(data []byte, address gonet.Addr) []byte {
-	packetBuf := buf.StackNew()
+// AttachAddressToPacket
+// relinquish ownership of data
+// gain ownership of the returning value
+func AttachAddressToPacket(data *buf.Buffer, address gonet.Addr) (*buf.Buffer, error) {
+	packetBuf := buf.New()
 	udpaddr := address.(*gonet.UDPAddr)
 	udpaddr := address.(*gonet.UDPAddr)
 	port, err := net.PortFromInt(uint32(udpaddr.Port))
 	port, err := net.PortFromInt(uint32(udpaddr.Port))
 	if err != nil {
 	if err != nil {
-		panic(err)
+		return nil, err
+	}
+	err = addrParser.WriteAddressPort(packetBuf, net.IPAddress(udpaddr.IP), port)
+	if err != nil {
+		return nil, err
 	}
 	}
-	err = addrParser.WriteAddressPort(&packetBuf, net.IPAddress(udpaddr.IP), port)
+	_, err = packetBuf.Write(data.Bytes())
 	if err != nil {
 	if err != nil {
-		panic(err)
+		return nil, err
 	}
 	}
-	//Incorrect buffer reuse
-	data = append(packetBuf.Bytes(), data...)
-	//packetBuf.Release()
-	return data
+	data.Release()
+	return packetBuf, nil
 }
 }
 
 
-func ExtractAddressFromPacket(data []byte) ([]byte, gonet.Addr) {
+// ExtractAddressFromPacket
+// relinquish ownership of data
+// gain ownership of the returning value
+func ExtractAddressFromPacket(data *buf.Buffer) (*buf.Buffer, gonet.Addr, error) {
 	packetBuf := buf.StackNew()
 	packetBuf := buf.StackNew()
-	address, port, err := addrParser.ReadAddressPort(&packetBuf, bytes.NewReader(data))
+	address, port, err := addrParser.ReadAddressPort(&packetBuf, bytes.NewReader(data.Bytes()))
 	if err != nil {
 	if err != nil {
-		panic(err)
+		return nil, nil, err
 	}
 	}
 	var addr = &gonet.UDPAddr{
 	var addr = &gonet.UDPAddr{
 		IP:   address.IP(),
 		IP:   address.IP(),
 		Port: int(port.Value()),
 		Port: int(port.Value()),
 		Zone: "",
 		Zone: "",
 	}
 	}
-	payload := data[int(packetBuf.Len()):]
+	data.Advance(packetBuf.Len())
 	packetBuf.Release()
 	packetBuf.Release()
-	return payload, addr
+	return data, addr, nil
 }
 }

+ 11 - 6
common/net/packetaddr/packetaddr_test.go

@@ -2,6 +2,7 @@ package packetaddr
 
 
 import (
 import (
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
+	"github.com/v2fly/v2ray-core/v4/common/buf"
 	sysnet "net"
 	sysnet "net"
 	"testing"
 	"testing"
 )
 )
@@ -12,11 +13,13 @@ func TestPacketEncodingIPv4(t *testing.T) {
 		Port: 1234,
 		Port: 1234,
 	}
 	}
 	var packetData [256]byte
 	var packetData [256]byte
-	wrapped := AttachAddressToPacket(packetData[:], packetAddress)
+	wrapped, err := AttachAddressToPacket(buf.FromBytes(packetData[:]), packetAddress)
+	assert.NoError(t, err)
 
 
-	packetPayload, decodedAddress := ExtractAddressFromPacket(wrapped)
+	packetPayload, decodedAddress, err := ExtractAddressFromPacket(wrapped)
+	assert.NoError(t, err)
 
 
-	assert.Equal(t, packetPayload, packetData[:])
+	assert.Equal(t, packetPayload.Bytes(), packetData[:])
 	assert.Equal(t, packetAddress, decodedAddress)
 	assert.Equal(t, packetAddress, decodedAddress)
 }
 }
 
 
@@ -26,10 +29,12 @@ func TestPacketEncodingIPv6(t *testing.T) {
 		Port: 1234,
 		Port: 1234,
 	}
 	}
 	var packetData [256]byte
 	var packetData [256]byte
-	wrapped := AttachAddressToPacket(packetData[:], packetAddress)
+	wrapped, err := AttachAddressToPacket(buf.FromBytes(packetData[:]), packetAddress)
+	assert.NoError(t, err)
 
 
-	packetPayload, decodedAddress := ExtractAddressFromPacket(wrapped)
+	packetPayload, decodedAddress, err := ExtractAddressFromPacket(wrapped)
+	assert.NoError(t, err)
 
 
-	assert.Equal(t, packetPayload, packetData[:])
+	assert.Equal(t, packetPayload.Bytes(), packetData[:])
 	assert.Equal(t, packetAddress, decodedAddress)
 	assert.Equal(t, packetAddress, decodedAddress)
 }
 }