Browse Source

refactor UDP dispatcher to support fullcone dispatcher

Shelikhoo 4 years ago
parent
commit
ac65036808

+ 1 - 1
app/dns/nameserver_udp.go

@@ -56,7 +56,7 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
 		Interval: time.Minute,
 		Execute:  s.Cleanup,
 	}
-	s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
+	s.udpServer = udp.NewSplitDispatcher(dispatcher, s.HandleResponse)
 	newError("DNS: created UDP client initialized for ", address.NetAddr()).AtInfo().WriteToLog()
 	return s
 }

+ 1 - 1
proxy/shadowsocks/server.go

@@ -70,7 +70,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
 }
 
 func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error {
-	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
+	udpServer := udp.NewSplitDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
 		request := protocol.RequestHeaderFromContext(ctx)
 		if request == nil {
 			return

+ 1 - 1
proxy/socks/server.go

@@ -186,7 +186,7 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 }
 
 func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error {
-	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
+	udpServer := udp.NewSplitDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
 		payload := packet.Payload
 		newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
 

+ 1 - 1
proxy/trojan/server.go

@@ -204,7 +204,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
 }
 
 func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
-	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
+	udpServer := udp.NewSplitDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
 		if err := clientWriter.WriteMultiBufferWithMetadata(buf.MultiBuffer{packet.Payload}, packet.Source); err != nil {
 			newError("failed to write response").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
 		}

+ 2 - 188
transport/internet/udp/dispatcher.go

@@ -2,196 +2,10 @@ package udp
 
 import (
 	"context"
-	"io"
-	"sync"
-	"time"
-
-	"github.com/v2fly/v2ray-core/v5/common"
 	"github.com/v2fly/v2ray-core/v5/common/buf"
 	"github.com/v2fly/v2ray-core/v5/common/net"
-	"github.com/v2fly/v2ray-core/v5/common/protocol/udp"
-	"github.com/v2fly/v2ray-core/v5/common/session"
-	"github.com/v2fly/v2ray-core/v5/common/signal"
-	"github.com/v2fly/v2ray-core/v5/common/signal/done"
-	"github.com/v2fly/v2ray-core/v5/features/routing"
-	"github.com/v2fly/v2ray-core/v5/transport"
 )
 
-type ResponseCallback func(ctx context.Context, packet *udp.Packet)
-
-type connEntry struct {
-	link   *transport.Link
-	timer  signal.ActivityUpdater
-	cancel context.CancelFunc
-}
-
-type Dispatcher struct {
-	sync.RWMutex
-	conns      map[net.Destination]*connEntry
-	dispatcher routing.Dispatcher
-	callback   ResponseCallback
-}
-
-func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
-	return &Dispatcher{
-		conns:      make(map[net.Destination]*connEntry),
-		dispatcher: dispatcher,
-		callback:   callback,
-	}
-}
-
-func (v *Dispatcher) RemoveRay(dest net.Destination) {
-	v.Lock()
-	defer v.Unlock()
-	if conn, found := v.conns[dest]; found {
-		common.Close(conn.link.Reader)
-		common.Close(conn.link.Writer)
-		delete(v.conns, dest)
-	}
-}
-
-func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) *connEntry {
-	v.Lock()
-	defer v.Unlock()
-
-	if entry, found := v.conns[dest]; found {
-		return entry
-	}
-
-	newError("establishing new connection for ", dest).WriteToLog()
-
-	ctx, cancel := context.WithCancel(ctx)
-	removeRay := func() {
-		cancel()
-		v.RemoveRay(dest)
-	}
-	timer := signal.CancelAfterInactivity(ctx, removeRay, time.Second*4)
-	link, _ := v.dispatcher.Dispatch(ctx, dest)
-	entry := &connEntry{
-		link:   link,
-		timer:  timer,
-		cancel: removeRay,
-	}
-	v.conns[dest] = entry
-	go handleInput(ctx, entry, dest, v.callback)
-	return entry
-}
-
-func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) {
-	// TODO: Add user to destString
-	newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx))
-
-	conn := v.getInboundRay(ctx, destination)
-	outputStream := conn.link.Writer
-	if outputStream != nil {
-		if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
-			newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
-			conn.cancel()
-			return
-		}
-	}
-}
-
-func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback) {
-	defer conn.cancel()
-
-	input := conn.link.Reader
-	timer := conn.timer
-
-	for {
-		select {
-		case <-ctx.Done():
-			return
-		default:
-		}
-
-		mb, err := input.ReadMultiBuffer()
-		if err != nil {
-			newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx))
-			return
-		}
-		timer.Update()
-		for _, b := range mb {
-			callback(ctx, &udp.Packet{
-				Payload: b,
-				Source:  dest,
-			})
-		}
-	}
-}
-
-type dispatcherConn struct {
-	dispatcher *Dispatcher
-	cache      chan *udp.Packet
-	done       *done.Instance
-}
-
-func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) {
-	c := &dispatcherConn{
-		cache: make(chan *udp.Packet, 16),
-		done:  done.New(),
-	}
-
-	d := NewDispatcher(dispatcher, c.callback)
-	c.dispatcher = d
-	return c, nil
-}
-
-func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) {
-	select {
-	case <-c.done.Wait():
-		packet.Payload.Release()
-		return
-	case c.cache <- packet:
-	default:
-		packet.Payload.Release()
-		return
-	}
-}
-
-func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) {
-	select {
-	case <-c.done.Wait():
-		return 0, nil, io.EOF
-	case packet := <-c.cache:
-		n := copy(p, packet.Payload.Bytes())
-		return n, &net.UDPAddr{
-			IP:   packet.Source.Address.IP(),
-			Port: int(packet.Source.Port),
-		}, nil
-	}
-}
-
-func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) {
-	buffer := buf.New()
-	raw := buffer.Extend(buf.Size)
-	n := copy(raw, p)
-	buffer.Resize(0, int32(n))
-
-	ctx := context.Background()
-	c.dispatcher.Dispatch(ctx, net.DestinationFromAddr(addr), buffer)
-	return n, nil
-}
-
-func (c *dispatcherConn) Close() error {
-	return c.done.Close()
-}
-
-func (c *dispatcherConn) LocalAddr() net.Addr {
-	return &net.UDPAddr{
-		IP:   []byte{0, 0, 0, 0},
-		Port: 0,
-	}
-}
-
-func (c *dispatcherConn) SetDeadline(t time.Time) error {
-	return nil
-}
-
-func (c *dispatcherConn) SetReadDeadline(t time.Time) error {
-	return nil
-}
-
-func (c *dispatcherConn) SetWriteDeadline(t time.Time) error {
-	return nil
+type DispatcherI interface {
+	Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer)
 }

