Selaa lähdekoodia

simplify udp handling

V2Ray 10 vuotta sitten
vanhempi
commit
76563cb7c7

+ 14 - 40
common/net/packet.go

@@ -6,54 +6,28 @@ type Packet interface {
 	MoreChunks() bool
 }
 
-func NewTCPPacket(dest Destination) *TCPPacket {
-	return &TCPPacket{
-		basePacket: basePacket{destination: dest},
+func NewPacket(dest Destination, firstChunk []byte, moreChunks bool) Packet {
+	return &packetImpl{
+		dest:     dest,
+		data:     firstChunk,
+		moreData: moreChunks,
 	}
 }
 
-func NewUDPPacket(dest Destination, data []byte, token uint16) *UDPPacket {
-	return &UDPPacket{
-		basePacket: basePacket{destination: dest},
-		data:       data,
-		token:      token,
-	}
-}
-
-type basePacket struct {
-	destination Destination
-}
-
-func (base basePacket) Destination() Destination {
-	return base.destination
-}
-
-type TCPPacket struct {
-	basePacket
-}
-
-func (packet *TCPPacket) Chunk() []byte {
-	return nil
-}
-
-func (packet *TCPPacket) MoreChunks() bool {
-	return true
-}
-
-type UDPPacket struct {
-	basePacket
-	data  []byte
-	token uint16
+type packetImpl struct {
+	dest     Destination
+	data     []byte
+	moreData bool
 }
 
-func (packet *UDPPacket) Token() uint16 {
-	return packet.token
+func (packet *packetImpl) Destination() Destination {
+	return packet.dest
 }
 
-func (packet *UDPPacket) Chunk() []byte {
+func (packet *packetImpl) Chunk() []byte {
 	return packet.data
 }
 
-func (packet *UDPPacket) MoreChunks() bool {
-	return false
+func (packet *packetImpl) MoreChunks() bool {
+	return packet.moreData
 }

+ 0 - 4
point.go

@@ -96,7 +96,3 @@ func (p *Point) DispatchToOutbound(packet v2net.Packet) InboundRay {
 	_ = och.Start(ray)
 	return ray
 }
-
-func (p *Point) DispatchToInbound(packet v2net.Packet) {
-	return
-}

+ 9 - 8
proxy/freedom/freedom.go

@@ -29,6 +29,12 @@ func (vconn *FreedomConnection) Start(ray core.OutboundRay) error {
 		return log.Error("Freedom: Failed to open connection: %s : %v", vconn.packet.Destination().String(), err)
 	}
 
+	input := ray.OutboundInput()
+	output := ray.OutboundOutput()
+	var readMutex, writeMutex sync.Mutex
+	readMutex.Lock()
+	writeMutex.Lock()
+
 	if chunk := vconn.packet.Chunk(); chunk != nil {
 		conn.Write(chunk)
 	}
@@ -37,16 +43,11 @@ func (vconn *FreedomConnection) Start(ray core.OutboundRay) error {
 		if ray != nil {
 			close(ray.OutboundOutput())
 		}
-		return nil
+		writeMutex.Unlock()
+	} else {
+		go dumpInput(conn, input, &writeMutex)
 	}
 
-	input := ray.OutboundInput()
-	output := ray.OutboundOutput()
-	var readMutex, writeMutex sync.Mutex
-	readMutex.Lock()
-	writeMutex.Lock()
-
-	go dumpInput(conn, input, &writeMutex)
 	go dumpOutput(conn, output, &readMutex)
 
 	go func() {

+ 1 - 1
proxy/socks/socks.go

@@ -168,7 +168,7 @@ func (server *SocksServer) HandleConnection(connection net.Conn) error {
 		dest = request.Destination()
 	}
 
-	ray := server.vPoint.DispatchToOutbound(v2net.NewTCPPacket(dest))
+	ray := server.vPoint.DispatchToOutbound(v2net.NewPacket(dest, nil, true))
 	input := ray.InboundInput()
 	output := ray.InboundOutput()
 	var readFinish, writeFinish sync.Mutex

+ 8 - 76
proxy/socks/udp.go

@@ -1,12 +1,8 @@
 package socks
 
 import (
-	"math"
-	"math/rand"
 	"net"
-	"sync"
 
-	"github.com/v2ray/v2ray-core/common/collect"
 	"github.com/v2ray/v2ray-core/common/log"
 	v2net "github.com/v2ray/v2ray-core/common/net"
 	"github.com/v2ray/v2ray-core/proxy/socks/protocol"
@@ -16,66 +12,7 @@ const (
 	bufferSize = 2 * 1024
 )
 
-type portMap struct {
-	access       sync.Mutex
-	data         map[uint16]*net.UDPAddr
-	removedPorts *collect.TimedQueue
-}
-
-func newPortMap() *portMap {
-	m := &portMap{
-		access:       sync.Mutex{},
-		data:         make(map[uint16]*net.UDPAddr),
-		removedPorts: collect.NewTimedQueue(1),
-	}
-	go m.removePorts(m.removedPorts.RemovedEntries())
-	return m
-}
-
-func (m *portMap) assignAddressToken(addr *net.UDPAddr) uint16 {
-	for {
-		token := uint16(rand.Intn(math.MaxUint16))
-		if _, used := m.data[token]; !used {
-			m.access.Lock()
-			if _, used = m.data[token]; !used {
-				m.data[token] = addr
-				m.access.Unlock()
-				return token
-			}
-			m.access.Unlock()
-		}
-	}
-}
-
-func (m *portMap) removePorts(removedPorts <-chan interface{}) {
-	for {
-		rawToken := <-removedPorts
-		m.access.Lock()
-		delete(m.data, rawToken.(uint16))
-		m.access.Unlock()
-	}
-}
-
-func (m *portMap) popPort(token uint16) *net.UDPAddr {
-	m.access.Lock()
-	defer m.access.Unlock()
-	addr, exists := m.data[token]
-	if !exists {
-		return nil
-	}
-	delete(m.data, token)
-	return addr
-}
-
-var (
-	ports *portMap
-
-	udpConn *net.UDPConn
-)
-
 func (server *SocksServer) ListenUDP(port uint16) error {
-	ports = newPortMap()
-
 	addr := &net.UDPAddr{
 		IP:   net.IP{0, 0, 0, 0},
 		Port: int(port),
@@ -88,7 +25,6 @@ func (server *SocksServer) ListenUDP(port uint16) error {
 	}
 
 	go server.AcceptPackets(conn)
-	udpConn = conn
 	return nil
 }
 
@@ -110,20 +46,16 @@ func (server *SocksServer) AcceptPackets(conn *net.UDPConn) error {
 			continue
 		}
 
-		token := ports.assignAddressToken(addr)
-
-		udpPacket := v2net.NewUDPPacket(request.Destination(), request.Data, token)
-		server.vPoint.DispatchToOutbound(udpPacket)
+		udpPacket := v2net.NewPacket(request.Destination(), request.Data, false)
+		go server.handlePacket(conn, udpPacket, addr)
 	}
 }
 
-func (server *SocksServer) Dispatch(packet v2net.Packet) {
-	if udpPacket, ok := packet.(*v2net.UDPPacket); ok {
-		token := udpPacket.Token()
-		addr := ports.popPort(token)
-		if udpConn != nil {
-			udpConn.WriteToUDP(udpPacket.Chunk(), addr)
-		}
+func (server *SocksServer) handlePacket(conn *net.UDPConn, packet v2net.Packet, clientAddr *net.UDPAddr) {
+	ray := server.vPoint.DispatchToOutbound(packet)
+	close(ray.InboundInput())
+
+	if data, ok := <-ray.InboundOutput(); ok {
+		conn.WriteToUDP(data, clientAddr)
 	}
-	// We don't expect TCP Packets here
 }

+ 1 - 1
proxy/vmess/vmess_test.go

@@ -69,7 +69,7 @@ func TestVMessInAndOut(t *testing.T) {
 	assert.Error(err).IsNil()
 
 	dest := v2net.NewTCPDestination(v2net.IPAddress([]byte{1, 2, 3, 4}, 80))
-	ich.Communicate(v2net.NewTCPPacket(dest))
+	ich.Communicate(v2net.NewPacket(dest, nil, true))
 	assert.Bytes([]byte(data2Send)).Equals(och.Data2Send.Bytes())
 	assert.Bytes(ich.DataReturned.Bytes()).Equals(och.Data2Return)
 }

+ 1 - 1
proxy/vmess/vmessin.go

@@ -72,7 +72,7 @@ func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error
 	}
 	log.Debug("VMessIn: Received request for %s", request.Address.String())
 
-	ray := handler.vPoint.DispatchToOutbound(v2net.NewTCPPacket(request.Destination()))
+	ray := handler.vPoint.DispatchToOutbound(v2net.NewPacket(request.Destination(), nil, true))
 	input := ray.InboundInput()
 	output := ray.InboundOutput()
 	var readFinish, writeFinish sync.Mutex

+ 14 - 17
proxy/vmess/vmessout.go

@@ -97,24 +97,13 @@ func startCommunicate(request *protocol.VMessRequest, dest v2net.Destination, ra
 
 	defer conn.Close()
 
-	if chunk := firstPacket.Chunk(); chunk != nil {
-		conn.Write(chunk)
-	}
-
-	if !firstPacket.MoreChunks() {
-		if ray != nil {
-			close(ray.OutboundOutput())
-		}
-		return nil
-	}
-
 	input := ray.OutboundInput()
 	output := ray.OutboundOutput()
 	var requestFinish, responseFinish sync.Mutex
 	requestFinish.Lock()
 	responseFinish.Lock()
 
-	go handleRequest(conn, request, input, &requestFinish)
+	go handleRequest(conn, request, firstPacket, input, &requestFinish)
 	go handleResponse(conn, request, output, &responseFinish)
 
 	requestFinish.Lock()
@@ -123,7 +112,7 @@ func startCommunicate(request *protocol.VMessRequest, dest v2net.Destination, ra
 	return nil
 }
 
-func handleRequest(conn *net.TCPConn, request *protocol.VMessRequest, input <-chan []byte, finish *sync.Mutex) {
+func handleRequest(conn *net.TCPConn, request *protocol.VMessRequest, firstPacket v2net.Packet, input <-chan []byte, finish *sync.Mutex) {
 	defer finish.Unlock()
 	encryptRequestWriter, err := v2io.NewAesEncryptWriter(request.RequestKey[:], request.RequestIV[:], conn)
 	if err != nil {
@@ -139,17 +128,25 @@ func handleRequest(conn *net.TCPConn, request *protocol.VMessRequest, input <-ch
 	}
 
 	// Send first packet of payload together with request, in favor of small requests.
-	payload, open := <-input
-	if open {
-		encryptRequestWriter.Crypt(payload)
-		buffer = append(buffer, payload...)
+	firstChunk := firstPacket.Chunk()
+	moreChunks := firstPacket.MoreChunks()
+
+	if firstChunk == nil && moreChunks {
+		firstChunk, moreChunks = <-input
+	}
+
+	if firstChunk != nil {
+		encryptRequestWriter.Crypt(firstChunk)
+		buffer = append(buffer, firstChunk...)
 
 		_, err = conn.Write(buffer)
 		if err != nil {
 			log.Error("VMessOut: Failed to write VMess request: %v", err)
 			return
 		}
+	}
 
+	if moreChunks {
 		v2net.ChanToWriter(encryptRequestWriter, input)
 	}
 	return