浏览代码

simplify udp hub

Darien Raymond 7 年之前
父节点
当前提交
7391b2439e
共有 3 个文件被更改,包括 47 次插入73 次删除
  1. 1 4
      app/proxyman/inbound/worker.go
  2. 1 1
      transport/internet/kcp/listener.go
  3. 45 68
      transport/internet/udp/hub.go

+ 1 - 4
app/proxyman/inbound/worker.go

@@ -244,10 +244,7 @@ func (w *udpWorker) removeConn(id connID) {
 func (w *udpWorker) Start() error {
 	w.activeConn = make(map[connID]*udpConn, 16)
 	w.done = signal.NewDone()
-	h, err := udp.ListenUDP(w.address, w.port, udp.ListenOption{
-		Callback:            w.callback,
-		ReceiveOriginalDest: w.recvOrigDest,
-	})
+	h, err := udp.ListenUDP(w.address, w.port, w.callback, udp.HubReceiveOriginalDestination(w.recvOrigDest))
 	if err != nil {
 		return err
 	}

+ 1 - 1
transport/internet/kcp/listener.go

@@ -61,7 +61,7 @@ func NewListener(ctx context.Context, address net.Address, port net.Port, addCon
 		l.tlsConfig = config.GetTLSConfig()
 	}
 
-	hub, err := udp.ListenUDP(address, port, udp.ListenOption{Callback: l.OnReceive, Concurrency: 2})
+	hub, err := udp.ListenUDP(address, port, l.OnReceive, udp.HubCapacity(64))
 	if err != nil {
 		return nil, err
 	}

+ 45 - 68
transport/internet/udp/hub.go

@@ -2,7 +2,6 @@ package udp
 
 import (
 	"v2ray.com/core/common/buf"
-	"v2ray.com/core/common/dice"
 	"v2ray.com/core/common/net"
 )
 
@@ -16,71 +15,28 @@ type Payload struct {
 // PayloadHandler is function to handle Payload.
 type PayloadHandler func(payload *buf.Buffer, source net.Destination, originalDest net.Destination)
 
-// PayloadQueue is a queue of Payload.
-type PayloadQueue struct {
-	queue    []chan Payload
-	callback PayloadHandler
-}
-
-// NewPayloadQueue returns a new PayloadQueue.
-func NewPayloadQueue(option ListenOption) *PayloadQueue {
-	queue := &PayloadQueue{
-		callback: option.Callback,
-		queue:    make([]chan Payload, option.Concurrency),
-	}
-	for i := range queue.queue {
-		queue.queue[i] = make(chan Payload, 64)
-		go queue.Dequeue(queue.queue[i])
-	}
-	return queue
-}
+type HubOption func(h *Hub)
 
-// Enqueue adds the payload to the end of this queue.
-func (q *PayloadQueue) Enqueue(payload Payload) {
-	size := len(q.queue)
-	idx := 0
-	if size > 1 {
-		idx = dice.Roll(size)
-	}
-	for i := 0; i < size; i++ {
-		select {
-		case q.queue[idx%size] <- payload:
-			return
-		default:
-			idx++
-		}
+func HubCapacity(cap int) HubOption {
+	return func(h *Hub) {
+		h.capacity = cap
 	}
 }
 
-func (q *PayloadQueue) Dequeue(queue <-chan Payload) {
-	for payload := range queue {
-		q.callback(payload.payload, payload.source, payload.originalDest)
+func HubReceiveOriginalDestination(r bool) HubOption {
+	return func(h *Hub) {
+		h.recvOrigDest = r
 	}
 }
 
-func (q *PayloadQueue) Close() error {
-	for _, queue := range q.queue {
-		close(queue)
-	}
-	return nil
-}
-
-type ListenOption struct {
-	Callback            PayloadHandler
-	ReceiveOriginalDest bool
-	Concurrency         int
-}
-
 type Hub struct {
-	conn   *net.UDPConn
-	queue  *PayloadQueue
-	option ListenOption
+	conn         *net.UDPConn
+	callback     PayloadHandler
+	capacity     int
+	recvOrigDest bool
 }
 
-func ListenUDP(address net.Address, port net.Port, option ListenOption) (*Hub, error) {
-	if option.Concurrency < 1 {
-		option.Concurrency = 1
-	}
+func ListenUDP(address net.Address, port net.Port, callback PayloadHandler, options ...HubOption) (*Hub, error) {
 	udpConn, err := net.ListenUDP("udp", &net.UDPAddr{
 		IP:   address.IP(),
 		Port: int(port),
@@ -89,7 +45,17 @@ func ListenUDP(address net.Address, port net.Port, option ListenOption) (*Hub, e
 		return nil, err
 	}
 	newError("listening UDP on ", address, ":", port).WriteToLog()
-	if option.ReceiveOriginalDest {
+	hub := &Hub{
+		conn:         udpConn,
+		capacity:     16,
+		callback:     callback,
+		recvOrigDest: false,
+	}
+	for _, opt := range options {
+		opt(hub)
+	}
+
+	if hub.recvOrigDest {
 		rawConn, err := udpConn.SyscallConn()
 		if err != nil {
 			return nil, newError("failed to get fd").Base(err)
@@ -103,12 +69,11 @@ func ListenUDP(address net.Address, port net.Port, option ListenOption) (*Hub, e
 			return nil, newError("failed to control socket").Base(err)
 		}
 	}
-	hub := &Hub{
-		conn:   udpConn,
-		queue:  NewPayloadQueue(option),
-		option: option,
-	}
-	go hub.start()
+
+	c := make(chan *Payload, hub.capacity)
+
+	go hub.start(c)
+	go hub.process(c)
 	return hub, nil
 }
 
@@ -125,7 +90,15 @@ func (h *Hub) WriteTo(payload []byte, dest net.Destination) (int, error) {
 	})
 }
 
-func (h *Hub) start() {
+func (h *Hub) process(c <-chan *Payload) {
+	for p := range c {
+		h.callback(p.payload, p.source, p.originalDest)
+	}
+}
+
+func (h *Hub) start(c chan<- *Payload) {
+	defer close(c)
+
 	oobBytes := make([]byte, 256)
 
 	for {
@@ -145,11 +118,11 @@ func (h *Hub) start() {
 			break
 		}
 
-		payload := Payload{
+		payload := &Payload{
 			payload: buffer,
 		}
 		payload.source = net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port))
-		if h.option.ReceiveOriginalDest && noob > 0 {
+		if h.recvOrigDest && noob > 0 {
 			payload.originalDest = RetrieveOriginalDest(oobBytes[:noob])
 			if payload.originalDest.IsValid() {
 				newError("UDP original destination: ", payload.originalDest).AtDebug().WriteToLog()
@@ -157,9 +130,13 @@ func (h *Hub) start() {
 				newError("failed to read UDP original destination").WriteToLog()
 			}
 		}
-		h.queue.Enqueue(payload)
+
+		select {
+		case c <- payload:
+		default:
+		}
+
 	}
-	h.queue.Close()
 }
 
 // Addr implements net.Listener.