v2ray 9 лет назад
Родитель
Сommit
791ac307a2

+ 123 - 0
common/protocol/user_validator.go

@@ -0,0 +1,123 @@
+package protocol
+
+import (
+	"hash"
+	"sync"
+	"time"
+)
+
+const (
+	updateIntervalSec = 10
+	cacheDurationSec  = 120
+)
+
+type IDHash func(key []byte) hash.Hash
+
+type idEntry struct {
+	id             *ID
+	userIdx        int
+	lastSec        Timestamp
+	lastSecRemoval Timestamp
+}
+
+type UserValidator interface {
+	Add(user *User) error
+	Get(timeHash []byte) (*User, Timestamp, bool)
+}
+
+type TimedUserValidator struct {
+	validUsers []*User
+	userHash   map[[16]byte]*indexTimePair
+	ids        []*idEntry
+	access     sync.RWMutex
+	hasher     IDHash
+}
+
+type indexTimePair struct {
+	index   int
+	timeSec Timestamp
+}
+
+func NewTimedUserValidator(hasher IDHash) UserValidator {
+	tus := &TimedUserValidator{
+		validUsers: make([]*User, 0, 16),
+		userHash:   make(map[[16]byte]*indexTimePair, 512),
+		access:     sync.RWMutex{},
+		ids:        make([]*idEntry, 0, 512),
+		hasher:     hasher,
+	}
+	go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second))
+	return tus
+}
+
+func (this *TimedUserValidator) generateNewHashes(nowSec Timestamp, idx int, entry *idEntry) {
+	var hashValue [16]byte
+	var hashValueRemoval [16]byte
+	idHash := this.hasher(entry.id.Bytes())
+	for entry.lastSec <= nowSec {
+		idHash.Write(entry.lastSec.Bytes())
+		idHash.Sum(hashValue[:0])
+		idHash.Reset()
+
+		idHash.Write(entry.lastSecRemoval.Bytes())
+		idHash.Sum(hashValueRemoval[:0])
+		idHash.Reset()
+
+		this.access.Lock()
+		this.userHash[hashValue] = &indexTimePair{idx, entry.lastSec}
+		delete(this.userHash, hashValueRemoval)
+		this.access.Unlock()
+
+		entry.lastSec++
+		entry.lastSecRemoval++
+	}
+}
+
+func (this *TimedUserValidator) updateUserHash(tick <-chan time.Time) {
+	for now := range tick {
+		nowSec := Timestamp(now.Unix() + cacheDurationSec)
+		for _, entry := range this.ids {
+			this.generateNewHashes(nowSec, entry.userIdx, entry)
+		}
+	}
+}
+
+func (this *TimedUserValidator) Add(user *User) error {
+	idx := len(this.validUsers)
+	this.validUsers = append(this.validUsers, user)
+
+	nowSec := time.Now().Unix()
+
+	entry := &idEntry{
+		id:             user.ID,
+		userIdx:        idx,
+		lastSec:        Timestamp(nowSec - cacheDurationSec),
+		lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
+	}
+	this.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
+	this.ids = append(this.ids, entry)
+	for _, alterid := range user.AlterIDs {
+		entry := &idEntry{
+			id:             alterid,
+			userIdx:        idx,
+			lastSec:        Timestamp(nowSec - cacheDurationSec),
+			lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
+		}
+		this.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
+		this.ids = append(this.ids, entry)
+	}
+
+	return nil
+}
+
+func (this *TimedUserValidator) Get(userHash []byte) (*User, Timestamp, bool) {
+	defer this.access.RUnlock()
+	this.access.RLock()
+	var fixedSizeHash [16]byte
+	copy(fixedSizeHash[:], userHash)
+	pair, found := this.userHash[fixedSizeHash]
+	if found {
+		return this.validUsers[pair.index], pair.timeSec, true
+	}
+	return nil, 0, false
+}

+ 4 - 4
proxy/vmess/inbound/inbound.go

