Forráskód Böngészése

refine udp hub api

Darien Raymond 7 éve
szülő
commit
053fc38d38

+ 9 - 1
app/proxyman/inbound/worker.go

@@ -317,9 +317,16 @@ func (w *udpWorker) removeConn(id connID) {
 	w.Unlock()
 }
 
+func (w *udpWorker) handlePackets() {
+	receive := w.hub.Receive()
+	for payload := range receive {
+		w.callback(payload.Content, payload.Source, payload.OriginalDestination)
+	}
+}
+
 func (w *udpWorker) Start() error {
 	w.activeConn = make(map[connID]*udpConn, 16)
-	h, err := udp.ListenUDP(w.address, w.port, w.callback, udp.HubReceiveOriginalDestination(w.recvOrigDest), udp.HubCapacity(256))
+	h, err := udp.ListenUDP(w.address, w.port, udp.HubReceiveOriginalDestination(w.recvOrigDest), udp.HubCapacity(256))
 	if err != nil {
 		return err
 	}
@@ -352,6 +359,7 @@ func (w *udpWorker) Start() error {
 		return err
 	}
 	w.hub = h
+	go w.handlePackets()
 	return nil
 }
 

+ 16 - 9
transport/internet/kcp/listener.go

@@ -61,7 +61,7 @@ func NewListener(ctx context.Context, address net.Address, port net.Port, addCon
 		l.tlsConfig = config.GetTLSConfig()
 	}
 
-	hub, err := udp.ListenUDP(address, port, l.OnReceive, udp.HubCapacity(1024))
+	hub, err := udp.ListenUDP(address, port, udp.HubCapacity(1024))
 	if err != nil {
 		return nil, err
 	}
@@ -69,10 +69,20 @@ func NewListener(ctx context.Context, address net.Address, port net.Port, addCon
 	l.hub = hub
 	l.Unlock()
 	newError("listening on ", address, ":", port).WriteToLog()
+
+	go l.handlePackets()
+
 	return l, nil
 }
 
-func (l *Listener) OnReceive(payload *buf.Buffer, src net.Destination, originalDest net.Destination) {
+func (l *Listener) handlePackets() {
+	receive := l.hub.Receive()
+	for payload := range receive {
+		l.OnReceive(payload.Content, payload.Source)
+	}
+}
+
+func (l *Listener) OnReceive(payload *buf.Buffer, src net.Destination) {
 	segments := l.reader.Read(payload.Bytes())
 	payload.Release()
 
@@ -81,13 +91,6 @@ func (l *Listener) OnReceive(payload *buf.Buffer, src net.Destination, originalD
 		return
 	}
 
-	l.Lock()
-	defer l.Unlock()
-
-	if l.hub == nil {
-		return
-	}
-
 	conv := segments[0].Conversation()
 	cmd := segments[0].Command()
 
@@ -96,6 +99,10 @@ func (l *Listener) OnReceive(payload *buf.Buffer, src net.Destination, originalD
 		Port:   src.Port,
 		Conv:   conv,
 	}
+
+	l.Lock()
+	defer l.Unlock()
+
 	conn, found := l.sessions[id]
 
 	if !found {

+ 19 - 25
transport/internet/udp/hub.go

@@ -7,14 +7,11 @@ import (
 
 // Payload represents a single UDP payload.
 type Payload struct {
-	payload      *buf.Buffer
-	source       net.Destination
-	originalDest net.Destination
+	Content             *buf.Buffer
+	Source              net.Destination
+	OriginalDestination net.Destination
 }
 
-// PayloadHandler is function to handle Payload.
-type PayloadHandler func(payload *buf.Buffer, source net.Destination, originalDest net.Destination)
-
 type HubOption func(h *Hub)
 
 func HubCapacity(cap int) HubOption {
@@ -31,12 +28,12 @@ func HubReceiveOriginalDestination(r bool) HubOption {
 
 type Hub struct {
 	conn         *net.UDPConn
-	callback     PayloadHandler
+	cache        chan *Payload
 	capacity     int
 	recvOrigDest bool
 }
 
-func ListenUDP(address net.Address, port net.Port, callback PayloadHandler, options ...HubOption) (*Hub, error) {
+func ListenUDP(address net.Address, port net.Port, options ...HubOption) (*Hub, error) {
 	udpConn, err := net.ListenUDP("udp", &net.UDPAddr{
 		IP:   address.IP(),
 		Port: int(port),
@@ -48,13 +45,14 @@ func ListenUDP(address net.Address, port net.Port, callback PayloadHandler, opti
 	hub := &Hub{
 		conn:         udpConn,
 		capacity:     256,
-		callback:     callback,
 		recvOrigDest: false,
 	}
 	for _, opt := range options {
 		opt(hub)
 	}
 
+	hub.cache = make(chan *Payload, hub.capacity)
+
 	if hub.recvOrigDest {
 		rawConn, err := udpConn.SyscallConn()
 		if err != nil {
@@ -70,10 +68,7 @@ func ListenUDP(address net.Address, port net.Port, callback PayloadHandler, opti
 		}
 	}
 
-	c := make(chan *Payload, hub.capacity)
-
-	go hub.start(c)
-	go hub.process(c)
+	go hub.start()
 	return hub, nil
 }
 
@@ -90,13 +85,8 @@ func (h *Hub) WriteTo(payload []byte, dest net.Destination) (int, error) {
 	})
 }
 
-func (h *Hub) process(c <-chan *Payload) {
-	for p := range c {
-		h.callback(p.payload, p.source, p.originalDest)
-	}
-}
-
-func (h *Hub) start(c chan<- *Payload) {
+func (h *Hub) start() {
+	c := h.cache
 	defer close(c)
 
 	oobBytes := make([]byte, 256)
@@ -119,13 +109,13 @@ func (h *Hub) start(c chan<- *Payload) {
 		}
 
 		payload := &Payload{
-			payload: buffer,
+			Content: buffer,
+			Source:  net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port)),
 		}
-		payload.source = net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port))
 		if h.recvOrigDest && noob > 0 {
-			payload.originalDest = RetrieveOriginalDest(oobBytes[:noob])
-			if payload.originalDest.IsValid() {
-				newError("UDP original destination: ", payload.originalDest).AtDebug().WriteToLog()
+			payload.OriginalDestination = RetrieveOriginalDest(oobBytes[:noob])
+			if payload.OriginalDestination.IsValid() {
+				newError("UDP original destination: ", payload.OriginalDestination).AtDebug().WriteToLog()
 			} else {
 				newError("failed to read UDP original destination").WriteToLog()
 			}
@@ -143,3 +133,7 @@ func (h *Hub) start(c chan<- *Payload) {
 func (h *Hub) Addr() net.Addr {
 	return h.conn.LocalAddr()
 }
+
+func (h *Hub) Receive() <-chan *Payload {
+	return h.cache
+}