Ver Fonte

implement remove user in vmess

Darien Raymond há 7 anos atrás
pai
commit
87ba7dd0d1

+ 1 - 0
common/protocol/user_validator.go

@@ -3,4 +3,5 @@ package protocol
 type UserValidator interface {
 	Add(user *User) error
 	Get(timeHash []byte) (*User, Timestamp, bool)
+	Remove(email string) bool
 }

+ 57 - 23
proxy/vmess/inbound/inbound.go

@@ -5,6 +5,7 @@ package inbound
 import (
 	"context"
 	"io"
+	"strings"
 	"sync"
 	"time"
 
@@ -25,7 +26,7 @@ import (
 )
 
 type userByEmail struct {
-	sync.RWMutex
+	sync.Mutex
 	cache           map[string]*protocol.User
 	defaultLevel    uint32
 	defaultAlterIDs uint16
@@ -34,7 +35,7 @@ type userByEmail struct {
 func newUserByEmail(users []*protocol.User, config *DefaultConfig) *userByEmail {
 	cache := make(map[string]*protocol.User)
 	for _, user := range users {
-		cache[user.Email] = user
+		cache[strings.ToLower(user.Email)] = user
 	}
 	return &userByEmail{
 		cache:           cache,
@@ -43,33 +44,59 @@ func newUserByEmail(users []*protocol.User, config *DefaultConfig) *userByEmail
 	}
 }
 
+func (v *userByEmail) addNoLock(u *protocol.User) bool {
+	email := strings.ToLower(u.Email)
+	user, found := v.cache[email]
+	if found {
+		return false
+	}
+	v.cache[email] = user
+	return true
+}
+
+func (v *userByEmail) Add(u *protocol.User) bool {
+	v.Lock()
+	defer v.Unlock()
+
+	return v.addNoLock(u)
+}
+
 func (v *userByEmail) Get(email string) (*protocol.User, bool) {
-	var user *protocol.User
-	var found bool
-	v.RLock()
-	user, found = v.cache[email]
-	v.RUnlock()
+	email = strings.ToLower(email)
+
+	v.Lock()
+	defer v.Unlock()
+
+	user, found := v.cache[email]
 	if !found {
-		v.Lock()
-		user, found = v.cache[email]
-		if !found {
-			id := uuid.New()
-			account := &vmess.Account{
-				Id:      id.String(),
-				AlterId: uint32(v.defaultAlterIDs),
-			}
-			user = &protocol.User{
-				Level:   v.defaultLevel,
-				Email:   email,
-				Account: serial.ToTypedMessage(account),
-			}
-			v.cache[email] = user
+		id := uuid.New()
+		account := &vmess.Account{
+			Id:      id.String(),
+			AlterId: uint32(v.defaultAlterIDs),
 		}
-		v.Unlock()
+		user = &protocol.User{
+			Level:   v.defaultLevel,
+			Email:   email,
+			Account: serial.ToTypedMessage(account),
+		}
+		v.cache[email] = user
 	}
 	return user, found
 }
 
+func (v *userByEmail) Remove(email string) bool {
+	email = strings.ToLower(email)
+
+	v.Lock()
+	defer v.Unlock()
+
+	if _, found := v.cache[email]; !found {
+		return false
+	}
+	delete(v.cache, email)
+	return true
+}
+
 // Handler is an inbound connection handler that handles messages in VMess protocol.
 type Handler struct {
 	policyManager         core.PolicyManager
@@ -129,11 +156,18 @@ func (h *Handler) GetUser(email string) *protocol.User {
 }
 
 func (h *Handler) AddUser(ctx context.Context, user *protocol.User) error {
+	if !h.usersByEmail.Add(user) {
+		return newError("User ", user.Email, " already exists.")
+	}
 	return h.clients.Add(user)
 }
 
 func (h *Handler) RemoveUser(ctx context.Context, email string) error {
-	return newError("not implemented")
+	if !h.usersByEmail.Remove(email) {
+		return newError("User ", email, " not found.")
+	}
+	h.clients.Remove(email)
+	return nil
 }
 
 func transferRequest(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output ray.OutputStream) error {

+ 65 - 47
proxy/vmess/vmess.go

@@ -8,6 +8,7 @@ package vmess
 //go:generate go run $GOPATH/src/v2ray.com/core/common/errors/errorgen/main.go -pkg vmess -path Proxy,VMess
 
 import (
+	"strings"
 	"sync"
 	"time"
 
@@ -21,34 +22,32 @@ const (
 	cacheDurationSec = 120
 )
 
-type idEntry struct {
-	id      *protocol.ID
-	userIdx int
+type user struct {
+	user    *protocol.User
+	account *InternalAccount
 	lastSec protocol.Timestamp
 }
 
 type TimedUserValidator struct {
 	sync.RWMutex
-	validUsers []*protocol.User
-	userHash   map[[16]byte]indexTimePair
-	ids        []*idEntry
-	hasher     protocol.IDHash
-	baseTime   protocol.Timestamp
-	task       *signal.PeriodicTask
+	users    []*user
+	userHash map[[16]byte]indexTimePair
+	hasher   protocol.IDHash
+	baseTime protocol.Timestamp
+	task     *signal.PeriodicTask
 }
 
 type indexTimePair struct {
-	index   int
+	user    *user
 	timeInc uint32
 }
 
 func NewTimedUserValidator(hasher protocol.IDHash) protocol.UserValidator {
 	tuv := &TimedUserValidator{
-		validUsers: make([]*protocol.User, 0, 16),
-		userHash:   make(map[[16]byte]indexTimePair, 512),
-		ids:        make([]*idEntry, 0, 512),
-		hasher:     hasher,
-		baseTime:   protocol.Timestamp(time.Now().Unix() - cacheDurationSec*3),
+		users:    make([]*user, 0, 16),
+		userHash: make(map[[16]byte]indexTimePair, 1024),
+		hasher:   hasher,
+		baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*3),
 	}
 	tuv.task = &signal.PeriodicTask{
 		Interval: updateInterval,
@@ -61,21 +60,27 @@ func NewTimedUserValidator(hasher protocol.IDHash) protocol.UserValidator {
 	return tuv
 }
 
-func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, idx int, entry *idEntry) {
+func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *user) {
 	var hashValue [16]byte
-	idHash := v.hasher(entry.id.Bytes())
-	for entry.lastSec <= nowSec {
-		common.Must2(idHash.Write(entry.lastSec.Bytes(nil)))
-		idHash.Sum(hashValue[:0])
-		idHash.Reset()
-
-		v.userHash[hashValue] = indexTimePair{
-			index:   idx,
-			timeInc: uint32(entry.lastSec - v.baseTime),
+	genHashForID := func(id *protocol.ID) {
+		idHash := v.hasher(id.Bytes())
+		for ts := user.lastSec; ts <= nowSec; ts++ {
+			common.Must2(idHash.Write(ts.Bytes(nil)))
+			idHash.Sum(hashValue[:0])
+			idHash.Reset()
+
+			v.userHash[hashValue] = indexTimePair{
+				user:    user,
+				timeInc: uint32(ts - v.baseTime),
+			}
 		}
+	}
 
-		entry.lastSec++
+	genHashForID(user.account.ID)
+	for _, id := range user.account.AlterIDs {
+		genHashForID(id)
 	}
+	user.lastSec = nowSec
 }
 
 func (v *TimedUserValidator) removeExpiredHashes(expire uint32) {
@@ -92,8 +97,8 @@ func (v *TimedUserValidator) updateUserHash() {
 	v.Lock()
 	defer v.Unlock()
 
-	for _, entry := range v.ids {
-		v.generateNewHashes(nowSec, entry.userIdx, entry)
+	for _, user := range v.users {
+		v.generateNewHashes(nowSec, user)
 	}
 
 	expire := protocol.Timestamp(now.Unix() - cacheDurationSec*3)
@@ -102,13 +107,11 @@ func (v *TimedUserValidator) updateUserHash() {
 	}
 }
 
-func (v *TimedUserValidator) Add(user *protocol.User) error {
+func (v *TimedUserValidator) Add(u *protocol.User) error {
 	v.Lock()
 	defer v.Unlock()
 
-	idx := len(v.validUsers)
-	v.validUsers = append(v.validUsers, user)
-	rawAccount, err := user.GetTypedAccount()
+	rawAccount, err := u.GetTypedAccount()
 	if err != nil {
 		return err
 	}
@@ -116,22 +119,13 @@ func (v *TimedUserValidator) Add(user *protocol.User) error {
 
 	nowSec := time.Now().Unix()
 
-	entry := &idEntry{
-		id:      account.ID,
-		userIdx: idx,
+	uu := &user{
+		user:    u,
+		account: account,
 		lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
 	}
-	v.generateNewHashes(protocol.Timestamp(nowSec+cacheDurationSec), idx, entry)
-	v.ids = append(v.ids, entry)
-	for _, alterid := range account.AlterIDs {
-		entry := &idEntry{
-			id:      alterid,
-			userIdx: idx,
-			lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
-		}
-		v.generateNewHashes(protocol.Timestamp(nowSec+cacheDurationSec), idx, entry)
-		v.ids = append(v.ids, entry)
-	}
+	v.users = append(v.users, uu)
+	v.generateNewHashes(protocol.Timestamp(nowSec+cacheDurationSec), uu)
 
 	return nil
 }
@@ -144,11 +138,35 @@ func (v *TimedUserValidator) Get(userHash []byte) (*protocol.User, protocol.Time
 	copy(fixedSizeHash[:], userHash)
 	pair, found := v.userHash[fixedSizeHash]
 	if found {
-		return v.validUsers[pair.index], protocol.Timestamp(pair.timeInc) + v.baseTime, true
+		return pair.user.user, protocol.Timestamp(pair.timeInc) + v.baseTime, true
 	}
 	return nil, 0, false
 }
 
+func (v *TimedUserValidator) Remove(email string) bool {
+	v.Lock()
+	defer v.Unlock()
+
+	email = strings.ToLower(email)
+	idx := -1
+	for i, u := range v.users {
+		if strings.ToLower(u.user.Email) == email {
+			idx = i
+			break
+		}
+	}
+	if idx == -1 {
+		return false
+	}
+	ulen := len(v.users)
+	if idx < len(v.users) {
+		v.users[idx] = v.users[ulen-1]
+		v.users[ulen-1] = nil
+		v.users = v.users[:ulen-1]
+	}
+	return true
+}
+
 // Close implements common.Closable.
 func (v *TimedUserValidator) Close() error {
 	return v.task.Close()

+ 58 - 0
proxy/vmess/vmess_test.go

@@ -0,0 +1,58 @@
+package vmess_test
+
+import (
+	"testing"
+	"time"
+
+	"v2ray.com/core/common"
+	"v2ray.com/core/common/serial"
+	"v2ray.com/core/common/uuid"
+
+	"v2ray.com/core/common/protocol"
+	. "v2ray.com/core/proxy/vmess"
+	. "v2ray.com/ext/assert"
+)
+
+func TestUserValidator(t *testing.T) {
+	assert := With(t)
+
+	hasher := protocol.DefaultIDHash
+	v := NewTimedUserValidator(hasher)
+	defer common.Close(v)
+
+	id := uuid.New()
+	user := &protocol.User{
+		Email: "test",
+		Account: serial.ToTypedMessage(&Account{
+			Id:      id.String(),
+			AlterId: 8,
+		}),
+	}
+	common.Must(v.Add(user))
+
+	{
+		ts := protocol.Timestamp(time.Now().Unix())
+		idHash := hasher(id.Bytes())
+		idHash.Write(ts.Bytes(nil))
+		userHash := idHash.Sum(nil)
+
+		euser, ets, found := v.Get(userHash)
+		assert(found, IsTrue)
+		assert(euser.Email, Equals, user.Email)
+		assert(int64(ets), Equals, int64(ts))
+	}
+
+	{
+		ts := protocol.Timestamp(time.Now().Add(time.Second * 500).Unix())
+		idHash := hasher(id.Bytes())
+		idHash.Write(ts.Bytes(nil))
+		userHash := idHash.Sum(nil)
+
+		euser, _, found := v.Get(userHash)
+		assert(found, IsFalse)
+		assert(euser, IsNil)
+	}
+
+	assert(v.Remove(user.Email), IsTrue)
+	assert(v.Remove(user.Email), IsFalse)
+}