Jelajahi Sumber

handle switch account command in vmess out

v2ray 9 tahun lalu
induk
melakukan
baaef1dad5

+ 33 - 0
common/net/address.go

@@ -17,6 +17,7 @@ type Address interface {
 	IsDomain() bool // True if this Address is an domain address
 
 	String() string // String representation of this Address
+	Equals(Address) bool
 }
 
 func ParseAddress(addr string) Address {
@@ -91,6 +92,17 @@ func (this *IPv4Address) String() string {
 	return this.IP().String()
 }
 
+func (this *IPv4Address) Equals(another Address) bool {
+	anotherIPv4, ok := another.(*IPv4Address)
+	if !ok {
+		return false
+	}
+	return this[0] == anotherIPv4[0] &&
+		this[1] == anotherIPv4[1] &&
+		this[2] == anotherIPv4[2] &&
+		this[3] == anotherIPv4[3]
+}
+
 type IPv6Address [16]byte
 
 func (addr *IPv6Address) IP() net.IP {
@@ -117,6 +129,19 @@ func (this *IPv6Address) String() string {
 	return "[" + this.IP().String() + "]"
 }
 
+func (this *IPv6Address) Equals(another Address) bool {
+	anotherIPv6, ok := another.(*IPv6Address)
+	if !ok {
+		return false
+	}
+	for idx, v := range *this {
+		if anotherIPv6[idx] != v {
+			return false
+		}
+	}
+	return true
+}
+
 type DomainAddressImpl string
 
 func (addr *DomainAddressImpl) IP() net.IP {
@@ -142,3 +167,11 @@ func (addr *DomainAddressImpl) IsDomain() bool {
 func (this *DomainAddressImpl) String() string {
 	return this.Domain()
 }
+
+func (this *DomainAddressImpl) Equals(another Address) bool {
+	anotherDomain, ok := another.(*DomainAddressImpl)
+	if !ok {
+		return false
+	}
+	return this.Domain() == anotherDomain.Domain()
+}

+ 15 - 0
common/net/destination.go

@@ -7,6 +7,7 @@ type Destination interface {
 	Port() Port
 	String() string // String representation of the destination
 	NetAddr() string
+	Equals(Destination) bool
 
 	IsTCP() bool // True if destination is reachable via TCP
 	IsUDP() bool // True if destination is reachable via UDP
@@ -55,6 +56,13 @@ func (dest *tcpDestination) Port() Port {
 	return dest.port
 }
 
+func (dest *tcpDestination) Equals(another Destination) bool {
+	if !another.IsTCP() {
+		return false
+	}
+	return dest.Port() == another.Port() && dest.Address().Equals(another.Address())
+}
+
 type udpDestination struct {
 	address Address
 	port    Port
@@ -87,3 +95,10 @@ func (dest *udpDestination) IsUDP() bool {
 func (dest *udpDestination) Port() Port {
 	return dest.port
 }
+
+func (dest *udpDestination) Equals(another Destination) bool {
+	if !another.IsUDP() {
+		return false
+	}
+	return dest.Port() == another.Port() && dest.Address().Equals(another.Address())
+}

+ 15 - 7
proxy/vmess/command/accounts.go

@@ -6,6 +6,7 @@ import (
 	v2net "github.com/v2ray/v2ray-core/common/net"
 	"github.com/v2ray/v2ray-core/common/serial"
 	"github.com/v2ray/v2ray-core/common/uuid"
+	"github.com/v2ray/v2ray-core/proxy/vmess"
 	"github.com/v2ray/v2ray-core/transport"
 )
 
@@ -19,13 +20,15 @@ func init() {
 // 2 bytes: port
 // 16 bytes: uuid
 // 2 bytes: alterid
-// 8 bytes: time
+// 1 byte: level
+// 1 bytes: time
 type SwitchAccount struct {
 	Host     v2net.Address
 	Port     v2net.Port
 	ID       *uuid.UUID
 	AlterIds serial.Uint16Literal
-	ValidSec serial.Uint16Literal
+	Level    vmess.UserLevel
+	ValidMin byte
 }
 
 func (this *SwitchAccount) Marshal(writer io.Writer) {
@@ -45,9 +48,9 @@ func (this *SwitchAccount) Marshal(writer io.Writer) {
 	writer.Write(idBytes)
 
 	writer.Write(this.AlterIds.Bytes())
+	writer.Write([]byte{byte(this.Level)})
 
-	timeBytes := this.ValidSec.Bytes()
-	writer.Write(timeBytes)
+	writer.Write([]byte{this.ValidMin})
 }
 
 func (this *SwitchAccount) Unmarshal(data []byte) error {
@@ -71,10 +74,15 @@ func (this *SwitchAccount) Unmarshal(data []byte) error {
 		return transport.CorruptedPacket
 	}
 	this.AlterIds = serial.ParseUint16(data[alterIdStart : alterIdStart+2])
-	timeStart := alterIdStart + 2
-	if len(data) < timeStart+2 {
+	levelStart := alterIdStart + 2
+	if len(data) < levelStart {
 		return transport.CorruptedPacket
 	}
-	this.ValidSec = serial.ParseUint16(data[timeStart : timeStart+2])
+	this.Level = vmess.UserLevel(data[levelStart])
+	timeStart := levelStart + 1
+	if len(data) < timeStart {
+		return transport.CorruptedPacket
+	}
+	this.ValidMin = data[timeStart]
 	return nil
 }

+ 4 - 2
proxy/vmess/command/accounts_test.go

@@ -18,7 +18,8 @@ func TestSwitchAccount(t *testing.T) {
 		Port:     1234,
 		ID:       uuid.New(),
 		AlterIds: 1024,
-		ValidSec: 8080,
+		Level:    128,
+		ValidMin: 16,
 	}
 
 	cmd, err := CreateResponseCommand(1)
@@ -33,5 +34,6 @@ func TestSwitchAccount(t *testing.T) {
 	netassert.Port(sa.Port).Equals(sa2.Port)
 	assert.String(sa.ID).Equals(sa2.ID.String())
 	assert.Uint16(sa.AlterIds.Value()).Equals(sa2.AlterIds.Value())
-	assert.Uint16(sa.ValidSec.Value()).Equals(sa2.ValidSec.Value())
+	assert.Byte(byte(sa.Level)).Equals(byte(sa2.Level))
+	assert.Byte(sa.ValidMin).Equals(sa2.ValidMin)
 }

+ 4 - 0
proxy/vmess/id.go

@@ -21,6 +21,10 @@ type ID struct {
 	cmdKey [IDBytesLen]byte
 }
 
+func (this *ID) Equals(another *ID) bool {
+	return this.uuid.Equals(another.uuid)
+}
+
 func (this *ID) Bytes() []byte {
 	return this.uuid.Bytes()
 }

+ 5 - 5
proxy/vmess/inbound/command.go

@@ -19,20 +19,20 @@ func (this *VMessInboundHandler) generateCommand(buffer *alloc.Buffer) {
 			inboundHandler, ok := handler.(*VMessInboundHandler)
 			if ok {
 				user := inboundHandler.GetUser()
-				availableSecUint16 := uint16(65535)
-				if availableSec < 65535 {
-					availableSecUint16 = uint16(availableSec)
+				availableMin := availableSec / 60
+				if availableMin > 255 {
+					availableMin = 255
 				}
 
 				saCmd := &command.SwitchAccount{
 					Port:     inboundHandler.Port(),
 					ID:       user.ID.UUID(),
 					AlterIds: serial.Uint16Literal(len(user.AlterIDs)),
-					ValidSec: serial.Uint16Literal(availableSecUint16),
+					Level:    user.Level,
+					ValidMin: byte(availableMin),
 				}
 				saCmd.Marshal(commandBytes)
 			}
-
 		}
 	}
 

+ 28 - 2
proxy/vmess/outbound/command.go

@@ -1,5 +1,31 @@
 package outbound
 
-func handleCommand(command byte, data []byte) error {
-	return nil
+import (
+	"github.com/v2ray/v2ray-core/common/log"
+	v2net "github.com/v2ray/v2ray-core/common/net"
+	"github.com/v2ray/v2ray-core/proxy/vmess"
+	"github.com/v2ray/v2ray-core/proxy/vmess/command"
+)
+
+func (this *VMessOutboundHandler) handleSwitchAccount(cmd *command.SwitchAccount) {
+	user := vmess.NewUser(vmess.NewID(cmd.ID), cmd.Level, cmd.AlterIds.Value())
+	dest := v2net.TCPDestination(cmd.Host, cmd.Port)
+	this.receiverManager.AddDetour(NewReceiver(dest, user), cmd.ValidMin)
+}
+
+func (this *VMessOutboundHandler) handleCommand(cmdId byte, data []byte) {
+	cmd, err := command.CreateResponseCommand(cmdId)
+	if err != nil {
+		log.Warning("VMessOut: Unknown response command (", cmdId, "): ", err)
+		return
+	}
+	if err := cmd.Unmarshal(data); err != nil {
+		log.Warning("VMessOut: Failed to parse response command: ", err)
+		return
+	}
+	switch typedCommand := cmd.(type) {
+	case *command.SwitchAccount:
+		this.handleSwitchAccount(typedCommand)
+	default:
+	}
 }

+ 7 - 7
proxy/vmess/outbound/outbound.go

@@ -48,10 +48,10 @@ func (this *VMessOutboundHandler) Dispatch(firstPacket v2net.Packet, ray ray.Out
 	request.RequestKey = buffer.Value[16:32]
 	request.ResponseHeader = buffer.Value[32:36]
 
-	return startCommunicate(request, vNextAddress, ray, firstPacket)
+	return this.startCommunicate(request, vNextAddress, ray, firstPacket)
 }
 
-func startCommunicate(request *protocol.VMessRequest, dest v2net.Destination, ray ray.OutboundRay, firstPacket v2net.Packet) error {
+func (this *VMessOutboundHandler) startCommunicate(request *protocol.VMessRequest, dest v2net.Destination, ray ray.OutboundRay, firstPacket v2net.Packet) error {
 	var destIp net.IP
 	if dest.Address().IsIPv4() || dest.Address().IsIPv6() {
 		destIp = dest.Address().IP()
@@ -84,8 +84,8 @@ func startCommunicate(request *protocol.VMessRequest, dest v2net.Destination, ra
 	requestFinish.Lock()
 	responseFinish.Lock()
 
-	go handleRequest(conn, request, firstPacket, input, &requestFinish)
-	go handleResponse(conn, request, output, &responseFinish, (request.Command == protocol.CmdUDP))
+	go this.handleRequest(conn, request, firstPacket, input, &requestFinish)
+	go this.handleResponse(conn, request, output, &responseFinish, (request.Command == protocol.CmdUDP))
 
 	requestFinish.Lock()
 	conn.CloseWrite()
@@ -93,7 +93,7 @@ func startCommunicate(request *protocol.VMessRequest, dest v2net.Destination, ra
 	return nil
 }
 
-func handleRequest(conn net.Conn, request *protocol.VMessRequest, firstPacket v2net.Packet, input <-chan *alloc.Buffer, finish *sync.Mutex) {
+func (this *VMessOutboundHandler) handleRequest(conn net.Conn, request *protocol.VMessRequest, firstPacket v2net.Packet, input <-chan *alloc.Buffer, finish *sync.Mutex) {
 	defer finish.Unlock()
 	aesStream, err := v2crypto.NewAesEncryptionStream(request.RequestKey[:], request.RequestIV[:])
 	if err != nil {
@@ -143,7 +143,7 @@ func headerMatch(request *protocol.VMessRequest, responseHeader []byte) bool {
 	return (request.ResponseHeader[0] == responseHeader[0])
 }
 
-func handleResponse(conn net.Conn, request *protocol.VMessRequest, output chan<- *alloc.Buffer, finish *sync.Mutex, isUDP bool) {
+func (this *VMessOutboundHandler) handleResponse(conn net.Conn, request *protocol.VMessRequest, output chan<- *alloc.Buffer, finish *sync.Mutex, isUDP bool) {
 	defer finish.Unlock()
 	defer close(output)
 	responseKey := md5.Sum(request.RequestKey[:])
@@ -178,7 +178,7 @@ func handleResponse(conn net.Conn, request *protocol.VMessRequest, output chan<-
 		}
 		command := buffer.Value[2]
 		data := buffer.Value[4 : 4+dataLen]
-		go handleCommand(command, data)
+		go this.handleCommand(command, data)
 		responseBegin = 4 + dataLen
 	}
 

+ 117 - 9
proxy/vmess/outbound/receiver.go

@@ -2,40 +2,148 @@ package outbound
 
 import (
 	"math/rand"
+	"sync"
+	"time"
 
 	v2net "github.com/v2ray/v2ray-core/common/net"
 	"github.com/v2ray/v2ray-core/proxy/vmess"
 )
 
 type Receiver struct {
+	sync.RWMutex
 	Destination v2net.Destination
 	Accounts    []*vmess.User
 }
 
+func NewReceiver(dest v2net.Destination, users ...*vmess.User) *Receiver {
+	return &Receiver{
+		Destination: dest,
+		Accounts:    users,
+	}
+}
+
+func (this *Receiver) HasUser(user *vmess.User) bool {
+	this.RLock()
+	defer this.RUnlock()
+	for _, u := range this.Accounts {
+		// TODO: handle AlterIds difference.
+		if u.ID.Equals(user.ID) {
+			return true
+		}
+	}
+	return false
+}
+
+func (this *Receiver) AddUser(user *vmess.User) {
+	if this.HasUser(user) {
+		return
+	}
+	this.Lock()
+	this.Accounts = append(this.Accounts, user)
+	this.Unlock()
+}
+
+func (this *Receiver) PickUser() *vmess.User {
+	userLen := len(this.Accounts)
+	userIdx := 0
+	if userLen > 1 {
+		userIdx = rand.Intn(userLen)
+	}
+	return this.Accounts[userIdx]
+}
+
+type ExpiringReceiver struct {
+	*Receiver
+	until time.Time
+}
+
+func (this *ExpiringReceiver) Expired() bool {
+	return this.until.After(time.Now())
+}
+
 type ReceiverManager struct {
-	receivers []*Receiver
+	receivers    []*Receiver
+	detours      []*ExpiringReceiver
+	detourAccess sync.RWMutex
 }
 
 func NewReceiverManager(receivers []*Receiver) *ReceiverManager {
 	return &ReceiverManager{
 		receivers: receivers,
+		detours:   make([]*ExpiringReceiver, 0, 16),
 	}
 }
 
-func (this *ReceiverManager) PickReceiver() (v2net.Destination, *vmess.User) {
+func (this *ReceiverManager) AddDetour(rec *Receiver, availableMin byte) {
+	if availableMin < 2 {
+		return
+	}
+	this.detourAccess.RLock()
+	destExists := false
+	for _, r := range this.detours {
+		if r.Destination == rec.Destination {
+			destExists = true
+			// Destination exists, add new user if necessary
+			for _, u := range rec.Accounts {
+				r.AddUser(u)
+			}
+		}
+	}
+
+	this.detourAccess.RUnlock()
+	expRec := &ExpiringReceiver{
+		Receiver: rec,
+		until:    time.Now().Add(time.Duration(availableMin-1) * time.Minute),
+	}
+	if !destExists {
+		this.detourAccess.Lock()
+		this.detours = append(this.detours, expRec)
+		this.detourAccess.Unlock()
+	}
+}
+
+func (this *ReceiverManager) pickDetour() *Receiver {
+	if len(this.detours) == 0 {
+		return nil
+	}
+	this.detourAccess.RLock()
+	idx := 0
+	detourLen := len(this.detours)
+	if detourLen > 1 {
+		idx = rand.Intn(detourLen)
+	}
+	rec := this.detours[idx]
+	this.detourAccess.RUnlock()
+
+	if rec.Expired() {
+		this.detourAccess.Lock()
+		detourLen := len(this.detours)
+		this.detours[idx] = this.detours[detourLen-1]
+		this.detours = this.detours[:detourLen-1]
+		this.detourAccess.Unlock()
+		return nil
+	}
+
+	return rec.Receiver
+}
+
+func (this *ReceiverManager) pickStdReceiver() *Receiver {
 	receiverLen := len(this.receivers)
+
 	receiverIdx := 0
 	if receiverLen > 1 {
 		receiverIdx = rand.Intn(receiverLen)
 	}
 
-	receiver := this.receivers[receiverIdx]
+	return this.receivers[receiverIdx]
+}
 
-	userLen := len(receiver.Accounts)
-	userIdx := 0
-	if userLen > 1 {
-		userIdx = rand.Intn(userLen)
+func (this *ReceiverManager) PickReceiver() (v2net.Destination, *vmess.User) {
+	rec := this.pickDetour()
+	if rec == nil {
+		rec = this.pickStdReceiver()
 	}
-	user := receiver.Accounts[userIdx]
-	return receiver.Destination, user
+	user := rec.PickUser()
+
+	return rec.Destination, user
 }

+ 1 - 1
proxy/vmess/outbound/receiver_json_test.go

@@ -21,7 +21,7 @@ func TestConfigTargetParsing(t *testing.T) {
       {
         "id": "e641f5ad-9397-41e3-bf1a-e8740dfed019",
         "email": "love@v2ray.com",
-        "level": 999
+        "level": 255
       }
     ]
   }`

+ 20 - 2
proxy/vmess/user.go

@@ -4,10 +4,10 @@ import (
 	"math/rand"
 )
 
-type UserLevel int
+type UserLevel byte
 
 const (
-	UserLevelAdmin     = UserLevel(999)
+	UserLevelAdmin     = UserLevel(255)
 	UserLevelUntrusted = UserLevel(0)
 )
 
@@ -17,6 +17,24 @@ type User struct {
 	Level    UserLevel
 }
 
+func NewUser(id *ID, level UserLevel, alterIdCount uint16) *User {
+	u := &User{
+		ID:    id,
+		Level: level,
+	}
+	if alterIdCount > 0 {
+		u.AlterIDs = make([]*ID, alterIdCount)
+		prevId := id.UUID()
+		for idx, _ := range u.AlterIDs {
+			newid := prevId.Next()
+			// TODO: check duplicate
+			u.AlterIDs[idx] = NewID(newid)
+			prevId = newid
+		}
+	}
+	return u
+}
+
 func (this *User) AnyValidID() *ID {
 	if len(this.AlterIDs) == 0 {
 		return this.ID

+ 2 - 15
proxy/vmess/user_json.go

@@ -12,7 +12,7 @@ func (u *User) UnmarshalJSON(data []byte) error {
 	type rawUser struct {
 		IdString     string `json:"id"`
 		EmailString  string `json:"email"`
-		LevelInt     int    `json:"level"`
+		LevelByte    byte   `json:"level"`
 		AlterIdCount uint16 `json:"alterId"`
 	}
 	var rawUserValue rawUser
@@ -23,20 +23,7 @@ func (u *User) UnmarshalJSON(data []byte) error {
 	if err != nil {
 		return err
 	}
-	u.ID = NewID(id)
-	//u.Email = rawUserValue.EmailString
-	u.Level = UserLevel(rawUserValue.LevelInt)
-
-	if rawUserValue.AlterIdCount > 0 {
-		prevId := u.ID.UUID()
-		// TODO: check duplicate
-		u.AlterIDs = make([]*ID, rawUserValue.AlterIdCount)
-		for idx, _ := range u.AlterIDs {
-			newid := prevId.Next()
-			u.AlterIDs[idx] = NewID(newid)
-			prevId = newid
-		}
-	}
+	*u = *NewUser(NewID(id), UserLevel(rawUserValue.LevelByte), rawUserValue.AlterIdCount)
 
 	return nil
 }