Browse Source

fix test break

V2Ray 10 years ago
parent
commit
5af327f00b
7 changed files with 53 additions and 19 deletions
  1. 4 1
      id.go
  2. 2 2
      id_test.go
  3. 3 3
      io/vmess/vmess.go
  4. 7 4
      io/vmess/vmess_test.go
  5. 3 3
      net/vmess/vmessin.go
  6. 23 0
      testing/mocks/mockuserset.go
  7. 11 6
      userset.go

+ 4 - 1
id.go

@@ -54,7 +54,10 @@ func (v ID) TimeHash(timeSec int64) []byte {
 }
 }
 
 
 func (v ID) Hash(data []byte) []byte {
 func (v ID) Hash(data []byte) []byte {
-	return v.hasher.Sum(data)
+	v.hasher.Write(data)
+	hash := v.hasher.Sum(nil)
+	v.hasher.Reset()
+	return hash
 }
 }
 
 
 var byteGroups = []int{8, 4, 4, 4, 12}
 var byteGroups = []int{8, 4, 4, 4, 12}

+ 2 - 2
id_test.go

@@ -12,6 +12,6 @@ func TestUUIDToID(t *testing.T) {
 	uuid := "2418d087-648d-4990-86e8-19dca1d006d3"
 	uuid := "2418d087-648d-4990-86e8-19dca1d006d3"
 	expectedBytes := []byte{0x24, 0x18, 0xd0, 0x87, 0x64, 0x8d, 0x49, 0x90, 0x86, 0xe8, 0x19, 0xdc, 0xa1, 0xd0, 0x06, 0xd3}
 	expectedBytes := []byte{0x24, 0x18, 0xd0, 0x87, 0x64, 0x8d, 0x49, 0x90, 0x86, 0xe8, 0x19, 0xdc, 0xa1, 0xd0, 0x06, 0xd3}
 
 
-	actualBytes, _ := UUIDToID(uuid)
-	assert.Bytes(actualBytes.Bytes()).Named("UUID").Equals(expectedBytes)
+	actualBytes, _ := NewID(uuid)
+	assert.Bytes(actualBytes.Bytes).Named("UUID").Equals(expectedBytes)
 }
 }

+ 3 - 3
io/vmess/vmess.go

@@ -49,10 +49,10 @@ type VMessRequest struct {
 }
 }
 
 
 type VMessRequestReader struct {
 type VMessRequestReader struct {
-	vUserSet *core.UserSet
+	vUserSet core.UserSet
 }
 }
 
 
