|  | @@ -6,6 +6,7 @@ import (
 | 
											
												
													
														|  |  	"github.com/v2ray/v2ray-core/common/alloc"
 |  |  	"github.com/v2ray/v2ray-core/common/alloc"
 | 
											
												
													
														|  |  	"github.com/v2ray/v2ray-core/common/log"
 |  |  	"github.com/v2ray/v2ray-core/common/log"
 | 
											
												
													
														|  |  	v2net "github.com/v2ray/v2ray-core/common/net"
 |  |  	v2net "github.com/v2ray/v2ray-core/common/net"
 | 
											
												
													
														|  | 
 |  | +	"github.com/v2ray/v2ray-core/common/serial"
 | 
											
												
													
														|  |  	"github.com/v2ray/v2ray-core/transport"
 |  |  	"github.com/v2ray/v2ray-core/transport"
 | 
											
												
													
														|  |  )
 |  |  )
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -21,7 +22,7 @@ type Request struct {
 | 
											
												
													
														|  |  	OTA     bool
 |  |  	OTA     bool
 | 
											
												
													
														|  |  }
 |  |  }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -func ReadRequest(reader io.Reader) (*Request, error) {
 |  | 
 | 
											
												
													
														|  | 
 |  | +func ReadRequest(reader io.Reader, auth *Authenticator) (*Request, error) {
 | 
											
												
													
														|  |  	buffer := alloc.NewSmallBuffer()
 |  |  	buffer := alloc.NewSmallBuffer()
 | 
											
												
													
														|  |  	defer buffer.Release()
 |  |  	defer buffer.Release()
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -30,6 +31,7 @@ func ReadRequest(reader io.Reader) (*Request, error) {
 | 
											
												
													
														|  |  		log.Error("Shadowsocks: Failed to read address type: ", err)
 |  |  		log.Error("Shadowsocks: Failed to read address type: ", err)
 | 
											
												
													
														|  |  		return nil, transport.CorruptedPacket
 |  |  		return nil, transport.CorruptedPacket
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  | 
 |  | +	lenBuffer := 1
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	request := new(Request)
 |  |  	request := new(Request)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -39,43 +41,64 @@ func ReadRequest(reader io.Reader) (*Request, error) {
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  	switch addrType {
 |  |  	switch addrType {
 | 
											
												
													
														|  |  	case AddrTypeIPv4:
 |  |  	case AddrTypeIPv4:
 | 
											
												
													
														|  | -		_, err := io.ReadFull(reader, buffer.Value[:4])
 |  | 
 | 
											
												
													
														|  | 
 |  | +		_, err := io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+4])
 | 
											
												
													
														|  |  		if err != nil {
 |  |  		if err != nil {
 | 
											
												
													
														|  |  			log.Error("Shadowsocks: Failed to read IPv4 address: ", err)
 |  |  			log.Error("Shadowsocks: Failed to read IPv4 address: ", err)
 | 
											
												
													
														|  |  			return nil, transport.CorruptedPacket
 |  |  			return nil, transport.CorruptedPacket
 | 
											
												
													
														|  |  		}
 |  |  		}
 | 
											
												
													
														|  | -		request.Address = v2net.IPAddress(buffer.Value[:4])
 |  | 
 | 
											
												
													
														|  | 
 |  | +		request.Address = v2net.IPAddress(buffer.Value[lenBuffer : lenBuffer+4])
 | 
											
												
													
														|  | 
 |  | +		lenBuffer += 4
 | 
											
												
													
														|  |  	case AddrTypeIPv6:
 |  |  	case AddrTypeIPv6:
 | 
											
												
													
														|  | -		_, err := io.ReadFull(reader, buffer.Value[:16])
 |  | 
 | 
											
												
													
														|  | 
 |  | +		_, err := io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+16])
 | 
											
												
													
														|  |  		if err != nil {
 |  |  		if err != nil {
 | 
											
												
													
														|  |  			log.Error("Shadowsocks: Failed to read IPv6 address: ", err)
 |  |  			log.Error("Shadowsocks: Failed to read IPv6 address: ", err)
 | 
											
												
													
														|  |  			return nil, transport.CorruptedPacket
 |  |  			return nil, transport.CorruptedPacket
 | 
											
												
													
														|  |  		}
 |  |  		}
 | 
											
												
													
														|  | -		request.Address = v2net.IPAddress(buffer.Value[:16])
 |  | 
 | 
											
												
													
														|  | 
 |  | +		request.Address = v2net.IPAddress(buffer.Value[lenBuffer : lenBuffer+16])
 | 
											
												
													
														|  | 
 |  | +		lenBuffer += 16
 | 
											
												
													
														|  |  	case AddrTypeDomain:
 |  |  	case AddrTypeDomain:
 | 
											
												
													
														|  | -		_, err := io.ReadFull(reader, buffer.Value[:1])
 |  | 
 | 
											
												
													
														|  | 
 |  | +		_, err := io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+1])
 | 
											
												
													
														|  |  		if err != nil {
 |  |  		if err != nil {
 | 
											
												
													
														|  |  			log.Error("Shadowsocks: Failed to read domain lenth: ", err)
 |  |  			log.Error("Shadowsocks: Failed to read domain lenth: ", err)
 | 
											
												
													
														|  |  			return nil, transport.CorruptedPacket
 |  |  			return nil, transport.CorruptedPacket
 | 
											
												
													
														|  |  		}
 |  |  		}
 | 
											
												
													
														|  | -		domainLength := int(buffer.Value[0])
 |  | 
 | 
											
												
													
														|  | -		_, err = io.ReadFull(reader, buffer.Value[:domainLength])
 |  | 
 | 
											
												
													
														|  | 
 |  | +		domainLength := int(buffer.Value[lenBuffer])
 | 
											
												
													
														|  | 
 |  | +		lenBuffer++
 | 
											
												
													
														|  | 
 |  | +		_, err = io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+domainLength])
 | 
											
												
													
														|  |  		if err != nil {
 |  |  		if err != nil {
 | 
											
												
													
														|  |  			log.Error("Shadowsocks: Failed to read domain: ", err)
 |  |  			log.Error("Shadowsocks: Failed to read domain: ", err)
 | 
											
												
													
														|  |  			return nil, transport.CorruptedPacket
 |  |  			return nil, transport.CorruptedPacket
 | 
											
												
													
														|  |  		}
 |  |  		}
 | 
											
												
													
														|  | -		request.Address = v2net.DomainAddress(string(buffer.Value[:domainLength]))
 |  | 
 | 
											
												
													
														|  | 
 |  | +		request.Address = v2net.DomainAddress(string(buffer.Value[lenBuffer : lenBuffer+domainLength]))
 | 
											
												
													
														|  | 
 |  | +		lenBuffer += domainLength
 | 
											
												
													
														|  |  	default:
 |  |  	default:
 | 
											
												
													
														|  |  		log.Error("Shadowsocks: Unknown address type: ", addrType)
 |  |  		log.Error("Shadowsocks: Unknown address type: ", addrType)
 | 
											
												
													
														|  |  		return nil, transport.CorruptedPacket
 |  |  		return nil, transport.CorruptedPacket
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -	_, err = io.ReadFull(reader, buffer.Value[:2])
 |  | 
 | 
											
												
													
														|  | 
 |  | +	_, err = io.ReadFull(reader, buffer.Value[lenBuffer:lenBuffer+2])
 | 
											
												
													
														|  |  	if err != nil {
 |  |  	if err != nil {
 | 
											
												
													
														|  |  		log.Error("Shadowsocks: Failed to read port: ", err)
 |  |  		log.Error("Shadowsocks: Failed to read port: ", err)
 | 
											
												
													
														|  |  		return nil, transport.CorruptedPacket
 |  |  		return nil, transport.CorruptedPacket
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -	request.Port = v2net.PortFromBytes(buffer.Value[:2])
 |  | 
 | 
											
												
													
														|  | 
 |  | +	request.Port = v2net.PortFromBytes(buffer.Value[lenBuffer : lenBuffer+2])
 | 
											
												
													
														|  | 
 |  | +	lenBuffer += 2
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	if request.OTA {
 | 
											
												
													
														|  | 
 |  | +		authBytes := buffer.Value[lenBuffer : lenBuffer+auth.AuthSize()]
 | 
											
												
													
														|  | 
 |  | +		_, err = io.ReadFull(reader, authBytes)
 | 
											
												
													
														|  | 
 |  | +		if err != nil {
 | 
											
												
													
														|  | 
 |  | +			log.Error("Shadowsocks: Failed to read OTA: ", err)
 | 
											
												
													
														|  | 
 |  | +			return nil, transport.CorruptedPacket
 | 
											
												
													
														|  | 
 |  | +		}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +		actualAuth := auth.Authenticate(nil, buffer.Value[0:lenBuffer])
 | 
											
												
													
														|  | 
 |  | +		if !serial.BytesLiteral(actualAuth).Equals(serial.BytesLiteral(authBytes)) {
 | 
											
												
													
														|  | 
 |  | +			log.Error("Shadowsocks: Invalid OTA: ", actualAuth)
 | 
											
												
													
														|  | 
 |  | +			return nil, transport.CorruptedPacket
 | 
											
												
													
														|  | 
 |  | +		}
 | 
											
												
													
														|  | 
 |  | +	}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |  	return request, nil
 |  |  	return request, nil
 | 
											
												
													
														|  |  }
 |  |  }
 |