فهرست منبع

rewrite vmess encoding using buf

Darien Raymond 8 سال پیش
والد
کامیت
02685094d3
2فایلهای تغییر یافته به همراه63 افزوده شده و 72 حذف شده
  1. 36 33
      proxy/vmess/encoding/client.go
  2. 27 39
      proxy/vmess/encoding/server.go

+ 36 - 33
proxy/vmess/encoding/client.go

@@ -71,57 +71,61 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
 	common.Must2(idHash.Write(timestamp.Bytes(nil)))
 	common.Must2(writer.Write(idHash.Sum(nil)))
 
-	buffer := make([]byte, 0, 512)
-	buffer = append(buffer, Version)
-	buffer = append(buffer, c.requestBodyIV...)
-	buffer = append(buffer, c.requestBodyKey...)
-	buffer = append(buffer, c.responseHeader, byte(header.Option))
+	buffer := buf.New()
+	defer buffer.Release()
+
+	buffer.AppendBytes(Version)
+	buffer.Append(c.requestBodyIV)
+	buffer.Append(c.requestBodyKey)
+	buffer.AppendBytes(c.responseHeader, byte(header.Option))
+
 	padingLen := dice.Roll(16)
 	if header.Security.Is(protocol.SecurityType_LEGACY) {
 		// Disable padding in legacy mode for a smooth transition.
 		padingLen = 0
 	}
 	security := byte(padingLen<<4) | byte(header.Security)
-	buffer = append(buffer, security, byte(0), byte(header.Command))
+	buffer.AppendBytes(security, byte(0), byte(header.Command))
 
 	if header.Command != protocol.RequestCommandMux {
-		buffer = header.Port.Bytes(buffer)
+		common.Must(buffer.AppendSupplier(serial.WriteUint16(header.Port.Value())))
 
 		switch header.Address.Family() {
 		case net.AddressFamilyIPv4:
-			buffer = append(buffer, byte(protocol.AddressTypeIPv4))
-			buffer = append(buffer, header.Address.IP()...)
+			buffer.AppendBytes(byte(protocol.AddressTypeIPv4))
+			buffer.Append(header.Address.IP())
 		case net.AddressFamilyIPv6:
-			buffer = append(buffer, byte(protocol.AddressTypeIPv6))
-			buffer = append(buffer, header.Address.IP()...)
+			buffer.AppendBytes(byte(protocol.AddressTypeIPv6))
+			buffer.Append(header.Address.IP())
 		case net.AddressFamilyDomain:
 			domain := header.Address.Domain()
 			if protocol.IsDomainTooLong(domain) {
 				return newError("long domain not supported: ", domain)
 			}
 			nDomain := len(domain)
-			buffer = append(buffer, byte(protocol.AddressTypeDomain), byte(nDomain))
-			buffer = append(buffer, domain...)
+			buffer.AppendBytes(byte(protocol.AddressTypeDomain), byte(nDomain))
+			common.Must(buffer.AppendSupplier(serial.WriteString(domain)))
 		}
 	}
 
 	if padingLen > 0 {
-		pading := make([]byte, padingLen)
-		common.Must2(rand.Read(pading))
-		buffer = append(buffer, pading...)
+		common.Must(buffer.AppendSupplier(buf.ReadFullFrom(rand.Reader, padingLen)))
 	}
 
 	fnv1a := fnv.New32a()
-	common.Must2(fnv1a.Write(buffer))
+	common.Must2(fnv1a.Write(buffer.Bytes()))
 
-	buffer = fnv1a.Sum(buffer)
+	common.Must(buffer.AppendSupplier(func(b []byte) (int, error) {
+		fnv1a.Sum(b[:0])
+		return fnv1a.Size(), nil
+	}))
 
 	timestampHash := md5.New()
 	common.Must2(timestampHash.Write(hashTimestamp(timestamp)))
 	iv := timestampHash.Sum(nil)
 	aesStream := crypto.NewAesEncryptionStream(account.(*vmess.InternalAccount).ID.CmdKey(), iv)
