소스 검색

fix domain length issue for all proxies

Darien Raymond 8 년 전
부모
커밋
26818a2602

+ 3 - 4
app/proxyman/mux/frame.go

@@ -81,11 +81,10 @@ func (f FrameMetadata) AsSupplier() buf.Supplier {
 				length += 17
 			case net.AddressFamilyDomain:
 				domain := addr.Domain()
-				nDomain := len(domain)
-				if nDomain > 256 {
-					nDomain = 256
-					domain = domain[:256]
+				if protocol.IsDomainTooLong(domain) {
+					return 0, newError("domain name too long: ", domain)
 				}
+				nDomain := len(domain)
 				b = append(b, byte(protocol.AddressTypeDomain), byte(nDomain))
 				b = append(b, domain...)
 				length += nDomain + 2

+ 4 - 0
common/protocol/headers.go

@@ -97,3 +97,7 @@ func (sc *SecurityConfig) AsSecurity() Security {
 	}
 	return NormSecurity(Security(sc.Type))
 }
+
+func IsDomainTooLong(domain string) bool {
+	return len(domain) > 256
+}

+ 9 - 4
proxy/shadowsocks/protocol.go

@@ -5,6 +5,7 @@ import (
 	"crypto/rand"
 	"io"
 
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/bitmask"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/crypto"
@@ -160,19 +161,23 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri
 		header.AppendBytes(AddrTypeIPv6)
 		header.Append([]byte(request.Address.IP()))
 	case net.AddressFamilyDomain:
-		header.AppendBytes(AddrTypeDomain, byte(len(request.Address.Domain())))
-		header.Append([]byte(request.Address.Domain()))
+		domain := request.Address.Domain()
+		if protocol.IsDomainTooLong(domain) {
+			return nil, newError("domain name too long: ", domain)
+		}
+		header.AppendBytes(AddrTypeDomain, byte(len(domain)))
+		common.Must(header.AppendSupplier(serial.WriteString(domain)))
 	default:
 		return nil, newError("unsupported address type: ", request.Address.Family())
 	}
 
-	header.AppendSupplier(serial.WriteUint16(uint16(request.Port)))
+	common.Must(header.AppendSupplier(serial.WriteUint16(uint16(request.Port))))
 
 	if request.Option.Has(RequestOptionOneTimeAuth) {
 		header.SetByte(0, header.Byte(0)|0x10)
 
 		authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv))
-		header.AppendSupplier(authenticator.Authenticate(header.Bytes()))
+		common.Must(header.AppendSupplier(authenticator.Authenticate(header.Bytes())))
 	}
 
 	_, err = writer.Write(header.Bytes())

+ 5 - 5
proxy/socks/protocol.go

@@ -3,6 +3,7 @@ package socks
 import (
 	"io"
 
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
@@ -253,14 +254,13 @@ func appendAddress(buffer *buf.Buffer, address net.Address, port net.Port) error
 		buffer.AppendBytes(0x04)
 		buffer.Append(address.IP())
 	case net.AddressFamilyDomain:
-		n := byte(len(address.Domain()))
-		if int(n) != len(address.Domain()) {
-			return newError("Super long domain is not supported in Socks protocol. ", address.Domain())
+		if protocol.IsDomainTooLong(address.Domain()) {
+			return newError("Super long domain is not supported in Socks protocol: ", address.Domain())
 		}
 		buffer.AppendBytes(0x03, byte(len(address.Domain())))
-		buffer.AppendSupplier(serial.WriteString(address.Domain()))
+		common.Must(buffer.AppendSupplier(serial.WriteString(address.Domain())))
 	}
-	buffer.AppendSupplier(serial.WriteUint16(port.Value()))
+	common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value())))
 	return nil
 }
 

+ 10 - 4
proxy/vmess/encoding/client.go

@@ -60,12 +60,12 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession {
 	return session
 }
 
-func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) {
+func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error {
 	timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)()
 	account, err := header.User.GetTypedAccount()
 	if err != nil {
 		log.Trace(newError("failed to get user account: ", err).AtError())
-		return
+		return nil
 	}
 	idHash := c.idHash(account.(*vmess.InternalAccount).AnyValidID().Bytes())
 	common.Must2(idHash.Write(timestamp.Bytes(nil)))
@@ -95,8 +95,13 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
 			buffer = append(buffer, byte(protocol.AddressTypeIPv6))
 			buffer = append(buffer, header.Address.IP()...)
 		case net.AddressFamilyDomain:
-			buffer = append(buffer, byte(protocol.AddressTypeDomain), byte(len(header.Address.Domain())))
-			buffer = append(buffer, header.Address.Domain()...)
+			domain := header.Address.Domain()
+			if protocol.IsDomainTooLong(domain) {
+				return newError("long domain not supported: ", domain)
+			}
+			nDomain := len(domain)
+			buffer = append(buffer, byte(protocol.AddressTypeDomain), byte(nDomain))
+			buffer = append(buffer, domain...)
 		}
 	}
 
@@ -117,6 +122,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
 	aesStream := crypto.NewAesEncryptionStream(account.(*vmess.InternalAccount).ID.CmdKey(), iv)
 	aesStream.XORKeyStream(buffer, buffer)
 	common.Must2(writer.Write(buffer))
+	return nil
 }
 
 func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, writer io.Writer) buf.Writer {

+ 2 - 1
proxy/vmess/encoding/encoding_test.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"testing"
 
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
@@ -38,7 +39,7 @@ func TestRequestSerialization(t *testing.T) {
 
 	buffer := buf.New()
 	client := NewClientSession(protocol.DefaultIDHash)
-	client.EncodeRequestHeader(expectedRequest, buffer)
+	common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
 
 	buffer2 := buf.New()
 	buffer2.Append(buffer.Bytes())

+ 3 - 1
proxy/vmess/outbound/outbound.go

@@ -108,7 +108,9 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
 
 	requestDone := signal.ExecuteAsync(func() error {
 		writer := buf.NewBufferedWriter(conn)
-		session.EncodeRequestHeader(request, writer)
+		if err := session.EncodeRequestHeader(request, writer); err != nil {
+			return newError("failed to encode request").Base(err).AtWarning()
+		}
 
 		bodyWriter := session.EncodeRequestBody(request, writer)
 		firstPayload, err := input.ReadTimeout(time.Millisecond * 500)