Selaa lähdekoodia

fix getting sys fd

Darien Raymond 8 vuotta sitten
vanhempi
commit
40f890e638

+ 1 - 1
transport/internet/internal/sysfd.go

@@ -6,7 +6,7 @@ import (
 )
 
 var (
-	errInvalidConn = newError("Invalid Connection.")
+	errInvalidConn = newError("not a net.Conn")
 )
 
 // GetSysFd returns the underlying fd of a connection.

+ 9 - 12
transport/internet/tcp/sockopt_linux.go

@@ -3,33 +3,30 @@
 package tcp
 
 import (
+	"net"
 	"syscall"
 
 	"v2ray.com/core/app/log"
-	"v2ray.com/core/common/net"
+	v2net "v2ray.com/core/common/net"
 	"v2ray.com/core/transport/internet"
+	"v2ray.com/core/transport/internet/internal"
 )
 
 const SO_ORIGINAL_DST = 80
 
-func GetOriginalDestination(conn internet.Connection) net.Destination {
-	tcpConn, ok := conn.(internet.SysFd)
-	if !ok {
-		log.Trace(newError("failed to get sys fd"))
-		return net.Destination{}
-	}
-	fd, err := tcpConn.SysFd()
+func GetOriginalDestination(conn internet.Connection) v2net.Destination {
+	fd, err := internal.GetSysFd(conn.(net.Conn))
 	if err != nil {
 		log.Trace(newError("failed to get original destination").Base(err))
-		return net.Destination{}
+		return v2net.Destination{}
 	}
 
 	addr, err := syscall.GetsockoptIPv6Mreq(fd, syscall.IPPROTO_IP, SO_ORIGINAL_DST)
 	if err != nil {
 		log.Trace(newError("failed to call getsockopt").Base(err))
-		return net.Destination{}
+		return v2net.Destination{}
 	}
-	ip := net.IPAddress(addr.Multiaddr[4:8])
+	ip := v2net.IPAddress(addr.Multiaddr[4:8])
 	port := uint16(addr.Multiaddr[2])<<8 + uint16(addr.Multiaddr[3])
-	return net.TCPDestination(ip, net.Port(port))
+	return v2net.TCPDestination(ip, v2net.Port(port))
 }

+ 28 - 0
transport/internet/tcp/sockopt_linux_test.go

@@ -0,0 +1,28 @@
+// +build linux
+
+package tcp_test
+
+import (
+	"context"
+	"testing"
+
+	"v2ray.com/core/testing/assert"
+	"v2ray.com/core/testing/servers/tcp"
+)
+
+func TestGetOriginalDestination(t *testing.T) {
+	assert := assert.On(t)
+
+	tcpServer := tcp.Server{
+		MsgProcessor: xor,
+	}
+	dest, err := tcpServer.Start()
+	assert.Error(err).IsNil()
+	defer tcpServer.Close()
+
+	conn, err := Dial(context.Background(), dest)
+	assert.Error(err).IsNil()
+
+	_, err := GetOriginalDestination(conn)
+	assert.String(err.Error()).Contains("failed to call getsockopt")
+}