Pārlūkot izejas kodu

fix udp issue in socks proxy

Darien Raymond 9 gadi atpakaļ
vecāks
revīzija
0d77139c24

+ 1 - 1
proxy/freedom/freedom.go

@@ -1,7 +1,7 @@
 package freedom
 
 import (
-  "io"
+	"io"
 	"net"
 	"sync"
 

+ 1 - 0
proxy/socks/socks.go

@@ -32,6 +32,7 @@ type SocksServer struct {
 	tcpListener      *hub.TCPHub
 	udpConn          *net.UDPConn
 	udpAddress       v2net.Destination
+	udpServer        *hub.UDPServer
 	listeningPort    v2net.Port
 }
 

+ 31 - 31
proxy/socks/udp.go

@@ -7,6 +7,7 @@ import (
 	"github.com/v2ray/v2ray-core/common/log"
 	v2net "github.com/v2ray/v2ray-core/common/net"
 	"github.com/v2ray/v2ray-core/proxy/socks/protocol"
+	"github.com/v2ray/v2ray-core/transport/hub"
 )
 
 func (this *SocksServer) ListenUDP(port v2net.Port) error {
@@ -23,6 +24,7 @@ func (this *SocksServer) ListenUDP(port v2net.Port) error {
 	this.udpMutex.Lock()
 	this.udpAddress = v2net.UDPDestination(this.config.Address, port)
 	this.udpConn = conn
+	this.udpServer = hub.NewUDPServer(this.packetDispatcher)
 	this.udpMutex.Unlock()
 
 	go this.AcceptPackets()
@@ -63,38 +65,36 @@ func (this *SocksServer) AcceptPackets() error {
 
 		udpPacket := v2net.NewPacket(request.Destination(), request.Data, false)
 		log.Info("Socks: Send packet to ", udpPacket.Destination(), " with ", request.Data.Len(), " bytes")
-		go this.handlePacket(udpPacket, addr, request.Address, request.Port)
-	}
-	return nil
-}
+		this.udpServer.Dispatch(
+			v2net.UDPDestination(v2net.IPAddress(addr.IP), v2net.Port(addr.Port)), udpPacket,
+			func(packet v2net.Packet) {
+				response := &protocol.Socks5UDPRequest{
+					Fragment: 0,
+					Address:  udpPacket.Destination().Address(),
+					Port:     udpPacket.Destination().Port(),
+					Data:     packet.Chunk(),
+				}
+				log.Info("Socks: Writing back UDP response with ", response.Data.Len(), " bytes to ", packet.Destination())
 
-func (this *SocksServer) handlePacket(packet v2net.Packet, clientAddr *net.UDPAddr, targetAddr v2net.Address, port v2net.Port) {
-	ray := this.packetDispatcher.DispatchToOutbound(packet)
-	close(ray.InboundInput())
+				udpMessage := alloc.NewSmallBuffer().Clear()
+				response.Write(udpMessage)
 
-	for data := range ray.InboundOutput() {
-		response := &protocol.Socks5UDPRequest{
-			Fragment: 0,
-			Address:  targetAddr,
-			Port:     port,
-			Data:     data,
-		}
-		log.Info("Socks: Writing back UDP response with ", data.Len(), " bytes from ", targetAddr, " to ", clientAddr)
-
-		udpMessage := alloc.NewSmallBuffer().Clear()
-		response.Write(udpMessage)
-
-		this.udpMutex.RLock()
-		if !this.accepting {
-			this.udpMutex.RUnlock()
-			return
-		}
-		nBytes, err := this.udpConn.WriteToUDP(udpMessage.Value, clientAddr)
-		this.udpMutex.RUnlock()
-		udpMessage.Release()
-		response.Data.Release()
-		if err != nil {
-			log.Error("Socks: failed to write UDP message (", nBytes, " bytes) to ", clientAddr, ": ", err)
-		}
+				this.udpMutex.RLock()
+				if !this.accepting {
+					this.udpMutex.RUnlock()
+					return
+				}
+				nBytes, err := this.udpConn.WriteToUDP(udpMessage.Value, &net.UDPAddr{
+					IP:   packet.Destination().Address().IP(),
+					Port: int(packet.Destination().Port()),
+				})
+				this.udpMutex.RUnlock()
+				udpMessage.Release()
+				response.Data.Release()
+				if err != nil {
+					log.Error("Socks: failed to write UDP message (", nBytes, " bytes) to ", packet.Destination(), ": ", err)
+				}
+			})
 	}
+	return nil
 }

+ 12 - 10
testing/scenarios/socks_end_test.go

@@ -189,18 +189,20 @@ func TestUDPAssociate(t *testing.T) {
 	})
 	assert.Error(err).IsNil()
 
-	udpPayload := "UDP request to udp server."
-	udpRequest := socks5UDPRequest(v2net.UDPDestination(v2net.IPAddress([]byte{127, 0, 0, 1}), targetPort), []byte(udpPayload))
+	for i := 0; i < 100; i++ {
+		udpPayload := "UDP request to udp server."
+		udpRequest := socks5UDPRequest(v2net.UDPDestination(v2net.IPAddress([]byte{127, 0, 0, 1}), targetPort), []byte(udpPayload))
 
-	nBytes, err = udpConn.Write(udpRequest)
-	assert.Int(nBytes).Equals(len(udpRequest))
-	assert.Error(err).IsNil()
+		nBytes, err = udpConn.Write(udpRequest)
+		assert.Int(nBytes).Equals(len(udpRequest))
+		assert.Error(err).IsNil()
 
-	udpResponse := make([]byte, 1024)
-	nBytes, err = udpConn.Read(udpResponse)
-	assert.Error(err).IsNil()
-	assert.Bytes(udpResponse[:nBytes]).Equals(
-		socks5UDPRequest(v2net.UDPDestination(v2net.IPAddress([]byte{127, 0, 0, 1}), targetPort), []byte("Processed: UDP request to udp server.")))
+		udpResponse := make([]byte, 1024)
+		nBytes, err = udpConn.Read(udpResponse)
+		assert.Error(err).IsNil()
+		assert.Bytes(udpResponse[:nBytes]).Equals(
+			socks5UDPRequest(v2net.UDPDestination(v2net.IPAddress([]byte{127, 0, 0, 1}), targetPort), []byte("Processed: UDP request to udp server.")))
+	}
 
 	udpConn.Close()
 	conn.Close()

+ 64 - 0
transport/hub/udp_server.go

@@ -0,0 +1,64 @@
+package hub
+
+import (
+	"sync"
+
+	"github.com/v2ray/v2ray-core/app/dispatcher"
+	v2net "github.com/v2ray/v2ray-core/common/net"
+	"github.com/v2ray/v2ray-core/transport/ray"
+)
+
+type UDPResponseCallback func(packet v2net.Packet)
+
+type connEntry struct {
+	inboundRay ray.InboundRay
+	callback   UDPResponseCallback
+}
+
+type UDPServer struct {
+	sync.RWMutex
+	conns            map[string]*connEntry
+	packetDispatcher dispatcher.PacketDispatcher
+}
+
+func NewUDPServer(packetDispatcher dispatcher.PacketDispatcher) *UDPServer {
+	return &UDPServer{
+		conns:            make(map[string]*connEntry),
+		packetDispatcher: packetDispatcher,
+	}
+}
+
+func (this *UDPServer) locateExistingAndDispatch(dest string, packet v2net.Packet) bool {
+	this.RLock()
+	defer this.RUnlock()
+	if entry, found := this.conns[dest]; found {
+		entry.inboundRay.InboundInput() <- packet.Chunk()
+		return true
+	}
+	return false
+}
+
+func (this *UDPServer) Dispatch(source v2net.Destination, packet v2net.Packet, callback UDPResponseCallback) {
+	destString := source.String() + "-" + packet.Destination().String()
+	if this.locateExistingAndDispatch(destString, packet) {
+		return
+	}
+
+	this.Lock()
+	inboundRay := this.packetDispatcher.DispatchToOutbound(v2net.NewPacket(packet.Destination(), packet.Chunk(), true))
+	this.conns[destString] = &connEntry{
+		inboundRay: inboundRay,
+		callback:   callback,
+	}
+	this.Unlock()
+	go this.handleConnection(destString, inboundRay, source, callback)
+}
+
+func (this *UDPServer) handleConnection(destString string, inboundRay ray.InboundRay, source v2net.Destination, callback UDPResponseCallback) {
+	for buffer := range inboundRay.InboundOutput() {
+		callback(v2net.NewPacket(source, buffer, false))
+	}
+	this.Lock()
+	delete(this.conns, destString)
+	this.Unlock()
+}