Browse Source

Feat: add PacketAddr support to Trojan client

秋のかえで 3 years ago
parent
commit
01a2686568
2 changed files with 81 additions and 1 deletions
  1. 48 1
      proxy/trojan/client.go
  2. 33 0
      proxy/trojan/protocol.go

+ 48 - 1
proxy/trojan/client.go

@@ -7,6 +7,7 @@ import (
 	"github.com/v2fly/v2ray-core/v5/common"
 	"github.com/v2fly/v2ray-core/v5/common"
 	"github.com/v2fly/v2ray-core/v5/common/buf"
 	"github.com/v2fly/v2ray-core/v5/common/buf"
 	"github.com/v2fly/v2ray-core/v5/common/net"
 	"github.com/v2fly/v2ray-core/v5/common/net"
+	"github.com/v2fly/v2ray-core/v5/common/net/packetaddr"
 	"github.com/v2fly/v2ray-core/v5/common/protocol"
 	"github.com/v2fly/v2ray-core/v5/common/protocol"
 	"github.com/v2fly/v2ray-core/v5/common/retry"
 	"github.com/v2fly/v2ray-core/v5/common/retry"
 	"github.com/v2fly/v2ray-core/v5/common/session"
 	"github.com/v2fly/v2ray-core/v5/common/session"
@@ -16,6 +17,7 @@ import (
 	"github.com/v2fly/v2ray-core/v5/proxy"
 	"github.com/v2fly/v2ray-core/v5/proxy"
 	"github.com/v2fly/v2ray-core/v5/transport"
 	"github.com/v2fly/v2ray-core/v5/transport"
 	"github.com/v2fly/v2ray-core/v5/transport/internet"
 	"github.com/v2fly/v2ray-core/v5/transport/internet"
+	"github.com/v2fly/v2ray-core/v5/transport/internet/udp"
 )
 )
 
 
 // Client is an inbound handler for trojan protocol
 // Client is an inbound handler for trojan protocol
@@ -85,6 +87,51 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 	ctx, cancel := context.WithCancel(ctx)
 	ctx, cancel := context.WithCancel(ctx)
 	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
 	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
 
 
+	if packetConn, err := packetaddr.ToPacketAddrConn(link, destination); err == nil {
+		postRequest := func() error {
+			defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
+
+			var buffer [2048]byte
+			_, addr, err := packetConn.ReadFrom(buffer[:])
+			if err != nil {
+				return newError("failed to read a packet").Base(err)
+			}
+			dest := net.DestinationFromAddr(addr)
+
+			bufferWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
+			connWriter := &ConnWriter{Writer: bufferWriter, Target: dest, Account: account}
+			packetWriter := &PacketWriter{Writer: connWriter, Target: dest}
+
+			// write some request payload to buffer
+			if _, err := packetWriter.WriteTo(buffer[:], addr); err != nil {
+				return newError("failed to write a request payload").Base(err)
+			}
+
+			// Flush; bufferWriter.WriteMultiBuffer now is bufferWriter.writer.WriteMultiBuffer
+			if err = bufferWriter.SetBuffered(false); err != nil {
+				return newError("failed to flush payload").Base(err).AtWarning()
+			}
+
+			return udp.CopyPacketConn(packetWriter, packetConn, udp.UpdateActivity(timer))
+		}
+
+		getResponse := func() error {
+			defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
+
+			packetReader := &PacketReader{Reader: conn}
+			splitReader := &PacketSplitReader{Reader: packetReader}
+
+			return udp.CopyPacketConn(packetConn, splitReader, udp.UpdateActivity(timer))
+		}
+
+		responseDoneAndCloseWriter := task.OnSuccess(getResponse, task.Close(link.Writer))
+		if err := task.Run(ctx, postRequest, responseDoneAndCloseWriter); err != nil {
+			return newError("connection ends").Base(err)
+		}
+
+		return nil
+	}
+
 	postRequest := func() error {
 	postRequest := func() error {
 		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
 		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
 
 
@@ -100,7 +147,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 
 
 		// write some request payload to buffer
 		// write some request payload to buffer
 		if err = buf.CopyOnceTimeout(link.Reader, bodyWriter, proxy.FirstPayloadTimeout); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout {
 		if err = buf.CopyOnceTimeout(link.Reader, bodyWriter, proxy.FirstPayloadTimeout); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout {
-			return newError("failed to write A request payload").Base(err).AtWarning()
+			return newError("failed to write a request payload").Base(err).AtWarning()
 		}
 		}
 
 
 		// Flush; bufferWriter.WriteMultiBuffer now is bufferWriter.writer.WriteMultiBuffer
 		// Flush; bufferWriter.WriteMultiBuffer now is bufferWriter.writer.WriteMultiBuffer

+ 33 - 0
proxy/trojan/protocol.go

@@ -3,6 +3,7 @@ package trojan
 import (
 import (
 	"encoding/binary"
 	"encoding/binary"
 	"io"
 	"io"
+	gonet "net"
 
 
 	"github.com/v2fly/v2ray-core/v5/common/buf"
 	"github.com/v2fly/v2ray-core/v5/common/buf"
 	"github.com/v2fly/v2ray-core/v5/common/net"
 	"github.com/v2fly/v2ray-core/v5/common/net"
@@ -128,6 +129,12 @@ func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net
 	return nil
 	return nil
 }
 }
 
 
+func (w *PacketWriter) WriteTo(payload []byte, addr gonet.Addr) (int, error) {
+	dest := net.DestinationFromAddr(addr)
+
+	return w.writePacket(payload, dest)
+}
+
 func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) { // nolint: unparam
 func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) { // nolint: unparam
 	buffer := buf.StackNew()
 	buffer := buf.StackNew()
 	defer buffer.Release()
 	defer buffer.Release()
@@ -279,3 +286,29 @@ func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
 
 
 	return &PacketPayload{Target: dest, Buffer: mb}, nil
 	return &PacketPayload{Target: dest, Buffer: mb}, nil
 }
 }
+
+type PacketSplitReader struct {
+	Reader  *PacketReader
+	Payload *PacketPayload
+}
+
+func (r *PacketSplitReader) ReadFrom(p []byte) (int, gonet.Addr, error) {
+	var err error
+
+	if r.Payload == nil || r.Payload.Buffer.IsEmpty() {
+		r.Payload, err = r.Reader.ReadMultiBufferWithMetadata()
+		if err != nil {
+			return 0, nil, err
+		}
+	}
+
+	addr := &gonet.UDPAddr{
+		IP:   r.Payload.Target.Address.IP(),
+		Port: int(r.Payload.Target.Port),
+	}
+
+	mb, nBytes := buf.SplitBytes(r.Payload.Buffer, p)
+	r.Payload.Buffer = mb
+
+	return nBytes, addr, nil
+}