|  | @@ -9,6 +9,8 @@ import (
 | 
	
		
			
				|  |  |  	"sync"
 | 
	
		
			
				|  |  |  	"time"
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	"v2ray.com/core/common/dice"
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	"golang.org/x/crypto/chacha20poly1305"
 | 
	
		
			
				|  |  |  	"v2ray.com/core/common"
 | 
	
		
			
				|  |  |  	"v2ray.com/core/common/bitmask"
 | 
	
	
		
			
				|  | @@ -103,6 +105,44 @@ func NewServerSession(validator protocol.UserValidator, sessionHistory *SessionH
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +func readAddress(buffer *buf.Buffer, reader io.Reader) (net.Address, net.Port, error) {
 | 
	
		
			
				|  |  | +	var address net.Address
 | 
	
		
			
				|  |  | +	var port net.Port
 | 
	
		
			
				|  |  | +	if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 3)); err != nil {
 | 
	
		
			
				|  |  | +		return address, port, newError("failed to read port and address type").Base(err)
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	port = net.PortFromBytes(buffer.BytesRange(-3, -1))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	addressType := protocol.AddressType(buffer.Byte(buffer.Len() - 1))
 | 
	
		
			
				|  |  | +	switch addressType {
 | 
	
		
			
				|  |  | +	case protocol.AddressTypeIPv4:
 | 
	
		
			
				|  |  | +		if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
 | 
	
		
			
				|  |  | +			return address, port, newError("failed to read IPv4 address").Base(err)
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +		address = net.IPAddress(buffer.BytesFrom(-4))
 | 
	
		
			
				|  |  | +	case protocol.AddressTypeIPv6:
 | 
	
		
			
				|  |  | +		if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil {
 | 
	
		
			
				|  |  | +			return address, port, newError("failed to read IPv6 address").Base(err)
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +		address = net.IPAddress(buffer.BytesFrom(-16))
 | 
	
		
			
				|  |  | +	case protocol.AddressTypeDomain:
 | 
	
		
			
				|  |  | +		if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
 | 
	
		
			
				|  |  | +			return address, port, newError("failed to read domain address").Base(err)
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +		domainLength := int(buffer.Byte(buffer.Len() - 1))
 | 
	
		
			
				|  |  | +		if domainLength == 0 {
 | 
	
		
			
				|  |  | +			return address, port, newError("zero length domain")
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +		if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil {
 | 
	
		
			
				|  |  | +			return address, port, newError("failed to read domain address").Base(err)
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +		address = net.DomainAddress(string(buffer.BytesFrom(-domainLength)))
 | 
	
		
			
				|  |  | +	default:
 | 
	
		
			
				|  |  | +		return address, port, newError("invalid address type", addressType)
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	return address, port, nil
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) {
 | 
	
		
			
				|  |  |  	buffer := buf.New()
 | 
	
		
			
				|  |  |  	defer buffer.Release()
 | 
	
	
		
			
				|  | @@ -128,7 +168,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 | 
	
		
			
				|  |  |  	aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv)
 | 
	
		
			
				|  |  |  	decryptor := crypto.NewCryptionReader(aesStream, reader)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	if err := buffer.Reset(buf.ReadFullFrom(decryptor, 41)); err != nil {
 | 
	
		
			
				|  |  | +	if err := buffer.Reset(buf.ReadFullFrom(decryptor, 38)); err != nil {
 | 
	
		
			
				|  |  |  		return nil, newError("failed to read request header").Base(err)
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -137,10 +177,6 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 | 
	
		
			
				|  |  |  		Version: buffer.Byte(0),
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	if request.Version != Version {
 | 
	
		
			
				|  |  | -		return nil, newError("invalid protocol version ", request.Version)
 | 
	
		
			
				|  |  | -	}
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |  	s.requestBodyIV = append([]byte(nil), buffer.BytesRange(1, 17)...)   // 16 bytes
 | 
	
		
			
				|  |  |  	s.requestBodyKey = append([]byte(nil), buffer.BytesRange(17, 33)...) // 16 bytes
 | 
	
		
			
				|  |  |  	var sid sessionId
 | 
	
	
		
			
				|  | @@ -159,33 +195,28 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 | 
	
		
			
				|  |  |  	// 1 bytes reserved
 | 
	
		
			
				|  |  |  	request.Command = protocol.RequestCommand(buffer.Byte(37))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	if request.Command != protocol.RequestCommandMux {
 | 
	
		
			
				|  |  | -		request.Port = net.PortFromBytes(buffer.BytesRange(38, 40))
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -		switch protocol.AddressType(buffer.Byte(40)) {
 | 
	
		
			
				|  |  | -		case protocol.AddressTypeIPv4:
 | 
	
		
			
				|  |  | -			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 4)); err != nil {
 | 
	
		
			
				|  |  | -				return nil, newError("failed to read IPv4 address").Base(err)
 | 
	
		
			
				|  |  | -			}
 | 
	
		
			
				|  |  | -			request.Address = net.IPAddress(buffer.BytesFrom(-4))
 | 
	
		
			
				|  |  | -		case protocol.AddressTypeIPv6:
 | 
	
		
			
				|  |  | -			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 16)); err != nil {
 | 
	
		
			
				|  |  | -				return nil, newError("failed to read IPv6 address").Base(err)
 | 
	
		
			
				|  |  | -			}
 | 
	
		
			
				|  |  | -			request.Address = net.IPAddress(buffer.BytesFrom(-16))
 | 
	
		
			
				|  |  | -		case protocol.AddressTypeDomain:
 | 
	
		
			
				|  |  | -			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 1)); err != nil {
 | 
	
		
			
				|  |  | -				return nil, newError("failed to read domain address").Base(err)
 | 
	
		
			
				|  |  | -			}
 | 
	
		
			
				|  |  | -			domainLength := int(buffer.Byte(buffer.Len() - 1))
 | 
	
		
			
				|  |  | -			if domainLength == 0 {
 | 
	
		
			
				|  |  | -				return nil, newError("zero length domain").Base(err)
 | 
	
		
			
				|  |  | -			}
 | 
	
		
			
				|  |  | -			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, domainLength)); err != nil {
 | 
	
		
			
				|  |  | -				return nil, newError("failed to read domain address").Base(err)
 | 
	
		
			
				|  |  | -			}
 | 
	
		
			
				|  |  | -			request.Address = net.DomainAddress(string(buffer.BytesFrom(-domainLength)))
 | 
	
		
			
				|  |  | +	invalidRequest := false
 | 
	
		
			
				|  |  | +	switch request.Command {
 | 
	
		
			
				|  |  | +	case protocol.RequestCommandMux:
 | 
	
		
			
				|  |  | +		request.Address = net.DomainAddress("v1.mux.cool")
 | 
	
		
			
				|  |  | +		request.Port = 0
 | 
	
		
			
				|  |  | +	case protocol.RequestCommandTCP, protocol.RequestCommandUDP:
 | 
	
		
			
				|  |  | +		if addr, port, err := readAddress(buffer, decryptor); err == nil {
 | 
	
		
			
				|  |  | +			request.Address = addr
 | 
	
		
			
				|  |  | +			request.Port = port
 | 
	
		
			
				|  |  | +		} else {
 | 
	
		
			
				|  |  | +			invalidRequest = true
 | 
	
		
			
				|  |  | +			newError("failed to read address").Base(err).WriteToLog()
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  | +	default:
 | 
	
		
			
				|  |  | +		invalidRequest = true
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	if invalidRequest {
 | 
	
		
			
				|  |  | +		randomLen := dice.Roll(32)
 | 
	
		
			
				|  |  | +		// Read random number of bytes for prevent detection.
 | 
	
		
			
				|  |  | +		buffer.AppendSupplier(buf.ReadFullFrom(decryptor, randomLen))
 | 
	
		
			
				|  |  | +		return nil, newError("invalid request")
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	if padingLen > 0 {
 |