-	aesStream.XORKeyStream(buffer, buffer)
-	common.Must2(writer.Write(buffer))
+	aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
+	common.Must2(writer.Write(buffer.Bytes()))
 	return nil
 }
 
@@ -197,32 +201,31 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
 	aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey, c.responseBodyIV)
 	c.responseReader = crypto.NewCryptionReader(aesStream, reader)
 
-	buffer := make([]byte, 256)
+	buffer := buf.New()
+	defer buffer.Release()
 
-	_, err := io.ReadFull(c.responseReader, buffer[:4])
-	if err != nil {
+	if err := buffer.AppendSupplier(buf.ReadFullFrom(c.responseReader, 4)); err != nil {
 		log.Trace(newError("failed to read response header").Base(err))
 		return nil, err
 	}
 
-	if buffer[0] != c.responseHeader {
-		return nil, newError("unexpected response header. Expecting ", int(c.responseHeader), " but actually ", int(buffer[0]))
+	if buffer.Byte(0) != c.responseHeader {
+		return nil, newError("unexpected response header. Expecting ", int(c.responseHeader), " but actually ", int(buffer.Byte(0)))
 	}
 
 	header := &protocol.ResponseHeader{
-		Option: bitmask.Byte(buffer[1]),
+		Option: bitmask.Byte(buffer.Byte(1)),
 	}
 
-	if buffer[2] != 0 {
-		cmdID := buffer[2]
-		dataLen := int(buffer[3])
-		_, err := io.ReadFull(c.responseReader, buffer[:dataLen])
-		if err != nil {
+	if buffer.Byte(2) != 0 {
+		cmdID := buffer.Byte(2)
+		dataLen := int(buffer.Byte(3))
+
+		if err := buffer.Reset(buf.ReadFullFrom(c.responseReader, dataLen)); err != nil {
 			log.Trace(newError("failed to read response command").Base(err))
 			return nil, err
 		}
-		data := buffer[:dataLen]
-		command, err := UnmarshalCommand(cmdID, data)
+		command, err := UnmarshalCommand(cmdID, buffer.Bytes())
 		if err == nil {
 			header.Command = command
 		}

+ 27 - 39
proxy/vmess/encoding/server.go

@@ -115,14 +115,14 @@ func NewServerSession(validator protocol.UserValidator, sessionHistory *SessionH
 }
 
 func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) {
-	buffer := make([]byte, 512)
+	buffer := buf.New()
+	defer buffer.Release()
 
-	_, err := io.ReadFull(reader, buffer[:protocol.IDBytesLen])
-	if err != nil {
+	if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, protocol.IDBytesLen)); err != nil {
 		return nil, newError("failed to read request header").Base(err)
 	}
 
-	user, timestamp, valid := s.userValidator.Get(buffer[:protocol.IDBytesLen])
+	user, timestamp, valid := s.userValidator.Get(buffer.Bytes())
 	if !valid {
 		return nil, newError("invalid user")
 	}
@@ -139,23 +139,21 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 	aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv)
 	decryptor := crypto.NewCryptionReader(aesStream, reader)
 
-	nBytes, err := io.ReadFull(decryptor, buffer[:41])
-	if err != nil {
+	if err := buffer.Reset(buf.ReadFullFrom(decryptor, 41)); err != nil {
 		return nil, newError("failed to read request header").Base(err)
 	}
-	bufferLen := nBytes
 
 	request := &protocol.RequestHeader{
 		User:    user,
-		Version: buffer[0],
+		Version: buffer.Byte(0),
 	}
 
 	if request.Version != Version {
 		return nil, newError("invalid protocol version ", request.Version)
 	}
 
-	s.requestBodyIV = append([]byte(nil), buffer[1:17]...)   // 16 bytes
-	s.requestBodyKey = append([]byte(nil), buffer[17:33]...) // 16 bytes
+	s.requestBodyIV = append([]byte(nil), buffer.BytesRange(1, 17)...)   // 16 bytes
+	s.requestBodyKey = append([]byte(nil), buffer.BytesRange(17, 33)...) // 16 bytes
 	var sid sessionId
 	copy(sid.user[:], vmessAccount.ID.Bytes())
 	copy(sid.key[:], s.requestBodyKey)
@@ -165,66 +163,56 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 	}
 	s.sessionHistory.add(sid)
 
-	s.responseHeader = buffer[33]             // 1 byte
-	request.Option = bitmask.Byte(buffer[34]) // 1 byte
-	padingLen := int(buffer[35] >> 4)
-	request.Security = protocol.NormSecurity(protocol.Security(buffer[35] & 0x0F))
+	s.responseHeader = buffer.Byte(33)             // 1 byte
+	request.Option = bitmask.Byte(buffer.Byte(34)) // 1 byte
+	padingLen := int(buffer.Byte(35) >> 4)
+	request.Security = protocol.NormSecurity(protocol.Security(buffer.Byte(35) & 0x0F))
 	// 1 bytes reserved
-	request.Command = protocol.RequestCommand(buffer[37])
+	request.Command = protocol.RequestCommand(buffer.Byte(37))
 
 	if request.Command != protocol.RequestCommandMux {
-		request.Port = net.PortFromBytes(buffer[38:40])
+		request.Port = net.PortFromBytes(buffer.BytesRange(38, 40))
 
-		switch protocol.AddressType(buffer[40]) {
+		switch protocol.AddressType(buffer.Byte(40)) {
 		case protocol.AddressTypeIPv4:
-			_, err = io.ReadFull(decryptor, buffer[41:45]) // 4 bytes
-			bufferLen += 4
-			if err != nil {
+			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 4)); err != nil {
 				return nil, newError("failed to read IPv4 address").Base(err)
 			}
-			request.Address = net.IPAddress(buffer[41:45])
+			request.Address = net.IPAddress(buffer.BytesFrom(-4))
 		case protocol.AddressTypeIPv6:
-			_, err = io.ReadFull(decryptor, buffer[41:57]) // 16 bytes
-			bufferLen += 16
-			if err != nil {
+			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 16)); err != nil {
 				return nil, newError("failed to read IPv6 address").Base(err)
 			}
-			request.Address = net.IPAddress(buffer[41:57])
+			request.Address = net.IPAddress(buffer.BytesFrom(-16))
 		case protocol.AddressTypeDomain:
-			_, err = io.ReadFull(decryptor, buffer[41:42])
-			if err != nil {
+			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 1)); err != nil {
 				return nil, newError("failed to read domain address").Base(err)
 			}
-			domainLength := int(buffer[41])
+			domainLength := int(buffer.Byte(buffer.Len() - 1))
 			if domainLength == 0 {
 				return nil, newError("zero length domain").Base(err)
 			}
-			_, err = io.ReadFull(decryptor, buffer[42:42+domainLength])
-			if err != nil {
+			if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, domainLength)); err != nil {
 				return nil, newError("failed to read domain address").Base(err)
 			}
-			bufferLen += 1 + domainLength
-			request.Address = net.DomainAddress(string(buffer[42 : 42+domainLength]))
+			request.Address = net.DomainAddress(string(buffer.BytesFrom(-domainLength)))
 		}
 	}
 
 	if padingLen > 0 {
-		_, err = io.ReadFull(decryptor, buffer[bufferLen:bufferLen+padingLen])
-		if err != nil {
+		if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, padingLen)); err != nil {
 			return nil, newError("failed to read padding").Base(err)
 		}
-		bufferLen += padingLen
 	}
 
-	_, err = io.ReadFull(decryptor, buffer[bufferLen:bufferLen+4])
-	if err != nil {
+	if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 4)); err != nil {
 		return nil, newError("failed to read checksum").Base(err)
 	}
 
 	fnv1a := fnv.New32a()
-	common.Must2(fnv1a.Write(buffer[:bufferLen]))
+	common.Must2(fnv1a.Write(buffer.BytesTo(-4)))
 	actualHash := fnv1a.Sum32()
-	expectedHash := serial.BytesToUint32(buffer[bufferLen : bufferLen+4])
+	expectedHash := serial.BytesToUint32(buffer.BytesFrom(-4))
 
 	if actualHash != expectedHash {
 		return nil, newError("invalid auth")