Browse Source

fix(tproxy): the problem that cannot find ipv6 destination in redirect mode of tproxy (#815)

* fix(tproxy): the problem that cannot find ipv6 destination in redirect mode of tproxy

* chore(lint): reformat the code

* chore(lint): Codacy Static Code Analysis
mzz 4 years ago
parent
commit
5f851c05b1
1 changed files with 29 additions and 6 deletions
  1. 29 6
      transport/internet/tcp/sockopt_linux.go

+ 29 - 6
transport/internet/tcp/sockopt_linux.go

@@ -23,14 +23,37 @@ func GetOriginalDestination(conn internet.Connection) (net.Destination, error) {
 	}
 	}
 	var dest net.Destination
 	var dest net.Destination
 	err = rawConn.Control(func(fd uintptr) {
 	err = rawConn.Control(func(fd uintptr) {
-		addr, err := syscall.GetsockoptIPv6Mreq(int(fd), syscall.IPPROTO_IP, SO_ORIGINAL_DST)
-		if err != nil {
-			newError("failed to call getsockopt").Base(err).WriteToLog()
+		var remoteIP net.IP
+		switch addr := conn.RemoteAddr().(type) {
+		case *net.TCPAddr:
+			remoteIP = addr.IP
+		case *net.UDPAddr:
+			remoteIP = addr.IP
+		default:
+			newError("failed to call getsockopt").WriteToLog()
 			return
 			return
 		}
 		}
-		ip := net.IPAddress(addr.Multiaddr[4:8])
-		port := uint16(addr.Multiaddr[2])<<8 + uint16(addr.Multiaddr[3])
-		dest = net.TCPDestination(ip, net.Port(port))
+		if remoteIP.To4() != nil {
+			// ipv4
+			addr, err := syscall.GetsockoptIPv6Mreq(int(fd), syscall.IPPROTO_IP, SO_ORIGINAL_DST)
+			if err != nil {
+				newError("failed to call getsockopt").Base(err).WriteToLog()
+				return
+			}
+			ip := net.IPAddress(addr.Multiaddr[4:8])
+			port := uint16(addr.Multiaddr[2])<<8 + uint16(addr.Multiaddr[3])
+			dest = net.TCPDestination(ip, net.Port(port))
+		} else {
+			// ipv6
+			addr, err := syscall.GetsockoptIPv6MTUInfo(int(fd), syscall.IPPROTO_IPV6, SO_ORIGINAL_DST)
+			if err != nil {
+				newError("failed to call getsockopt").Base(err).WriteToLog()
+				return
+			}
+			ip := net.IPAddress(addr.Addr.Addr[:])
+			port := net.PortFromBytes([]byte{byte(addr.Addr.Port), byte(addr.Addr.Port >> 8)})
+			dest = net.TCPDestination(ip, port)
+		}
 	})
 	})
 	if err != nil {
 	if err != nil {
 		return net.Destination{}, newError("failed to control connection").Base(err)
 		return net.Destination{}, newError("failed to control connection").Base(err)