Переглянути джерело

udp dispatcher takes context with dispatching requests. fixes #1182.

Darien Raymond 7 роки тому
батько
коміт
e3cc852c57

+ 8 - 8
app/dns/udpns.go

@@ -48,17 +48,17 @@ type ClassicNameServer struct {
 
 func NewClassicNameServer(address net.Destination, dispatcher core.Dispatcher, clientIP net.IP) *ClassicNameServer {
 	s := &ClassicNameServer{
-		address:   address,
-		ips:       make(map[string][]IPRecord),
-		requests:  make(map[uint16]pendingRequest),
-		udpServer: udp.NewDispatcher(dispatcher),
-		clientIP:  clientIP,
-		pub:       pubsub.NewService(),
+		address:  address,
+		ips:      make(map[string][]IPRecord),
+		requests: make(map[uint16]pendingRequest),
+		clientIP: clientIP,
+		pub:      pubsub.NewService(),
 	}
 	s.cleanup = &task.Periodic{
 		Interval: time.Minute,
 		Execute:  s.Cleanup,
 	}
+	s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
 	common.Must(s.cleanup.Start())
 	return s
 }
@@ -98,7 +98,7 @@ func (s *ClassicNameServer) Cleanup() error {
 	return nil
 }
 
-func (s *ClassicNameServer) HandleResponse(payload *buf.Buffer) {
+func (s *ClassicNameServer) HandleResponse(ctx context.Context, payload *buf.Buffer) {
 	msg := new(dns.Msg)
 	err := msg.Unpack(payload.Bytes())
 	if err == dns.ErrTruncated {
@@ -267,7 +267,7 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string) {
 	for _, msg := range msgs {
 		b, err := msgToBuffer(msg)
 		common.Must(err)
-		s.udpServer.Dispatch(ctx, s.address, b, s.HandleResponse)
+		s.udpServer.Dispatch(context.Background(), s.address, b)
 	}
 }
 

+ 13 - 0
common/protocol/context.go

@@ -8,6 +8,7 @@ type key int
 
 const (
 	userKey key = iota
+	requestKey
 )
 
 // ContextWithUser returns a context combined with a User.
@@ -23,3 +24,15 @@ func UserFromContext(ctx context.Context) *User {
 	}
 	return v.(*User)
 }
+
+func ContextWithRequestHeader(ctx context.Context, request *RequestHeader) context.Context {
+	return context.WithValue(ctx, requestKey, request)
+}
+
+func RequestHeaderFromContext(ctx context.Context) *RequestHeader {
+	request := ctx.Value(requestKey)
+	if request == nil {
+		return nil
+	}
+	return request.(*RequestHeader)
+}

+ 18 - 12
proxy/shadowsocks/server.go

@@ -74,7 +74,22 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
 }
 
 func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection, dispatcher core.Dispatcher) error {
-	udpServer := udp.NewDispatcher(dispatcher)
+	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, payload *buf.Buffer) {
+		request := protocol.RequestHeaderFromContext(ctx)
+		if request == nil {
+			return
+		}
+
+		data, err := EncodeUDPPacket(request, payload.Bytes())
+		payload.Release()
+		if err != nil {
+			newError("failed to encode UDP packet").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
+			return
+		}
+		defer data.Release()
+
+		conn.Write(data.Bytes())
+	})
 
 	reader := buf.NewReader(conn)
 	for {
@@ -123,17 +138,8 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 			newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx))
 
 			ctx = protocol.ContextWithUser(ctx, request.User)
-			udpServer.Dispatch(ctx, dest, data, func(payload *buf.Buffer) {
-				data, err := EncodeUDPPacket(request, payload.Bytes())
-				payload.Release()
-				if err != nil {
-					newError("failed to encode UDP packet").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
-					return
-				}
-				defer data.Release()
-
-				conn.Write(data.Bytes())
-			})
+			ctx = protocol.ContextWithRequestHeader(ctx, request)
+			udpServer.Dispatch(ctx, dest, data)
 		}
 	}
 

+ 19 - 14
proxy/socks/server.go

