Browse Source

move parseHost to http protocol

Darien Raymond 7 years ago
parent
commit
7e37d141e2

+ 22 - 0
common/protocol/http/headers.go

@@ -2,6 +2,7 @@ package http
 
 import (
 	"net/http"
+	"strconv"
 	"strings"
 
 	"v2ray.com/core/common/net"
@@ -42,3 +43,24 @@ func RemoveHopByHopHeaders(header http.Header) {
 		header.Del(strings.TrimSpace(h))
 	}
 }
+
+// ParseHost splits host and port from a raw string. Default port is used when raw string doesn't contain port.
+func ParseHost(rawHost string, defaultPort net.Port) (net.Destination, error) {
+	port := defaultPort
+	host, rawPort, err := net.SplitHostPort(rawHost)
+	if err != nil {
+		if addrError, ok := err.(*net.AddrError); ok && strings.Contains(addrError.Err, "missing port") {
+			host = rawHost
+		} else {
+			return net.Destination{}, err
+		}
+	} else if len(rawPort) > 0 {
+		intPort, err := strconv.Atoi(rawPort)
+		if err != nil {
+			return net.Destination{}, err
+		}
+		port = net.Port(intPort)
+	}
+
+	return net.TCPDestination(net.ParseAddress(host), port), nil
+}

+ 40 - 0
common/protocol/http/headers_test.go

@@ -6,6 +6,8 @@ import (
 	"strings"
 	"testing"
 
+	"v2ray.com/core/common/net"
+
 	. "v2ray.com/core/common/protocol/http"
 	. "v2ray.com/ext/assert"
 )
@@ -53,3 +55,41 @@ Accept-Language: de,en;q=0.7,en-us;q=0.3
 	assert(req.Header.Get("Proxy-Connection"), IsEmpty)
 	assert(req.Header.Get("Proxy-Authenticate"), IsEmpty)
 }
+
+func TestParseHost(t *testing.T) {
+	testCases := []struct {
+		RawHost     string
+		DefaultPort net.Port
+		Destination net.Destination
+		Error       bool
+	}{
+		{
+			RawHost:     "v2ray.com:80",
+			DefaultPort: 443,
+			Destination: net.TCPDestination(net.DomainAddress("v2ray.com"), 80),
+		},
+		{
+			RawHost:     "tls.v2ray.com",
+			DefaultPort: 443,
+			Destination: net.TCPDestination(net.DomainAddress("tls.v2ray.com"), 443),
+		},
+		{
+			RawHost:     "[2401:1bc0:51f0:ec08::1]:80",
+			DefaultPort: 443,
+			Destination: net.TCPDestination(net.ParseAddress("[2401:1bc0:51f0:ec08::1]"), 80),
+		},
+	}
+
+	for _, testCase := range testCases {
+		dest, err := ParseHost(testCase.RawHost, testCase.DefaultPort)
+		if testCase.Error {
+			if err == nil {
+				t.Error("for test case: ", testCase.RawHost, " expected error, but actually nil")
+			}
+		} else {
+			if dest != testCase.Destination {
+				t.Error("for test case: ", testCase.RawHost, " expected host: ", testCase.Destination.String(), " but got ", dest.String())
+			}
+		}
+	}
+}

+ 7 - 3
common/protocol/http/sniff.go

@@ -6,6 +6,7 @@ import (
 	"strings"
 
 	"v2ray.com/core/common"
+	"v2ray.com/core/common/net"
 )
 
 type version byte
@@ -75,10 +76,13 @@ func SniffHTTP(b []byte) (*SniffHeader, error) {
 			continue
 		}
 		key := strings.ToLower(string(parts[0]))
-		value := strings.ToLower(string(bytes.Trim(parts[1], " ")))
 		if key == "host" {
-			domain := strings.Split(value, ":")
-			sh.host = strings.TrimSpace(domain[0])
+			rawHost := strings.ToLower(string(bytes.TrimSpace(parts[1])))
+			dest, err := ParseHost(rawHost, net.Port(80))
+			if err != nil {
+				return nil, err
+			}
+			sh.host = dest.Address.String()
 		}
 	}
 

+ 1 - 22
proxy/http/server.go

@@ -6,7 +6,6 @@ import (
 	"encoding/base64"
 	"io"
 	"net/http"
-	"strconv"
 	"strings"
 	"time"
 
@@ -56,26 +55,6 @@ func (*Server) Network() []net.Network {
 	return []net.Network{net.Network_TCP}
 }
 
-func parseHost(rawHost string, defaultPort net.Port) (net.Destination, error) {
-	port := defaultPort
-	host, rawPort, err := net.SplitHostPort(rawHost)
-	if err != nil {
-		if addrError, ok := err.(*net.AddrError); ok && strings.Contains(addrError.Err, "missing port") {
-			host = rawHost
-		} else {
-			return net.Destination{}, err
-		}
-	} else if len(rawPort) > 0 {
-		intPort, err := strconv.Atoi(rawPort)
-		if err != nil {
-			return net.Destination{}, err
-		}
-		port = net.Port(intPort)
-	}
-
-	return net.TCPDestination(net.ParseAddress(host), port), nil
-}
-
 func isTimeout(err error) bool {
 	nerr, ok := errors.Cause(err).(net.Error)
 	return ok && nerr.Timeout()
@@ -139,7 +118,7 @@ Start:
 	if len(host) == 0 {
 		host = request.URL.Host
 	}
-	dest, err := parseHost(host, defaultPort)
+	dest, err := http_proto.ParseHost(host, defaultPort)
 	if err != nil {
 		return newError("malformed proxy host: ", host).AtWarning().Base(err)
 	}