+ 197 - 0
transport/internet/udp/dispatcher_split.go

@@ -0,0 +1,197 @@
+package udp
+
+import (
+	"context"
+	"io"
+	"sync"
+	"time"
+
+	"github.com/v2fly/v2ray-core/v5/common"
+	"github.com/v2fly/v2ray-core/v5/common/buf"
+	"github.com/v2fly/v2ray-core/v5/common/net"
+	"github.com/v2fly/v2ray-core/v5/common/protocol/udp"
+	"github.com/v2fly/v2ray-core/v5/common/session"
+	"github.com/v2fly/v2ray-core/v5/common/signal"
+	"github.com/v2fly/v2ray-core/v5/common/signal/done"
+	"github.com/v2fly/v2ray-core/v5/features/routing"
+	"github.com/v2fly/v2ray-core/v5/transport"
+)
+
+type ResponseCallback func(ctx context.Context, packet *udp.Packet)
+
+type connEntry struct {
+	link   *transport.Link
+	timer  signal.ActivityUpdater
+	cancel context.CancelFunc
+}
+
+type Dispatcher struct {
+	sync.RWMutex
+	conns      map[net.Destination]*connEntry
+	dispatcher routing.Dispatcher
+	callback   ResponseCallback
+}
+
+func NewSplitDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
+	return &Dispatcher{
+		conns:      make(map[net.Destination]*connEntry),
+		dispatcher: dispatcher,
+		callback:   callback,
+	}
+}
+
+func (v *Dispatcher) RemoveRay(dest net.Destination) {
+	v.Lock()
+	defer v.Unlock()
+	if conn, found := v.conns[dest]; found {
+		common.Close(conn.link.Reader)
+		common.Close(conn.link.Writer)
+		delete(v.conns, dest)
+	}
+}
+
+func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) *connEntry {
+	v.Lock()
+	defer v.Unlock()
+
+	if entry, found := v.conns[dest]; found {
+		return entry
+	}
+
+	newError("establishing new connection for ", dest).WriteToLog()
+
+	ctx, cancel := context.WithCancel(ctx)
+	removeRay := func() {
+		cancel()
+		v.RemoveRay(dest)
+	}
+	timer := signal.CancelAfterInactivity(ctx, removeRay, time.Second*4)
+	link, _ := v.dispatcher.Dispatch(ctx, dest)
+	entry := &connEntry{
+		link:   link,
+		timer:  timer,
+		cancel: removeRay,
+	}
+	v.conns[dest] = entry
+	go handleInput(ctx, entry, dest, v.callback)
+	return entry
+}
+
+func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) {
+	// TODO: Add user to destString
+	newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx))
+
+	conn := v.getInboundRay(ctx, destination)
+	outputStream := conn.link.Writer
+	if outputStream != nil {
+		if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
+			newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
+			conn.cancel()
+			return
+		}
+	}
+}
+
+func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback) {
+	defer conn.cancel()
+
+	input := conn.link.Reader
+	timer := conn.timer
+
+	for {
+		select {
+		case <-ctx.Done():
+			return
+		default:
+		}
+
+		mb, err := input.ReadMultiBuffer()
+		if err != nil {
+			newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx))
+			return
+		}
+		timer.Update()
+		for _, b := range mb {
+			callback(ctx, &udp.Packet{
+				Payload: b,
+				Source:  dest,
+			})
+		}
+	}
+}
+
+type dispatcherConn struct {
+	dispatcher *Dispatcher
+	cache      chan *udp.Packet
+	done       *done.Instance
+}
+
+func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) {
+	c := &dispatcherConn{
+		cache: make(chan *udp.Packet, 16),
+		done:  done.New(),
+	}
+
+	d := NewSplitDispatcher(dispatcher, c.callback)
+	c.dispatcher = d
+	return c, nil
+}
+
+func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) {
+	select {
+	case <-c.done.Wait():
+		packet.Payload.Release()
+		return
+	case c.cache <- packet:
+	default:
+		packet.Payload.Release()
+		return
+	}
+}
+
+func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) {
+	select {
+	case <-c.done.Wait():
+		return 0, nil, io.EOF
+	case packet := <-c.cache:
+		n := copy(p, packet.Payload.Bytes())
+		return n, &net.UDPAddr{
+			IP:   packet.Source.Address.IP(),
+			Port: int(packet.Source.Port),
+		}, nil
+	}
+}
+
+func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) {
+	buffer := buf.New()
+	raw := buffer.Extend(buf.Size)
+	n := copy(raw, p)
+	buffer.Resize(0, int32(n))
+
+	ctx := context.Background()
+	c.dispatcher.Dispatch(ctx, net.DestinationFromAddr(addr), buffer)
+	return n, nil
+}
+
+func (c *dispatcherConn) Close() error {
+	return c.done.Close()
+}
+
+func (c *dispatcherConn) LocalAddr() net.Addr {
+	return &net.UDPAddr{
+		IP:   []byte{0, 0, 0, 0},
+		Port: 0,
+	}
+}
+
+func (c *dispatcherConn) SetDeadline(t time.Time) error {
+	return nil
+}
+
+func (c *dispatcherConn) SetReadDeadline(t time.Time) error {
+	return nil
+}
+
+func (c *dispatcherConn) SetWriteDeadline(t time.Time) error {
+	return nil
+}

+ 1 - 1
transport/internet/udp/dispatcher_test.go → transport/internet/udp/dispatcher_split_test.go

@@ -65,7 +65,7 @@ func TestSameDestinationDispatching(t *testing.T) {
 	b.WriteString("abcd")
 
 	var msgCount uint32
-	dispatcher := NewDispatcher(td, func(ctx context.Context, packet *udp.Packet) {
+	dispatcher := NewSplitDispatcher(td, func(ctx context.Context, packet *udp.Packet) {
 		atomic.AddUint32(&msgCount, 1)
 	})