@@ -66,7 +66,7 @@ type VMessInboundHandler struct {
 	sync.Mutex
 	packetDispatcher      dispatcher.PacketDispatcher
 	inboundHandlerManager proxyman.InboundHandlerManager
-	clients               protocol.UserSet
+	clients               proto.UserValidator
 	usersByEmail          *userByEmail
 	accepting             bool
 	listener              *hub.TCPHub
@@ -91,7 +91,7 @@ func (this *VMessInboundHandler) Close() {
 func (this *VMessInboundHandler) GetUser(email string) *proto.User {
 	user, existing := this.usersByEmail.Get(email)
 	if !existing {
-		this.clients.AddUser(user)
+		this.clients.Add(user)
 	}
 	return user
 }
@@ -211,9 +211,9 @@ func init() {
 			}
 			config := rawConfig.(*Config)
 
-			allowedClients := protocol.NewTimedUserSet()
+			allowedClients := proto.NewTimedUserValidator(protocol.IDHash)
 			for _, user := range config.AllowedUsers {
-				allowedClients.AddUser(user)
+				allowedClients.Add(user)
 			}
 
 			handler := &VMessInboundHandler{

+ 2 - 1
proxy/vmess/outbound/outbound.go

@@ -14,6 +14,7 @@ import (
 	v2io "github.com/v2ray/v2ray-core/common/io"
 	"github.com/v2ray/v2ray-core/common/log"
 	v2net "github.com/v2ray/v2ray-core/common/net"
+	proto "github.com/v2ray/v2ray-core/common/protocol"
 	"github.com/v2ray/v2ray-core/proxy"
 	"github.com/v2ray/v2ray-core/proxy/internal"
 	vmessio "github.com/v2ray/v2ray-core/proxy/vmess/io"
@@ -106,7 +107,7 @@ func (this *VMessOutboundHandler) handleRequest(conn net.Conn, request *protocol
 
 	buffer := alloc.NewBuffer().Clear()
 	defer buffer.Release()
-	buffer, err = request.ToBytes(protocol.NewRandomTimestampGenerator(protocol.Timestamp(time.Now().Unix()), 30), buffer)
+	buffer, err = request.ToBytes(protocol.NewRandomTimestampGenerator(proto.Timestamp(time.Now().Unix()), 30), buffer)
 	if err != nil {
 		log.Error("VMessOut: Failed to serialize VMess request: ", err)
 		return

+ 7 - 5
proxy/vmess/protocol/rand.go

@@ -2,25 +2,27 @@ package protocol
 
 import (
 	"math/rand"
+
+	"github.com/v2ray/v2ray-core/common/protocol"
 )
 
 type RandomTimestampGenerator interface {
-	Next() Timestamp
+	Next() protocol.Timestamp
 }
 
 type RealRandomTimestampGenerator struct {
-	base  Timestamp
+	base  protocol.Timestamp
 	delta int
 }
 
-func NewRandomTimestampGenerator(base Timestamp, delta int) RandomTimestampGenerator {
+func NewRandomTimestampGenerator(base protocol.Timestamp, delta int) RandomTimestampGenerator {
 	return &RealRandomTimestampGenerator{
 		base:  base,
 		delta: delta,
 	}
 }
 
-func (this *RealRandomTimestampGenerator) Next() Timestamp {
+func (this *RealRandomTimestampGenerator) Next() protocol.Timestamp {
 	rangeInDelta := rand.Intn(this.delta*2) - this.delta
-	return this.base + Timestamp(rangeInDelta)
+	return this.base + protocol.Timestamp(rangeInDelta)
 }

+ 2 - 1
proxy/vmess/protocol/rand_test.go

@@ -4,6 +4,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/v2ray/v2ray-core/common/protocol"
 	. "github.com/v2ray/v2ray-core/proxy/vmess/protocol"
 	v2testing "github.com/v2ray/v2ray-core/testing"
 	"github.com/v2ray/v2ray-core/testing/assert"
@@ -14,7 +15,7 @@ func TestGenerateRandomInt64InRange(t *testing.T) {
 
 	base := time.Now().Unix()
 	delta := 100
-	generator := NewRandomTimestampGenerator(Timestamp(base), delta)
+	generator := NewRandomTimestampGenerator(protocol.Timestamp(base), delta)
 
 	for i := 0; i < 100; i++ {
 		v := int64(generator.Next())

+ 4 - 5
proxy/vmess/protocol/testing/mockuserset.go

@@ -1,22 +1,21 @@
 package mocks
 
 import (
-	proto "github.com/v2ray/v2ray-core/common/protocol"
-	"github.com/v2ray/v2ray-core/proxy/vmess/protocol"
+	"github.com/v2ray/v2ray-core/common/protocol"
 )
 
 type MockUserSet struct {
-	Users      []*proto.User
+	Users      []*protocol.User
 	UserHashes map[string]int
 	Timestamps map[string]protocol.Timestamp
 }
 
-func (us *MockUserSet) AddUser(user *proto.User) error {
+func (us *MockUserSet) Add(user *protocol.User) error {
 	us.Users = append(us.Users, user)
 	return nil
 }
 
-func (us *MockUserSet) GetUser(userhash []byte) (*proto.User, protocol.Timestamp, bool) {
+func (us *MockUserSet) Get(userhash []byte) (*protocol.User, protocol.Timestamp, bool) {
 	idx, found := us.UserHashes[string(userhash)]
 	if found {
 		return us.Users[idx], us.Timestamps[string(userhash)], true

+ 5 - 6
proxy/vmess/protocol/testing/static_userset.go

@@ -1,21 +1,20 @@
 package mocks
 
 import (
-	proto "github.com/v2ray/v2ray-core/common/protocol"
+	"github.com/v2ray/v2ray-core/common/protocol"
 	"github.com/v2ray/v2ray-core/common/uuid"
-	"github.com/v2ray/v2ray-core/proxy/vmess/protocol"
 )
 
 type StaticUserSet struct {
 }
 
-func (us *StaticUserSet) AddUser(user *proto.User) error {
+func (us *StaticUserSet) Add(user *protocol.User) error {
 	return nil
 }
 
-func (us *StaticUserSet) GetUser(userhash []byte) (*proto.User, protocol.Timestamp, bool) {
+func (us *StaticUserSet) Get(userhash []byte) (*protocol.User, protocol.Timestamp, bool) {
 	id, _ := uuid.ParseString("703e9102-eb57-499c-8b59-faf4f371bb21")
-	return &proto.User{
-		ID: proto.NewID(id),
+	return &protocol.User{
+		ID: protocol.NewID(id),
 	}, 0, true
 }

+ 0 - 137
proxy/vmess/protocol/userset.go

@@ -1,137 +0,0 @@
-package protocol
-
-import (
-	"sync"
-	"time"
-
-	proto "github.com/v2ray/v2ray-core/common/protocol"
-	"github.com/v2ray/v2ray-core/common/serial"
-)
-
-const (
-	updateIntervalSec = 10
-	cacheDurationSec  = 120
-)
-
-type Timestamp int64
-
-func (this Timestamp) Bytes() []byte {
-	return serial.Int64Literal(this).Bytes()
-}
-
-func (this Timestamp) HashBytes() []byte {
-	once := this.Bytes()
-	bytes := make([]byte, 0, 32)
-	bytes = append(bytes, once...)
-	bytes = append(bytes, once...)
-	bytes = append(bytes, once...)
-	bytes = append(bytes, once...)
-	return bytes
-}
-
-type idEntry struct {
-	id             *proto.ID
-	userIdx        int
-	lastSec        Timestamp
-	lastSecRemoval Timestamp
-}
-
-type UserSet interface {
-	AddUser(user *proto.User) error
-	GetUser(timeHash []byte) (*proto.User, Timestamp, bool)
-}
-
-type TimedUserSet struct {
-	validUsers []*proto.User
-	userHash   map[[16]byte]*indexTimePair
-	ids        []*idEntry
-	access     sync.RWMutex
-}
-
-type indexTimePair struct {
-	index   int
-	timeSec Timestamp
-}
-
-func NewTimedUserSet() UserSet {
-	tus := &TimedUserSet{
-		validUsers: make([]*proto.User, 0, 16),
-		userHash:   make(map[[16]byte]*indexTimePair, 512),
-		access:     sync.RWMutex{},
-		ids:        make([]*idEntry, 0, 512),
-	}
-	go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second))
-	return tus
-}
-
-func (us *TimedUserSet) generateNewHashes(nowSec Timestamp, idx int, entry *idEntry) {
-	var hashValue [16]byte
-	var hashValueRemoval [16]byte
-	idHash := IDHash(entry.id.Bytes())
-	for entry.lastSec <= nowSec {
-		idHash.Write(entry.lastSec.Bytes())
-		idHash.Sum(hashValue[:0])
-		idHash.Reset()
-
-		idHash.Write(entry.lastSecRemoval.Bytes())
-		idHash.Sum(hashValueRemoval[:0])
-		idHash.Reset()
-
-		us.access.Lock()
-		us.userHash[hashValue] = &indexTimePair{idx, entry.lastSec}
-		delete(us.userHash, hashValueRemoval)
-		us.access.Unlock()
-
-		entry.lastSec++
-		entry.lastSecRemoval++
-	}
-}
-
-func (us *TimedUserSet) updateUserHash(tick <-chan time.Time) {
-	for now := range tick {
-		nowSec := Timestamp(now.Unix() + cacheDurationSec)
-		for _, entry := range us.ids {
-			us.generateNewHashes(nowSec, entry.userIdx, entry)
-		}
-	}
-}
-
-func (us *TimedUserSet) AddUser(user *proto.User) error {
-	idx := len(us.validUsers)
-	us.validUsers = append(us.validUsers, user)
-
-	nowSec := time.Now().Unix()
-
-	entry := &idEntry{
-		id:             user.ID,
-		userIdx:        idx,
-		lastSec:        Timestamp(nowSec - cacheDurationSec),
-		lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
-	}
-	us.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
-	us.ids = append(us.ids, entry)
-	for _, alterid := range user.AlterIDs {
-		entry := &idEntry{
-			id:             alterid,
-			userIdx:        idx,
-			lastSec:        Timestamp(nowSec - cacheDurationSec),
-			lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
-		}
-		us.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
-		us.ids = append(us.ids, entry)
-	}
-
-	return nil
-}
-
-func (us *TimedUserSet) GetUser(userHash []byte) (*proto.User, Timestamp, bool) {
-	defer us.access.RUnlock()
-	us.access.RLock()
-	var fixedSizeHash [16]byte
-	copy(fixedSizeHash[:], userHash)
-	pair, found := us.userHash[fixedSizeHash]
-	if found {
-		return us.validUsers[pair.index], pair.timeSec, true
-	}
-	return nil, 0, false
-}

+ 15 - 5
proxy/vmess/protocol/vmess.go

@@ -31,6 +31,16 @@ const (
 	blockSize = 16
 )
 
+func hashTimestamp(t proto.Timestamp) []byte {
+	once := t.Bytes()
+	bytes := make([]byte, 0, 32)
+	bytes = append(bytes, once...)
+	bytes = append(bytes, once...)
+	bytes = append(bytes, once...)
+	bytes = append(bytes, once...)
+	return bytes
+}
+
 // VMessRequest implements the request message of VMess protocol. It only contains the header of a
 // request message. The data part will be handled by connection handler directly, in favor of data
 // streaming.
@@ -61,11 +71,11 @@ func (this *VMessRequest) IsChunkStream() bool {
 
 // VMessRequestReader is a parser to read VMessRequest from a byte stream.
 type VMessRequestReader struct {
-	vUserSet UserSet
+	vUserSet proto.UserValidator
 }
 
 // NewVMessRequestReader creates a new VMessRequestReader with a given UserSet
-func NewVMessRequestReader(vUserSet UserSet) *VMessRequestReader {
+func NewVMessRequestReader(vUserSet proto.UserValidator) *VMessRequestReader {
 	return &VMessRequestReader{
 		vUserSet: vUserSet,
 	}
@@ -82,13 +92,13 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
 		return nil, err
 	}
 
-	userObj, timeSec, valid := this.vUserSet.GetUser(buffer.Value[:nBytes])
+	userObj, timeSec, valid := this.vUserSet.Get(buffer.Value[:nBytes])
 	if !valid {
 		return nil, proxy.ErrorInvalidAuthentication
 	}
 
 	timestampHash := TimestampHash()
-	timestampHash.Write(timeSec.HashBytes())
+	timestampHash.Write(hashTimestamp(timeSec))
 	iv := timestampHash.Sum(nil)
 	aesStream, err := v2crypto.NewAesDecryptionStream(userObj.ID.CmdKey(), iv)
 	if err != nil {
@@ -223,7 +233,7 @@ func (this *VMessRequest) ToBytes(timestampGenerator RandomTimestampGenerator, b
 	encryptionEnd += 4
 
 	timestampHash := md5.New()
-	timestampHash.Write(timestamp.HashBytes())
+	timestampHash.Write(hashTimestamp(timestamp))
 	iv := timestampHash.Sum(nil)
 	aesStream, err := v2crypto.NewAesEncryptionStream(this.User.ID.CmdKey(), iv)
 	if err != nil {

+ 8 - 8
proxy/vmess/protocol/vmess_test.go

@@ -17,10 +17,10 @@ import (
 )
 
 type FakeTimestampGenerator struct {
-	timestamp Timestamp
+	timestamp proto.Timestamp
 }
 
-func (this *FakeTimestampGenerator) Next() Timestamp {
+func (this *FakeTimestampGenerator) Next() proto.Timestamp {
 	return this.timestamp
 }
 
@@ -36,8 +36,8 @@ func TestVMessSerialization(t *testing.T) {
 		ID: userId,
 	}
 
-	userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]Timestamp)}
-	userSet.AddUser(testUser)
+	userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]proto.Timestamp)}
+	userSet.Add(testUser)
 
 	request := new(VMessRequest)
 	request.Version = byte(0x01)
@@ -54,7 +54,7 @@ func TestVMessSerialization(t *testing.T) {
 	request.Address = v2net.DomainAddress("v2ray.com")
 	request.Port = v2net.Port(80)
 
-	mockTime := Timestamp(1823730)
+	mockTime := proto.Timestamp(1823730)
 
 	buffer, err := request.ToBytes(&FakeTimestampGenerator{timestamp: mockTime}, nil)
 	if err != nil {
@@ -92,12 +92,12 @@ func BenchmarkVMessRequestWriting(b *testing.B) {
 	assert.Error(err).IsNil()
 
 	userId := proto.NewID(id)
-	userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]Timestamp)}
+	userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]proto.Timestamp)}
 
 	testUser := &proto.User{
 		ID: userId,
 	}
-	userSet.AddUser(testUser)
+	userSet.Add(testUser)
 
 	request := new(VMessRequest)
 	request.Version = byte(0x01)
@@ -114,6 +114,6 @@ func BenchmarkVMessRequestWriting(b *testing.B) {
 	request.Port = v2net.Port(80)
 
 	for i := 0; i < b.N; i++ {
-		request.ToBytes(NewRandomTimestampGenerator(Timestamp(time.Now().Unix()), 30), nil)
+		request.ToBytes(NewRandomTimestampGenerator(proto.Timestamp(time.Now().Unix()), 30), nil)
 	}
 }