Browse Source

Set correct remote addr for h2 request. Fixes #1068.

Darien Raymond 7 years ago
parent
commit
f1f4a796cf
3 changed files with 96 additions and 1 deletions
  1. 32 0
      common/net/destination.go
  2. 51 0
      common/net/destination_test.go
  3. 13 1
      transport/internet/http/hub.go

+ 32 - 0
common/net/destination.go

@@ -2,6 +2,7 @@ package net
 
 import (
 	"net"
+	"strings"
 )
 
 // Destination represents a network destination including address and protocol (tcp / udp).
@@ -26,6 +27,37 @@ func DestinationFromAddr(addr net.Addr) Destination {
 	}
 }
 
+// ParseDestination converts a destination from its string presentation.
+func ParseDestination(dest string) (Destination, error) {
+	d := Destination{
+		Address: AnyIP,
+		Port:    Port(0),
+	}
+	if strings.HasPrefix(dest, "tcp:") {
+		d.Network = Network_TCP
+		dest = dest[4:]
+	} else if strings.HasPrefix(dest, "udp:") {
+		d.Network = Network_UDP
+		dest = dest[4:]
+	}
+
+	hstr, pstr, err := SplitHostPort(dest)
+	if err != nil {
+		return d, err
+	}
+	if len(hstr) > 0 {
+		d.Address = ParseAddress(hstr)
+	}
+	if len(pstr) > 0 {
+		port, err := PortFromString(pstr)
+		if err != nil {
+			return d, err
+		}
+		d.Port = port
+	}
+	return d, nil
+}
+
 // TCPDestination creates a TCP destination with given address
 func TCPDestination(address Address, port Port) Destination {
 	return Destination{

+ 51 - 0
common/net/destination_test.go

@@ -25,3 +25,54 @@ func TestUDPDestination(t *testing.T) {
 	assert(dest, IsUDP)
 	assert(dest.String(), Equals, "udp:[2001:4860:4860::8888]:53")
 }
+
+func TestDestinationParse(t *testing.T) {
+	assert := With(t)
+
+	cases := []struct {
+		Input  string
+		Output Destination
+		Error  bool
+	}{
+		{
+			Input:  "tcp:127.0.0.1:80",
+			Output: TCPDestination(LocalHostIP, Port(80)),
+		},
+		{
+			Input:  "udp:8.8.8.8:53",
+			Output: UDPDestination(IPAddress([]byte{8, 8, 8, 8}), Port(53)),
+		},
+		{
+			Input: "8.8.8.8:53",
+			Output: Destination{
+				Address: IPAddress([]byte{8, 8, 8, 8}),
+				Port:    Port(53),
+			},
+		},
+		{
+			Input: ":53",
+			Output: Destination{
+				Address: AnyIP,
+				Port:    Port(53),
+			},
+		},
+		{
+			Input: "8.8.8.8",
+			Error: true,
+		},
+		{
+			Input: "8.8.8.8:http",
+			Error: true,
+		},
+	}
+
+	for _, testcase := range cases {
+		d, err := ParseDestination(testcase.Input)
+		if !testcase.Error {
+			assert(err, IsNil)
+			assert(d, Equals, testcase.Output)
+		} else {
+			assert(err, IsNotNil)
+		}
+	}
+}

+ 13 - 1
transport/internet/http/hub.go

@@ -63,13 +63,25 @@ func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request)
 	if f, ok := writer.(http.Flusher); ok {
 		f.Flush()
 	}
+
+	remoteAddr := l.Addr()
+	dest, err := net.ParseDestination(request.RemoteAddr)
+	if err != nil {
+		newError("failed to parse request remote addr: ", request.RemoteAddr).Base(err).WriteToLog()
+	} else {
+		remoteAddr = &net.TCPAddr{
+			IP:   dest.Address.IP(),
+			Port: int(dest.Port),
+		}
+	}
+
 	done := signal.NewDone()
 	conn := net.NewConnection(
 		net.ConnectionOutput(request.Body),
 		net.ConnectionInput(flushWriter{w: writer, d: done}),
 		net.ConnectionOnClose(common.NewChainedClosable(done, request.Body)),
 		net.ConnectionLocalAddr(l.Addr()),
-		net.ConnectionRemoteAddr(l.Addr()),
+		net.ConnectionRemoteAddr(remoteAddr),
 	)
 	l.handler(conn)
 	<-done.Wait()