@@ -172,7 +172,23 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 }
 
 func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, dispatcher core.Dispatcher) error {
-	udpServer := udp.NewDispatcher(dispatcher)
+	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, payload *buf.Buffer) {
+		newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
+
+		request := protocol.RequestHeaderFromContext(ctx)
+		if request == nil {
+			return
+		}
+		udpMessage, err := EncodeUDPPacket(request, payload.Bytes())
+		payload.Release()
+
+		defer udpMessage.Release()
+		if err != nil {
+			newError("failed to write UDP response").AtWarning().Base(err).WriteToLog(session.ExportIDToError(ctx))
+		}
+
+		conn.Write(udpMessage.Bytes()) // nolint: errcheck
+	})
 
 	if source, ok := proxy.SourceFromContext(ctx); ok {
 		newError("client UDP connection from ", source).WriteToLog(session.ExportIDToError(ctx))
@@ -209,19 +225,8 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
 				})
 			}
 
-			udpServer.Dispatch(ctx, request.Destination(), payload, func(payload *buf.Buffer) {
-				newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
-
-				udpMessage, err := EncodeUDPPacket(request, payload.Bytes())
-				payload.Release()
-
-				defer udpMessage.Release()
-				if err != nil {
-					newError("failed to write UDP response").AtWarning().Base(err).WriteToLog(session.ExportIDToError(ctx))
-				}
-
-				conn.Write(udpMessage.Bytes()) // nolint: errcheck
-			})
+			ctx = protocol.ContextWithRequestHeader(ctx, request)
+			udpServer.Dispatch(ctx, request.Destination(), payload)
 		}
 	}
 }

+ 104 - 0
testing/scenarios/socks_test.go

