Browse Source

Refactor vmess internal struct for better readability

V2Ray 10 năm trước cách đây
mục cha
commit
791ac780f0
7 tập tin đã thay đổi với 147 bổ sung229 xóa
  1. 52 153
      io/vmess/vmess.go
  2. 19 41
      io/vmess/vmess_test.go
  3. 45 0
      log/log.go
  4. 18 6
      net/vmess/vmessin.go
  5. 10 20
      net/vmess/vmessout.go
  6. 3 2
      vid.go
  7. 0 7
      vpoint.go

+ 52 - 153
io/vmess/vmess.go

@@ -11,7 +11,6 @@ import (
 	"io"
 	_ "log"
 	mrand "math/rand"
-	"net"
 
 	"github.com/v2ray/v2ray-core"
 	v2io "github.com/v2ray/v2ray-core/io"
@@ -33,127 +32,15 @@ var (
 // VMessRequest implements the request message of VMess protocol. It only contains
 // the header of a request message. The data part will be handled by conection
 // handler directly, in favor of data streaming.
-// 1 Version
-// 16 UserHash
-// 16 Request IV
-// 16 Request Key
-// 4 Response Header
-// 1 Command
-// 2 Port
-// 1 Address Type
-// 256 Target Address
-
-type VMessRequest [312]byte
-
-func (r *VMessRequest) Version() byte {
-	return r[0]
-}
-
-func (r *VMessRequest) SetVersion(version byte) *VMessRequest {
-	r[0] = version
-	return r
-}
-
-func (r *VMessRequest) UserHash() []byte {
-	return r[1:17]
-}
-
-func (r *VMessRequest) RequestIV() []byte {
-	return r[17:33]
-}
-
-func (r *VMessRequest) RequestKey() []byte {
-	return r[33:49]
-}
-
-func (r *VMessRequest) ResponseHeader() []byte {
-	return r[49:53]
-}
-
-func (r *VMessRequest) Command() byte {
-	return r[53]
-}
-
-func (r *VMessRequest) SetCommand(command byte) *VMessRequest {
-	r[53] = command
-	return r
-}
-
-func (r *VMessRequest) Port() uint16 {
-	return binary.BigEndian.Uint16(r.portBytes())
-}
-
-func (r *VMessRequest) portBytes() []byte {
-	return r[54:56]
-}
-
-func (r *VMessRequest) SetPort(port uint16) *VMessRequest {
-	binary.BigEndian.PutUint16(r.portBytes(), port)
-	return r
-}
-
-func (r *VMessRequest) targetAddressType() byte {
-	return r[56]
-}
-
-func (r *VMessRequest) Destination() v2net.VAddress {
-	switch r.targetAddressType() {
-	case addrTypeIPv4:
-		fallthrough
-	case addrTypeIPv6:
-		return v2net.IPAddress(r.targetAddressBytes(), r.Port())
-	case addrTypeDomain:
-		return v2net.DomainAddress(r.TargetAddress(), r.Port())
-	default:
-		panic("Unpexected address type")
-	}
-}
-
-func (r *VMessRequest) TargetAddress() string {
-	switch r.targetAddressType() {
-	case addrTypeIPv4:
-		return net.IP(r[57:61]).String()
-	case addrTypeIPv6:
-		return net.IP(r[57:73]).String()
-	case addrTypeDomain:
-		domainLength := int(r[57])
-		return string(r[58 : 58+domainLength])
-	default:
-		panic("Unexpected address type")
-	}
-}
-
-func (r *VMessRequest) targetAddressBytes() []byte {
-	switch r.targetAddressType() {
-	case addrTypeIPv4:
-		return r[57:61]
-	case addrTypeIPv6:
-		return r[57:73]
-	case addrTypeDomain:
-		domainLength := int(r[57])
-		return r[57 : 58+domainLength]
-	default:
-		panic("Unexpected address type")
-	}
-}
-
-func (r *VMessRequest) SetIPv4(ipv4 []byte) *VMessRequest {
-	r[56] = addrTypeIPv4
-	copy(r[57:], ipv4)
-	return r
-}
-
-func (r *VMessRequest) SetIPv6(ipv6 []byte) *VMessRequest {
-	r[56] = addrTypeIPv6
-	copy(r[57:], ipv6)
-	return r
-}
 
-func (r *VMessRequest) SetDomain(domain string) *VMessRequest {
-	r[56] = addrTypeDomain
-	r[57] = byte(len(domain))
-	copy(r[58:], []byte(domain))
-	return r
+type VMessRequest struct {
+	Version        byte
+	UserId         core.VID
+	RequestIV      [16]byte
+	RequestKey     [16]byte
+	ResponseHeader [4]byte
+	Command        byte
+	Address        v2net.VAddress
 }
 
 type VMessRequestReader struct {
@@ -169,26 +56,30 @@ func NewVMessRequestReader(vUserSet *core.VUserSet) *VMessRequestReader {
 func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
 	request := new(VMessRequest)
 
-	nBytes, err := reader.Read(request[0:17] /* version + user hash */)
+	buffer := make([]byte, 256)
+	nBytes, err := reader.Read(buffer[0:1])
 	if err != nil {
 		return nil, err
 	}
-	if nBytes != 17 {
-		err = fmt.Errorf("Unexpected length of header %d", nBytes)
+	// TODO: verify version number
+	request.Version = buffer[0]
+
+	nBytes, err = reader.Read(buffer[:len(request.UserId)])
+	if err != nil {
 		return nil, err
 	}
-	// TODO: verify version number
-	userId, valid := r.vUserSet.IsValidUserId(request.UserHash())
+
+	userId, valid := r.vUserSet.IsValidUserId(buffer[:nBytes])
 	if !valid {
 		return nil, ErrorInvalidUser
 	}
+	request.UserId = *userId
 
 	decryptor, err := NewDecryptionReader(reader, userId.Hash([]byte("PWD")), make([]byte, blockSize))
 	if err != nil {
 		return nil, err
 	}
 
-	buffer := make([]byte, 300)
 	nBytes, err = decryptor.Read(buffer[0:1])
 	if err != nil {
 		return nil, err
@@ -204,15 +95,15 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
 	}
 
 	// TODO: check number of bytes returned
-	_, err = decryptor.Read(request.RequestIV())
+	_, err = decryptor.Read(request.RequestIV[:])
 	if err != nil {
 		return nil, err
 	}
-	_, err = decryptor.Read(request.RequestKey())
+	_, err = decryptor.Read(request.RequestKey[:])
 	if err != nil {
 		return nil, err
 	}
-	_, err = decryptor.Read(request.ResponseHeader())
+	_, err = decryptor.Read(request.ResponseHeader[:])
 	if err != nil {
 		return nil, err
 	}
@@ -220,13 +111,13 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
 	if err != nil {
 		return nil, err
 	}
-	request.SetCommand(buffer[0])
+	request.Command = buffer[0]
 
 	_, err = decryptor.Read(buffer[0:2])
 	if err != nil {
 		return nil, err
 	}
-	request.SetPort(binary.BigEndian.Uint16(buffer[0:2]))
+	port := binary.BigEndian.Uint16(buffer[0:2])
 
 	_, err = decryptor.Read(buffer[0:1])
 	if err != nil {
@@ -238,13 +129,13 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
 		if err != nil {
 			return nil, err
 		}
-		request.SetIPv4(buffer[1:5])
+		request.Address = v2net.IPAddress(buffer[1:5], port)
 	case addrTypeIPv6:
 		_, err = decryptor.Read(buffer[1:17])
 		if err != nil {
 			return nil, err
 		}
-		request.SetIPv6(buffer[1:17])
+		request.Address = v2net.IPAddress(buffer[1:17], port)
 	case addrTypeDomain:
 		_, err = decryptor.Read(buffer[1:2])
 		if err != nil {
@@ -255,7 +146,7 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
 		if err != nil {
 			return nil, err
 		}
-		request.SetDomain(string(buffer[2 : 2+domainLength]))
+		request.Address = v2net.DomainAddress(string(buffer[2:2+domainLength]), port)
 	}
 	_, err = decryptor.Read(buffer[0:1])
 	if err != nil {
@@ -271,19 +162,17 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
 }
 
 type VMessRequestWriter struct {
-	vUserSet *core.VUserSet
 }
 
-func NewVMessRequestWriter(vUserSet *core.VUserSet) *VMessRequestWriter {
+func NewVMessRequestWriter() *VMessRequestWriter {
 	writer := new(VMessRequestWriter)
-	writer.vUserSet = vUserSet
 	return writer
 }
 
 func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) error {
 	buffer := make([]byte, 0, 300)
-	buffer = append(buffer, request.Version())
-	buffer = append(buffer, request.UserHash()...)
+	buffer = append(buffer, request.Version)
+	buffer = append(buffer, request.UserId.Hash([]byte("ASK"))...)
 
 	encryptionBegin := len(buffer)
 
@@ -296,13 +185,27 @@ func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) erro
 	buffer = append(buffer, byte(randomLength))
 	buffer = append(buffer, randomContent...)
 
-	buffer = append(buffer, request.RequestIV()...)
-	buffer = append(buffer, request.RequestKey()...)
-	buffer = append(buffer, request.ResponseHeader()...)
-	buffer = append(buffer, request.Command())
-	buffer = append(buffer, request.portBytes()...)
-	buffer = append(buffer, request.targetAddressType())
-	buffer = append(buffer, request.targetAddressBytes()...)
+	buffer = append(buffer, request.RequestIV[:]...)
+	buffer = append(buffer, request.RequestKey[:]...)
+	buffer = append(buffer, request.ResponseHeader[:]...)
+	buffer = append(buffer, request.Command)
+
+	portBytes := make([]byte, 2)
+	binary.BigEndian.PutUint16(portBytes, request.Address.Port)
+	buffer = append(buffer, portBytes...)
+
+	switch {
+	case request.Address.IsIPv4():
+		buffer = append(buffer, addrTypeIPv4)
+		buffer = append(buffer, request.Address.IP...)
+	case request.Address.IsIPv6():
+		buffer = append(buffer, addrTypeIPv6)
+		buffer = append(buffer, request.Address.IP...)
+	case request.Address.IsDomain():
+		buffer = append(buffer, addrTypeDomain)
+		buffer = append(buffer, byte(len(request.Address.Domain)))
+		buffer = append(buffer, []byte(request.Address.Domain)...)
+	}
 
 	paddingLength := blockSize - 1 - (len(buffer)-encryptionBegin)%blockSize
 	if paddingLength == 0 {
@@ -317,11 +220,7 @@ func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) erro
 	buffer = append(buffer, paddingBuffer...)
 	encryptionEnd := len(buffer)
 
-	userId, valid := w.vUserSet.IsValidUserId(request.UserHash())
-	if !valid {
-		return ErrorInvalidUser
-	}
-	aesCipher, err := aes.NewCipher(userId.Hash([]byte("PWD")))
+	aesCipher, err := aes.NewCipher(request.UserId.Hash([]byte("PWD")))
 	if err != nil {
 		return err
 	}
@@ -344,6 +243,6 @@ type VMessResponse [4]byte
 
 func NewVMessResponse(request *VMessRequest) *VMessResponse {
 	response := new(VMessResponse)
-	copy(response[:], request.ResponseHeader())
+	copy(response[:], request.ResponseHeader[:])
 	return response
 }

+ 19 - 41
io/vmess/vmess_test.go

@@ -6,9 +6,13 @@ import (
 	"testing"
 
 	"github.com/v2ray/v2ray-core"
+	v2net "github.com/v2ray/v2ray-core/net"
+	"github.com/v2ray/v2ray-core/testing/unit"
 )
 
 func TestVMessSerialization(t *testing.T) {
+	assert := unit.Assert(t)
+
 	userId, err := core.UUIDToVID("2b2966ac-16aa-4fbf-8d81-c5f172a3da51")
 	if err != nil {
 		t.Fatal(err)
@@ -18,31 +22,29 @@ func TestVMessSerialization(t *testing.T) {
 	userSet.AddUser(core.VUser{userId})
 
 	request := new(VMessRequest)
-	request.SetVersion(byte(0x01))
-	userHash := userId.Hash([]byte("ASK"))
-	copy(request.UserHash(), userHash)
+	request.Version = byte(0x01)
+	request.UserId = userId
 
-	_, err = rand.Read(request.RequestIV())
+	_, err = rand.Read(request.RequestIV[:])
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	_, err = rand.Read(request.RequestKey())
+	_, err = rand.Read(request.RequestKey[:])
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	_, err = rand.Read(request.ResponseHeader())
+	_, err = rand.Read(request.ResponseHeader[:])
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	request.SetCommand(byte(0x01))
-	request.SetPort(80)
-	request.SetDomain("v2ray.com")
+	request.Command = byte(0x01)
+	request.Address = v2net.DomainAddress("v2ray.com", 80)
 
 	buffer := bytes.NewBuffer(make([]byte, 0, 300))
-	requestWriter := NewVMessRequestWriter(userSet)
+	requestWriter := NewVMessRequestWriter()
 	err = requestWriter.Write(buffer, request)
 	if err != nil {
 		t.Fatal(err)
@@ -54,35 +56,11 @@ func TestVMessSerialization(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	if actualRequest.Version() != byte(0x01) {
-		t.Errorf("Expected Version 1, but got %d", actualRequest.Version())
-	}
-
-	if !bytes.Equal(request.UserHash(), actualRequest.UserHash()) {
-		t.Errorf("Expected user hash %v, but got %v", request.UserHash(), actualRequest.UserHash())
-	}
-
-	if !bytes.Equal(request.RequestIV(), actualRequest.RequestIV()) {
-		t.Errorf("Expected request IV %v, but got %v", request.RequestIV(), actualRequest.RequestIV())
-	}
-
-	if !bytes.Equal(request.RequestKey(), actualRequest.RequestKey()) {
-		t.Errorf("Expected request Key %v, but got %v", request.RequestKey(), actualRequest.RequestKey())
-	}
-
-	if !bytes.Equal(request.ResponseHeader(), actualRequest.ResponseHeader()) {
-		t.Errorf("Expected response header %v, but got %v", request.ResponseHeader(), actualRequest.ResponseHeader())
-	}
-
-	if actualRequest.Command() != byte(0x01) {
-		t.Errorf("Expected command 1, but got %d", actualRequest.Command())
-	}
-
-	if actualRequest.Port() != 80 {
-		t.Errorf("Expected port 80, but got %d", actualRequest.Port())
-	}
-
-	if actualRequest.TargetAddress() != "v2ray.com" {
-		t.Errorf("Expected target address v2ray.com, but got %s", actualRequest.TargetAddress())
-	}
+	assert.Byte(actualRequest.Version).Named("Version").Equals(byte(0x01))
+	assert.Bytes(actualRequest.UserId[:]).Named("UserId").Equals(request.UserId[:])
+	assert.Bytes(actualRequest.RequestIV[:]).Named("RequestIV").Equals(request.RequestIV[:])
+	assert.Bytes(actualRequest.RequestKey[:]).Named("RequestKey").Equals(request.RequestKey[:])
+	assert.Bytes(actualRequest.ResponseHeader[:]).Named("ResponseHeader").Equals(request.ResponseHeader[:])
+	assert.Byte(actualRequest.Command).Named("Command").Equals(request.Command)
+	assert.String(actualRequest.Address.String()).Named("Address").Equals(request.Address.String())
 }

+ 45 - 0
log/log.go

@@ -0,0 +1,45 @@
+package log
+
+import (
+	"errors"
+	"fmt"
+	"log"
+)
+
+const (
+	DebugLevel   = LogLevel(0)
+	InfoLevel    = LogLevel(1)
+	WarningLevel = LogLevel(2)
+	ErrorLevel   = LogLevel(3)
+)
+
+var logLevel = WarningLevel
+
+type LogLevel int
+
+func SetLogLevel(level LogLevel) {
+	logLevel = level
+}
+
+func writeLog(data string, level LogLevel) {
+	if level < logLevel {
+		return
+	}
+	log.Print(data)
+}
+
+func Info(format string, v ...interface{}) {
+	data := fmt.Sprintf(format, v)
+	writeLog("[Info]"+data, InfoLevel)
+}
+
+func Warning(format string, v ...interface{}) {
+	data := fmt.Sprintf(format, v)
+	writeLog("[Warning]"+data, WarningLevel)
+}
+
+func Error(format string, v ...interface{}) error {
+	data := fmt.Sprintf(format, v)
+	writeLog("[Error]"+data, ErrorLevel)
+	return errors.New(data)
+}

+ 18 - 6
net/vmess/vmessin.go

@@ -12,12 +12,14 @@ import (
 
 type VMessInboundHandler struct {
 	vPoint    *core.VPoint
+	clients   *core.VUserSet
 	accepting bool
 }
 
-func NewVMessInboundHandler(vp *core.VPoint) *VMessInboundHandler {
+func NewVMessInboundHandler(vp *core.VPoint, clients *core.VUserSet) *VMessInboundHandler {
 	handler := new(VMessInboundHandler)
 	handler.vPoint = vp
+	handler.clients = clients
 	return handler
 }
 
@@ -45,7 +47,7 @@ func (handler *VMessInboundHandler) AcceptConnections(listener net.Listener) err
 
 func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error {
 	defer connection.Close()
-	reader := vmessio.NewVMessRequestReader(handler.vPoint.UserSet)
+	reader := vmessio.NewVMessRequestReader(handler.clients)
 
 	request, err := reader.Read(connection)
 	if err != nil {
@@ -55,8 +57,8 @@ func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error
 	response := vmessio.NewVMessResponse(request)
 	connection.Write(response[:])
 
-	requestKey := request.RequestKey()
-	requestIV := request.RequestIV()
+	requestKey := request.RequestKey[:]
+	requestIV := request.RequestIV[:]
 	responseKey := md5.Sum(requestKey)
 	responseIV := md5.Sum(requestIV)
 
@@ -70,7 +72,7 @@ func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error
 		return err
 	}
 
-	ray := handler.vPoint.NewInboundConnectionAccepted(request.Destination())
+	ray := handler.vPoint.NewInboundConnectionAccepted(request.Address)
 	input := ray.InboundInput()
 	output := ray.InboundOutput()
 	finish := make(chan bool, 2)
@@ -112,8 +114,18 @@ func (handler *VMessInboundHandler) waitForFinish(finish <-chan bool) {
 }
 
 type VMessInboundHandlerFactory struct {
+	allowedClients *core.VUserSet
+}
+
+func NewVMessInboundHandlerFactory(clients []core.VUser) *VMessInboundHandlerFactory {
+	factory := new(VMessInboundHandlerFactory)
+	factory.allowedClients = core.NewVUserSet()
+	for _, user := range clients {
+		factory.allowedClients.AddUser(user)
+	}
+	return factory
 }
 
 func (factory *VMessInboundHandlerFactory) Create(vp *core.VPoint) *VMessInboundHandler {
-	return NewVMessInboundHandler(vp)
+	return NewVMessInboundHandler(vp, factory.allowedClients)
 }

+ 10 - 20
net/vmess/vmessout.go

@@ -45,23 +45,13 @@ func (handler *VMessOutboundHandler) Start(ray core.OutboundVRay) error {
 	vNextAddress, vNextUser := handler.pickVNext()
 
 	request := new(vmessio.VMessRequest)
-	request.SetVersion(vmessio.Version)
-	copy(request.UserHash(), vNextUser.Id.Hash([]byte("ASK")))
-	rand.Read(request.RequestIV())
-	rand.Read(request.RequestKey())
-	rand.Read(request.ResponseHeader())
-	request.SetCommand(byte(0x01))
-	request.SetPort(handler.dest.Port)
-
-	address := handler.dest
-	switch {
-	case address.IsIPv4():
-		request.SetIPv4(address.IP)
-	case address.IsIPv6():
-		request.SetIPv6(address.IP)
-	case address.IsDomain():
-		request.SetDomain(address.Domain)
-	}
+	request.Version = vmessio.Version
+	request.UserId = vNextUser.Id
+	rand.Read(request.RequestIV[:])
+	rand.Read(request.RequestKey[:])
+	rand.Read(request.ResponseHeader[:])
+	request.Command = byte(0x01)
+	request.Address = handler.dest
 
 	conn, err := net.Dial("tcp", vNextAddress.String())
 	if err != nil {
@@ -69,11 +59,11 @@ func (handler *VMessOutboundHandler) Start(ray core.OutboundVRay) error {
 	}
 	defer conn.Close()
 
-	requestWriter := vmessio.NewVMessRequestWriter(handler.vPoint.UserSet)
+	requestWriter := vmessio.NewVMessRequestWriter()
 	requestWriter.Write(conn, request)
 
-	requestKey := request.RequestKey()
-	requestIV := request.RequestIV()
+	requestKey := request.RequestKey[:]
+	requestIV := request.RequestIV[:]
 	responseKey := md5.Sum(requestKey)
 	responseIV := md5.Sum(requestIV)
 

+ 3 - 2
vid.go

@@ -3,7 +3,8 @@ package core
 import (
 	"crypto/md5"
 	"encoding/hex"
-	"fmt"
+
+	"github.com/v2ray/v2ray-core/log"
 )
 
 // The ID of en entity, in the form of an UUID.
@@ -23,7 +24,7 @@ var byteGroups = []int{8, 4, 4, 4, 12}
 func UUIDToVID(uuid string) (v VID, err error) {
 	text := []byte(uuid)
 	if len(text) < 32 {
-		err = fmt.Errorf("uuid: invalid UUID string: %s", text)
+		err = log.Error("uuid: invalid UUID string: %s", text)
 		return
 	}
 

+ 0 - 7
vpoint.go

@@ -9,7 +9,6 @@ import (
 // VPoint is an single server in V2Ray system.
 type VPoint struct {
 	Config     VConfig
-	UserSet    *VUserSet
 	ichFactory InboundConnectionHandlerFactory
 	ochFactory OutboundConnectionHandlerFactory
 }
@@ -19,12 +18,6 @@ type VPoint struct {
 func NewVPoint(config *VConfig, ichFactory InboundConnectionHandlerFactory, ochFactory OutboundConnectionHandlerFactory) (*VPoint, error) {
 	var vpoint = new(VPoint)
 	vpoint.Config = *config
-	vpoint.UserSet = NewVUserSet()
-
-	for _, user := range vpoint.Config.AllowedClients {
-		vpoint.UserSet.AddUser(user)
-	}
-
 	vpoint.ichFactory = ichFactory
 	vpoint.ochFactory = ochFactory