Procházet zdrojové kódy

fix udp handling in dokodemo and mux

Darien Raymond před 8 roky
rodič
revize
d5f931ae8b

+ 5 - 4
app/proxyman/mux/mux.go

@@ -133,11 +133,12 @@ func (m *Client) monitor() {
 
 func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
 	dest, _ := proxy.TargetFromContext(ctx)
-	writer := &Writer{
-		dest:   dest,
-		id:     s.ID,
-		writer: output,
+	transferType := protocol.TransferTypeStream
+	if dest.Network == net.Network_UDP {
+		transferType = protocol.TransferTypePacket
 	}
+	s.transferType = transferType
+	writer := NewWriter(s.ID, dest, output, transferType)
 	defer writer.Close()
 	defer s.CloseUplink()
 

+ 7 - 2
proxy/dokodemo/dokodemo.go

@@ -84,9 +84,14 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 	})
 
 	responseDone := signal.ExecuteAsync(func() error {
-		v2writer := buf.NewWriter(conn)
+		var writer buf.Writer
+		if network == net.Network_TCP {
+			writer = buf.NewWriter(conn)
+		} else {
+			writer = buf.NewSequentialWriter(conn)
+		}
 
-		if err := buf.Copy(inboundRay.InboundOutput(), v2writer, buf.UpdateActivity(timer)); err != nil {
+		if err := buf.Copy(inboundRay.InboundOutput(), writer, buf.UpdateActivity(timer)); err != nil {
 			return newError("failed to transport response").Base(err)
 		}
 		return nil

+ 16 - 5
testing/scenarios/vmess_test.go

@@ -1198,15 +1198,26 @@ func TestVMessGCMMuxUDP(t *testing.T) {
 				})
 				assert.Error(err).IsNil()
 
+				conn.SetDeadline(time.Now().Add(time.Second * 10))
+
 				payload := make([]byte, 1024)
 				rand.Read(payload)
 
-				nBytes, err := conn.Write(payload)
-				assert.Error(err).IsNil()
-				assert.Int(nBytes).Equals(len(payload))
+				for j := 0; j < 10; j++ {
+					nBytes, _, err := conn.WriteMsgUDP(payload, nil, nil)
+					assert.Error(err).IsNil()
+					assert.Int(nBytes).Equals(len(payload))
+				}
+
+				response := make([]byte, 1024)
+				oob := make([]byte, 16)
+				for j := 0; j < 10; j++ {
+					nBytes, _, _, _, err := conn.ReadMsgUDP(response, oob)
+					assert.Error(err).IsNil()
+					assert.Int(nBytes).Equals(1024)
+					assert.Bytes(response).Equals(xor(payload))
+				}
 
-				response := readFrom(conn, time.Second*5, 1024)
-				assert.Bytes(response).Equals(xor(payload))
 				assert.Error(conn.Close()).IsNil()
 				wg.Done()
 			}()