Browse Source

use ListenPacket in Dial UDP connection

Darien Raymond 6 years ago
parent
commit
a1b552f948
1 changed files with 79 additions and 15 deletions
  1. 79 15
      transport/internet/system_dialer.go

+ 79 - 15
transport/internet/system_dialer.go

@@ -21,10 +21,51 @@ type DefaultSystemDialer struct {
 	controllers []controller
 }
 
+func resolveSrcAddr(network net.Network, src net.Address) net.Addr {
+	if src == nil || src == net.AnyIP {
+		return nil
+	}
+
+	if network == net.Network_TCP {
+		return &net.TCPAddr{
+			IP:   src.IP(),
+			Port: 0,
+		}
+	}
+
+	return &net.UDPAddr{
+		IP:   src.IP(),
+		Port: 0,
+	}
+}
+
 func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) {
+	if dest.Network == net.Network_UDP {
+		srcAddr := resolveSrcAddr(net.Network_UDP, src)
+		if srcAddr == nil {
+			srcAddr = &net.UDPAddr{
+				IP:   []byte{0, 0, 0, 0},
+				Port: 0,
+			}
+		}
+		packetConn, err := ListenSystemPacket(ctx, srcAddr, sockopt)
+		if err != nil {
+			return nil, err
+		}
+		destAddr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
+		if err != nil {
+			return nil, err
+		}
+		return &packetConnWrapper{
+			conn: packetConn,
+			dest: destAddr,
+		}, nil
+	}
+
 	dialer := &net.Dialer{
 		Timeout:   time.Second * 60,
 		DualStack: true,
+		LocalAddr: resolveSrcAddr(dest.Network, src),
 	}
 
 	if sockopt != nil || len(d.controllers) > 0 {
@@ -50,24 +91,47 @@ func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest ne
 		}
 	}
 
-	if src != nil && src != net.AnyIP {
-		var addr net.Addr
-		if dest.Network == net.Network_TCP {
-			addr = &net.TCPAddr{
-				IP:   src.IP(),
-				Port: 0,
-			}
-		} else {
-			addr = &net.UDPAddr{
-				IP:   src.IP(),
-				Port: 0,
-			}
-		}
-		dialer.LocalAddr = addr
-	}
 	return dialer.DialContext(ctx, dest.Network.SystemString(), dest.NetAddr())
 }
 
+type packetConnWrapper struct {
+	conn net.PacketConn
+	dest net.Addr
+}
+
+func (c *packetConnWrapper) Close() error {
+	return c.conn.Close()
+}
+
+func (c *packetConnWrapper) LocalAddr() net.Addr {
+	return c.conn.LocalAddr()
+}
+
+func (c *packetConnWrapper) RemoteAddr() net.Addr {
+	return c.dest
+}
+
+func (c *packetConnWrapper) Write(p []byte) (int, error) {
+	return c.conn.WriteTo(p, c.dest)
+}
+
+func (c *packetConnWrapper) Read(p []byte) (int, error) {
+	n, _, err := c.conn.ReadFrom(p)
+	return n, err
+}
+
+func (c *packetConnWrapper) SetDeadline(t time.Time) error {
+	return c.conn.SetDeadline(t)
+}
+
+func (c *packetConnWrapper) SetReadDeadline(t time.Time) error {
+	return c.conn.SetReadDeadline(t)
+}
+
+func (c *packetConnWrapper) SetWriteDeadline(t time.Time) error {
+	return c.conn.SetWriteDeadline(t)
+}
+
 type SystemDialerAdapter interface {
 	Dial(network string, address string) (net.Conn, error)
 }