@@ -3,6 +3,8 @@ package scenarios
 import (
 	"testing"
 
+	"v2ray.com/core/app/router"
+
 	xproxy "golang.org/x/net/proxy"
 	socks4 "h12.me/socks"
 	"v2ray.com/core"
@@ -10,6 +12,7 @@ import (
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/serial"
+	"v2ray.com/core/proxy/blackhole"
 	"v2ray.com/core/proxy/dokodemo"
 	"v2ray.com/core/proxy/freedom"
 	"v2ray.com/core/proxy/socks"
@@ -212,6 +215,107 @@ func TestSocksBridageUDP(t *testing.T) {
 	CloseAllServers(servers)
 }
 
+func TestSocksBridageUDPWithRouting(t *testing.T) {
+	assert := With(t)
+
+	udpServer := udp.Server{
+		MsgProcessor: xor,
+	}
+	dest, err := udpServer.Start()
+	assert(err, IsNil)
+	defer udpServer.Close()
+
+	serverPort := tcp.PickPort()
+	serverConfig := &core.Config{
+		App: []*serial.TypedMessage{
+			serial.ToTypedMessage(&router.Config{
+				Rule: []*router.RoutingRule{
+					{
+						Tag:        "out",
+						InboundTag: []string{"socks"},
+					},
+				},
+			}),
+		},
+		Inbound: []*core.InboundHandlerConfig{
+			{
+				Tag: "socks",
+				ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{
+					PortRange: net.SinglePortRange(serverPort),
+					Listen:    net.NewIPOrDomain(net.LocalHostIP),
+				}),
+				ProxySettings: serial.ToTypedMessage(&socks.ServerConfig{
+					AuthType:   socks.AuthType_NO_AUTH,
+					Address:    net.NewIPOrDomain(net.LocalHostIP),
+					UdpEnabled: true,
+				}),
+			},
+		},
+		Outbound: []*core.OutboundHandlerConfig{
+			{
+				ProxySettings: serial.ToTypedMessage(&blackhole.Config{}),
+			},
+			{
+				Tag:           "out",
+				ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
+			},
+		},
+	}
+
+	clientPort := tcp.PickPort()
+	clientConfig := &core.Config{
+		Inbound: []*core.InboundHandlerConfig{
+			{
+				ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{
+					PortRange: net.SinglePortRange(clientPort),
+					Listen:    net.NewIPOrDomain(net.LocalHostIP),
+				}),
+				ProxySettings: serial.ToTypedMessage(&dokodemo.Config{
+					Address: net.NewIPOrDomain(dest.Address),
+					Port:    uint32(dest.Port),
+					NetworkList: &net.NetworkList{
+						Network: []net.Network{net.Network_TCP, net.Network_UDP},
+					},
+				}),
+			},
+		},
+		Outbound: []*core.OutboundHandlerConfig{
+			{
+				ProxySettings: serial.ToTypedMessage(&socks.ClientConfig{
+					Server: []*protocol.ServerEndpoint{
+						{
+							Address: net.NewIPOrDomain(net.LocalHostIP),
+							Port:    uint32(serverPort),
+						},
+					},
+				}),
+			},
+		},
+	}
+
+	servers, err := InitializeServerConfigs(serverConfig, clientConfig)
+	assert(err, IsNil)
+
+	conn, err := net.DialUDP("udp", nil, &net.UDPAddr{
+		IP:   []byte{127, 0, 0, 1},
+		Port: int(clientPort),
+	})
+	assert(err, IsNil)
+
+	payload := "dokodemo request."
+	nBytes, err := conn.Write([]byte(payload))
+	assert(err, IsNil)
+	assert(nBytes, Equals, len(payload))
+
+	response := make([]byte, 1024)
+	nBytes, err = conn.Read(response)
+	assert(err, IsNil)
+	assert(response[:nBytes], Equals, xor([]byte(payload)))
+	assert(conn.Close(), IsNil)
+
+	CloseAllServers(servers)
+}
+
 func TestSocksConformance(t *testing.T) {
 	assert := With(t)
 

+ 12 - 9
transport/internet/udp/dispatcher.go

@@ -13,7 +13,7 @@ import (
 	"v2ray.com/core/common/signal"
 )
 
-type ResponseCallback func(payload *buf.Buffer)
+type ResponseCallback func(ctx context.Context, payload *buf.Buffer)
 
 type connEntry struct {
 	link   *core.Link
@@ -25,12 +25,14 @@ type Dispatcher struct {
 	sync.RWMutex
 	conns      map[net.Destination]*connEntry
 	dispatcher core.Dispatcher
+	callback   ResponseCallback
 }
 
-func NewDispatcher(dispatcher core.Dispatcher) *Dispatcher {
+func NewDispatcher(dispatcher core.Dispatcher, callback ResponseCallback) *Dispatcher {
 	return &Dispatcher{
 		conns:      make(map[net.Destination]*connEntry),
 		dispatcher: dispatcher,
+		callback:   callback,
 	}
 }
 
@@ -44,7 +46,7 @@ func (v *Dispatcher) RemoveRay(dest net.Destination) {
 	}
 }
 
-func (v *Dispatcher) getInboundRay(dest net.Destination, callback ResponseCallback) *connEntry {
+func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) *connEntry {
 	v.Lock()
 	defer v.Unlock()
 
@@ -54,7 +56,7 @@ func (v *Dispatcher) getInboundRay(dest net.Destination, callback ResponseCallba
 
 	newError("establishing new connection for ", dest).WriteToLog()
 
-	ctx, cancel := context.WithCancel(context.Background())
+	ctx, cancel := context.WithCancel(ctx)
 	removeRay := func() {
 		cancel()
 		v.RemoveRay(dest)
@@ -67,15 +69,15 @@ func (v *Dispatcher) getInboundRay(dest net.Destination, callback ResponseCallba
 		cancel: removeRay,
 	}
 	v.conns[dest] = entry
-	go handleInput(ctx, entry, callback)
+	go handleInput(ctx, entry, v.callback)
 	return entry
 }
 
-func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer, callback ResponseCallback) {
+func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) {
 	// TODO: Add user to destString
 	newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx))
 
-	conn := v.getInboundRay(destination, callback)
+	conn := v.getInboundRay(ctx, destination)
 	outputStream := conn.link.Writer
 	if outputStream != nil {
 		if err := outputStream.WriteMultiBuffer(buf.NewMultiBufferValue(payload)); err != nil {
@@ -87,6 +89,8 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
 }
 
 func handleInput(ctx context.Context, conn *connEntry, callback ResponseCallback) {
+	defer conn.cancel()
+
 	input := conn.link.Reader
 	timer := conn.timer
 
@@ -100,12 +104,11 @@ func handleInput(ctx context.Context, conn *connEntry, callback ResponseCallback
 		mb, err := input.ReadMultiBuffer()
 		if err != nil {
 			newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx))
-			conn.cancel()
 			return
 		}
 		timer.Update()
 		for _, b := range mb {
-			callback(b)
+			callback(ctx, b)
 		}
 	}
 }