-func NewVMessRequestReader(vUserSet *core.UserSet) *VMessRequestReader {
+func NewVMessRequestReader(vUserSet core.UserSet) *VMessRequestReader {
 	reader := new(VMessRequestReader)
 	reader := new(VMessRequestReader)
 	reader.vUserSet = vUserSet
 	reader.vUserSet = vUserSet
 	return reader
 	return reader
@@ -74,7 +74,7 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	userId, valid := r.vUserSet.IsValidUserId(buffer[:nBytes])
+	userId, valid := r.vUserSet.GetUser(buffer[:nBytes])
 	if !valid {
 	if !valid {
 		return nil, ErrorInvalidUser
 		return nil, ErrorInvalidUser
 	}
 	}

+ 7 - 4
io/vmess/vmess_test.go

@@ -7,18 +7,19 @@ import (
 
 
 	"github.com/v2ray/v2ray-core"
 	"github.com/v2ray/v2ray-core"
 	v2net "github.com/v2ray/v2ray-core/net"
 	v2net "github.com/v2ray/v2ray-core/net"
+	"github.com/v2ray/v2ray-core/testing/mocks"
 	"github.com/v2ray/v2ray-core/testing/unit"
 	"github.com/v2ray/v2ray-core/testing/unit"
 )
 )
 
 
 func TestVMessSerialization(t *testing.T) {
 func TestVMessSerialization(t *testing.T) {
 	assert := unit.Assert(t)
 	assert := unit.Assert(t)
 
 
-	userId, err := core.UUIDToID("2b2966ac-16aa-4fbf-8d81-c5f172a3da51")
+	userId, err := core.NewID("2b2966ac-16aa-4fbf-8d81-c5f172a3da51")
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
-	userSet := core.NewUserSet()
+	userSet := mocks.MockUserSet{[]core.ID{}, make(map[string]int)}
 	userSet.AddUser(core.User{userId})
 	userSet.AddUser(core.User{userId})
 
 
 	request := new(VMessRequest)
 	request := new(VMessRequest)
@@ -50,14 +51,16 @@ func TestVMessSerialization(t *testing.T) {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
-	requestReader := NewVMessRequestReader(userSet)
+	userSet.UserHashes[string(buffer.Bytes()[1:17])] = 0
+
+	requestReader := NewVMessRequestReader(&userSet)
 	actualRequest, err := requestReader.Read(buffer)
 	actualRequest, err := requestReader.Read(buffer)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
 	assert.Byte(actualRequest.Version).Named("Version").Equals(byte(0x01))
 	assert.Byte(actualRequest.Version).Named("Version").Equals(byte(0x01))
-	assert.Bytes(actualRequest.UserId[:]).Named("UserId").Equals(request.UserId[:])
+	assert.String(actualRequest.UserId.String).Named("UserId").Equals(request.UserId.String)
 	assert.Bytes(actualRequest.RequestIV[:]).Named("RequestIV").Equals(request.RequestIV[:])
 	assert.Bytes(actualRequest.RequestIV[:]).Named("RequestIV").Equals(request.RequestIV[:])
 	assert.Bytes(actualRequest.RequestKey[:]).Named("RequestKey").Equals(request.RequestKey[:])
 	assert.Bytes(actualRequest.RequestKey[:]).Named("RequestKey").Equals(request.RequestKey[:])
 	assert.Bytes(actualRequest.ResponseHeader[:]).Named("ResponseHeader").Equals(request.ResponseHeader[:])
 	assert.Bytes(actualRequest.ResponseHeader[:]).Named("ResponseHeader").Equals(request.ResponseHeader[:])

+ 3 - 3
net/vmess/vmessin.go

@@ -15,11 +15,11 @@ import (
 
 
 type VMessInboundHandler struct {
 type VMessInboundHandler struct {
 	vPoint    *core.Point
 	vPoint    *core.Point
-	clients   *core.UserSet
+	clients   core.UserSet
 	accepting bool
 	accepting bool
 }
 }
 
 
-func NewVMessInboundHandler(vp *core.Point, clients *core.UserSet) *VMessInboundHandler {
+func NewVMessInboundHandler(vp *core.Point, clients core.UserSet) *VMessInboundHandler {
 	handler := new(VMessInboundHandler)
 	handler := new(VMessInboundHandler)
 	handler.vPoint = vp
 	handler.vPoint = vp
 	handler.clients = clients
 	handler.clients = clients
@@ -121,7 +121,7 @@ func (factory *VMessInboundHandlerFactory) Create(vp *core.Point, rawConfig []by
 	if err != nil {
 	if err != nil {
 		panic(log.Error("Failed to load VMess inbound config: %v", err))
 		panic(log.Error("Failed to load VMess inbound config: %v", err))
 	}
 	}
-	allowedClients := core.NewUserSet()
+	allowedClients := core.NewTimedUserSet()
 	for _, client := range config.AllowedClients {
 	for _, client := range config.AllowedClients {
 		user, err := client.ToUser()
 		user, err := client.ToUser()
 		if err != nil {
 		if err != nil {

+ 23 - 0
testing/mocks/mockuserset.go

@@ -0,0 +1,23 @@
+package mocks
+
+import (
+	"github.com/v2ray/v2ray-core"
+)
+
+type MockUserSet struct {
+	UserIds    []core.ID
+	UserHashes map[string]int
+}
+
+func (us *MockUserSet) AddUser(user core.User) error {
+	us.UserIds = append(us.UserIds, user.Id)
+	return nil
+}
+
+func (us *MockUserSet) GetUser(userhash []byte) (*core.ID, bool) {
+	idx, found := us.UserHashes[string(userhash)]
+	if found {
+		return &us.UserIds[idx], true
+	}
+	return nil, false
+}

+ 11 - 6
userset.go

@@ -9,7 +9,12 @@ const (
 	cacheDurationSec  = 120
 	cacheDurationSec  = 120
 )
 )
 
 
-type UserSet struct {
+type UserSet interface {
+	AddUser(user User) error
+	GetUser(timeHash []byte) (*ID, bool)
+}
+
+type TimedUserSet struct {
 	validUserIds []ID
 	validUserIds []ID
 	userHashes   map[string]int
 	userHashes   map[string]int
 }
 }
@@ -19,8 +24,8 @@ type hashEntry struct {
 	timeSec int64
 	timeSec int64
 }
 }
 
 
-func NewUserSet() *UserSet {
-	vuSet := new(UserSet)
+func NewTimedUserSet() UserSet {
+	vuSet := new(TimedUserSet)
 	vuSet.validUserIds = make([]ID, 0, 16)
 	vuSet.validUserIds = make([]ID, 0, 16)
 	vuSet.userHashes = make(map[string]int)
 	vuSet.userHashes = make(map[string]int)
 
 
@@ -28,7 +33,7 @@ func NewUserSet() *UserSet {
 	return vuSet
 	return vuSet
 }
 }
 
 
-func (us *UserSet) updateUserHash(tick <-chan time.Time) {
+func (us *TimedUserSet) updateUserHash(tick <-chan time.Time) {
 	now := time.Now().UTC()
 	now := time.Now().UTC()
 	lastSec := now.Unix() - cacheDurationSec
 	lastSec := now.Unix() - cacheDurationSec
 
 
@@ -57,13 +62,13 @@ func (us *UserSet) updateUserHash(tick <-chan time.Time) {
 	}
 	}
 }
 }
 
 
-func (us *UserSet) AddUser(user User) error {
+func (us *TimedUserSet) AddUser(user User) error {
 	id := user.Id
 	id := user.Id
 	us.validUserIds = append(us.validUserIds, id)
 	us.validUserIds = append(us.validUserIds, id)
 	return nil
 	return nil
 }
 }
 
 
-func (us UserSet) IsValidUserId(userHash []byte) (*ID, bool) {
+func (us TimedUserSet) GetUser(userHash []byte) (*ID, bool) {
 	idIndex, found := us.userHashes[string(userHash)]
 	idIndex, found := us.userHashes[string(userHash)]
 	if found {
 	if found {
 		return &us.validUserIds[idIndex], true
 		return &us.validUserIds[idIndex], true