|  | @@ -115,14 +115,14 @@ func NewServerSession(validator protocol.UserValidator, sessionHistory *SessionH
 | 
											
												
													
														|  |  }
 |  |  }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) {
 |  |  func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) {
 | 
											
												
													
														|  | -	buffer := make([]byte, 512)
 |  | 
 | 
											
												
													
														|  | 
 |  | +	buffer := buf.New()
 | 
											
												
													
														|  | 
 |  | +	defer buffer.Release()
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -	_, err := io.ReadFull(reader, buffer[:protocol.IDBytesLen])
 |  | 
 | 
											
												
													
														|  | -	if err != nil {
 |  | 
 | 
											
												
													
														|  | 
 |  | +	if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, protocol.IDBytesLen)); err != nil {
 | 
											
												
													
														|  |  		return nil, newError("failed to read request header").Base(err)
 |  |  		return nil, newError("failed to read request header").Base(err)
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -	user, timestamp, valid := s.userValidator.Get(buffer[:protocol.IDBytesLen])
 |  | 
 | 
											
												
													
														|  | 
 |  | +	user, timestamp, valid := s.userValidator.Get(buffer.Bytes())
 | 
											
												
													
														|  |  	if !valid {
 |  |  	if !valid {
 | 
											
												
													
														|  |  		return nil, newError("invalid user")
 |  |  		return nil, newError("invalid user")
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
										
											
												
													
														|  | @@ -139,23 +139,21 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 | 
											
												
													
														|  |  	aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv)
 |  |  	aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv)
 | 
											
												
													
														|  |  	decryptor := crypto.NewCryptionReader(aesStream, reader)
 |  |  	decryptor := crypto.NewCryptionReader(aesStream, reader)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -	nBytes, err := io.ReadFull(decryptor, buffer[:41])
 |  | 
 | 
											
												
													
														|  | -	if err != nil {
 |  | 
 | 
											
												
													
														|  | 
 |  | +	if err := buffer.Reset(buf.ReadFullFrom(decryptor, 41)); err != nil {
 | 
											
												
													
														|  |  		return nil, newError("failed to read request header").Base(err)
 |  |  		return nil, newError("failed to read request header").Base(err)
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  | -	bufferLen := nBytes
 |  | 
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	request := &protocol.RequestHeader{
 |  |  	request := &protocol.RequestHeader{
 | 
											
												
													
														|  |  		User:    user,
 |  |  		User:    user,
 | 
											
												
													
														|  | -		Version: buffer[0],
 |  | 
 | 
											
												
													
														|  | 
 |  | +		Version: buffer.Byte(0),
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	if request.Version != Version {
 |  |  	if request.Version != Version {
 | 
											
												
													
														|  |  		return nil, newError("invalid protocol version ", request.Version)
 |  |  		return nil, newError("invalid protocol version ", request.Version)
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -	s.requestBodyIV = append([]byte(nil), buffer[1:17]...)   // 16 bytes
 |  | 
 | 
											
												
													
														|  | -	s.requestBodyKey = append([]byte(nil), buffer[17:33]...) // 16 bytes
 |  | 
 | 
											
												
													
														|  | 
 |  | +	s.requestBodyIV = append([]byte(nil), buffer.BytesRange(1, 17)...)   // 16 bytes
 | 
											
												
													
														|  | 
 |  | +	s.requestBodyKey = append([]byte(nil), buffer.BytesRange(17, 33)...) // 16 bytes
 | 
											
												
													
														|  |  	var sid sessionId
 |  |  	var sid sessionId
 | 
											
												
													
														|  |  	copy(sid.user[:], vmessAccount.ID.Bytes())
 |  |  	copy(sid.user[:], vmessAccount.ID.Bytes())
 | 
											
												
													
														|  |  	copy(sid.key[:], s.requestBodyKey)
 |  |  	copy(sid.key[:], s.requestBodyKey)
 | 
											
										
											
												
													
														|  | @@ -165,66 +163,56 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  	s.sessionHistory.add(sid)
 |  |  	s.sessionHistory.add(sid)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -	s.responseHeader = buffer[33]             // 1 byte
 |  | 
 | 
											
												
													
														|  | -	request.Option = bitmask.Byte(buffer[34]) // 1 byte
 |  | 
 | 
											
												
													
														|  | -	padingLen := int(buffer[35] >> 4)
 |  | 
 | 
											
												
													
														|  | -	request.Security = protocol.NormSecurity(protocol.Security(buffer[35] & 0x0F))
 |  | 
 | 
											
												
													
														|  | 
 |  | +	s.responseHeader = buffer.Byte(33)             // 1 byte
 | 
											
												
													
														|  | 
 |  | +	request.Option = bitmask.Byte(buffer.Byte(34)) // 1 byte
 | 
											
												
													
														|  | 
 |  | +	padingLen := int(buffer.Byte(35) >> 4)
 | 
											
												
													
														|  | 
 |  | +	request.Security = protocol.NormSecurity(protocol.Security(buffer.Byte(35) & 0x0F))
 | 
											
												
													
														|  |  	// 1 bytes reserved
 |  |  	// 1 bytes reserved
 | 
											
												
													
														|  | -	request.Command = protocol.RequestCommand(buffer[37])
 |  | 
 | 
											
												
													
														|  | 
 |  | +	request.Command = protocol.RequestCommand(buffer.Byte(37))
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	if request.Command != protocol.RequestCommandMux {
 |  |  	if request.Command != protocol.RequestCommandMux {
 | 
											
												
													
														|  | -		request.Port = net.PortFromBytes(buffer[38:40])
 |  | 
 | 
											
												
													
														|  | 
 |  | +		request.Port = net.PortFromBytes(buffer.BytesRange(38, 40))
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -		switch protocol.AddressType(buffer[40]) {
 |  | 
 | 
											
												
													
														|  | 
 |  | +		switch protocol.AddressType(buffer.Byte(40)) {
 | 
											
												
													
														|  |  		case protocol.AddressTypeIPv4:
 |  |  		case protocol.AddressTypeIPv4:
 | 
											
												
													
														|  | -			_, err = io.ReadFull(decryptor, buffer[41:45]) // 4 bytes
 |  | 
 | 
											
												
													
														|  | -			bufferLen += 4
 |  | 
 | 
											
												
													
														|  | -			if err != nil {
 |  | 
 | 
											
												
													
														|  | 
 |  | +			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 4)); err != nil {
 | 
											
												
													
														|  |  				return nil, newError("failed to read IPv4 address").Base(err)
 |  |  				return nil, newError("failed to read IPv4 address").Base(err)
 | 
											
												
													
														|  |  			}
 |  |  			}
 | 
											
												
													
														|  | -			request.Address = net.IPAddress(buffer[41:45])
 |  | 
 | 
											
												
													
														|  | 
 |  | +			request.Address = net.IPAddress(buffer.BytesFrom(-4))
 | 
											
												
													
														|  |  		case protocol.AddressTypeIPv6:
 |  |  		case protocol.AddressTypeIPv6:
 | 
											
												
													
														|  | -			_, err = io.ReadFull(decryptor, buffer[41:57]) // 16 bytes
 |  | 
 | 
											
												
													
														|  | -			bufferLen += 16
 |  | 
 | 
											
												
													
														|  | -			if err != nil {
 |  | 
 | 
											
												
													
														|  | 
 |  | +			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 16)); err != nil {
 | 
											
												
													
														|  |  				return nil, newError("failed to read IPv6 address").Base(err)
 |  |  				return nil, newError("failed to read IPv6 address").Base(err)
 | 
											
												
													
														|  |  			}
 |  |  			}
 | 
											
												
													
														|  | -			request.Address = net.IPAddress(buffer[41:57])
 |  | 
 | 
											
												
													
														|  | 
 |  | +			request.Address = net.IPAddress(buffer.BytesFrom(-16))
 | 
											
												
													
														|  |  		case protocol.AddressTypeDomain:
 |  |  		case protocol.AddressTypeDomain:
 | 
											
												
													
														|  | -			_, err = io.ReadFull(decryptor, buffer[41:42])
 |  | 
 | 
											
												
													
														|  | -			if err != nil {
 |  | 
 | 
											
												
													
														|  | 
 |  | +			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 1)); err != nil {
 | 
											
												
													
														|  |  				return nil, newError("failed to read domain address").Base(err)
 |  |  				return nil, newError("failed to read domain address").Base(err)
 | 
											
												
													
														|  |  			}
 |  |  			}
 | 
											
												
													
														|  | -			domainLength := int(buffer[41])
 |  | 
 | 
											
												
													
														|  | 
 |  | +			domainLength := int(buffer.Byte(buffer.Len() - 1))
 | 
											
												
													
														|  |  			if domainLength == 0 {
 |  |  			if domainLength == 0 {
 | 
											
												
													
														|  |  				return nil, newError("zero length domain").Base(err)
 |  |  				return nil, newError("zero length domain").Base(err)
 | 
											
												
													
														|  |  			}
 |  |  			}
 | 
											
												
													
														|  | -			_, err = io.ReadFull(decryptor, buffer[42:42+domainLength])
 |  | 
 | 
											
												
													
														|  | -			if err != nil {
 |  | 
 | 
											
												
													
														|  | 
 |  | +			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, domainLength)); err != nil {
 | 
											
												
													
														|  |  				return nil, newError("failed to read domain address").Base(err)
 |  |  				return nil, newError("failed to read domain address").Base(err)
 | 
											
												
													
														|  |  			}
 |  |  			}
 | 
											
												
													
														|  | -			bufferLen += 1 + domainLength
 |  | 
 | 
											
												
													
														|  | -			request.Address = net.DomainAddress(string(buffer[42 : 42+domainLength]))
 |  | 
 | 
											
												
													
														|  | 
 |  | +			request.Address = net.DomainAddress(string(buffer.BytesFrom(-domainLength)))
 | 
											
												
													
														|  |  		}
 |  |  		}
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	if padingLen > 0 {
 |  |  	if padingLen > 0 {
 | 
											
												
													
														|  | -		_, err = io.ReadFull(decryptor, buffer[bufferLen:bufferLen+padingLen])
 |  | 
 | 
											
												
													
														|  | -		if err != nil {
 |  | 
 | 
											
												
													
														|  | 
 |  | +		if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, padingLen)); err != nil {
 | 
											
												
													
														|  |  			return nil, newError("failed to read padding").Base(err)
 |  |  			return nil, newError("failed to read padding").Base(err)
 | 
											
												
													
														|  |  		}
 |  |  		}
 | 
											
												
													
														|  | -		bufferLen += padingLen
 |  | 
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -	_, err = io.ReadFull(decryptor, buffer[bufferLen:bufferLen+4])
 |  | 
 | 
											
												
													
														|  | -	if err != nil {
 |  | 
 | 
											
												
													
														|  | 
 |  | +	if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 4)); err != nil {
 | 
											
												
													
														|  |  		return nil, newError("failed to read checksum").Base(err)
 |  |  		return nil, newError("failed to read checksum").Base(err)
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	fnv1a := fnv.New32a()
 |  |  	fnv1a := fnv.New32a()
 | 
											
												
													
														|  | -	common.Must2(fnv1a.Write(buffer[:bufferLen]))
 |  | 
 | 
											
												
													
														|  | 
 |  | +	common.Must2(fnv1a.Write(buffer.BytesTo(-4)))
 | 
											
												
													
														|  |  	actualHash := fnv1a.Sum32()
 |  |  	actualHash := fnv1a.Sum32()
 | 
											
												
													
														|  | -	expectedHash := serial.BytesToUint32(buffer[bufferLen : bufferLen+4])
 |  | 
 | 
											
												
													
														|  | 
 |  | +	expectedHash := serial.BytesToUint32(buffer.BytesFrom(-4))
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	if actualHash != expectedHash {
 |  |  	if actualHash != expectedHash {
 | 
											
												
													
														|  |  		return nil, newError("invalid auth")
 |  |  		return nil, newError("invalid auth")
 |