Przeglądaj źródła

rewrite hashing logic in vmess

v2ray 9 lat temu
rodzic
commit
349b02084c

+ 2 - 1
proxy/vmess/outbound/outbound.go

@@ -5,6 +5,7 @@ import (
 	"crypto/rand"
 	"net"
 	"sync"
+	"time"
 
 	"github.com/v2ray/v2ray-core/app"
 	"github.com/v2ray/v2ray-core/common/alloc"
@@ -104,7 +105,7 @@ func handleRequest(conn net.Conn, request *protocol.VMessRequest, firstPacket v2
 
 	buffer := alloc.NewBuffer().Clear()
 	defer buffer.Release()
-	buffer, err = request.ToBytes(user.NewTimeHash(user.HMACHash{}), user.GenerateRandomInt64InRange, buffer)
+	buffer, err = request.ToBytes(user.NewRandomTimestampGenerator(user.Timestamp(time.Now().Unix()), 30), buffer)
 	if err != nil {
 		log.Error("VMessOut: Failed to serialize VMess request: %v", err)
 		return

+ 15 - 0
proxy/vmess/protocol/user/hash.go

@@ -0,0 +1,15 @@
+package user
+
+import (
+	"crypto/hmac"
+	"crypto/md5"
+	"hash"
+)
+
+func TimestampHash() hash.Hash {
+	return md5.New()
+}
+
+func IDHash(key []byte) hash.Hash {
+	return hmac.New(md5.New, key)
+}

+ 0 - 51
proxy/vmess/protocol/user/idhash.go

@@ -1,51 +0,0 @@
-package user
-
-import (
-	"crypto/hmac"
-	"crypto/md5"
-)
-
-type CounterHash interface {
-	Hash(key []byte, counter int64) []byte
-}
-
-type StringHash interface {
-	Hash(key []byte, data []byte) []byte
-}
-
-type TimeHash struct {
-	baseHash StringHash
-}
-
-func NewTimeHash(baseHash StringHash) CounterHash {
-	return TimeHash{
-		baseHash: baseHash,
-	}
-}
-
-func (h TimeHash) Hash(key []byte, counter int64) []byte {
-	counterBytes := int64ToBytes(counter)
-	return h.baseHash.Hash(key, counterBytes)
-}
-
-type HMACHash struct {
-}
-
-func (h HMACHash) Hash(key []byte, data []byte) []byte {
-	hash := hmac.New(md5.New, key)
-	hash.Write(data)
-	return hash.Sum(nil)
-}
-
-func int64ToBytes(value int64) []byte {
-	return []byte{
-		byte(value >> 56),
-		byte(value >> 48),
-		byte(value >> 40),
-		byte(value >> 32),
-		byte(value >> 24),
-		byte(value >> 16),
-		byte(value >> 8),
-		byte(value),
-	}
-}

+ 0 - 15
proxy/vmess/protocol/user/inthash.go

@@ -1,15 +0,0 @@
-package user
-
-import (
-	"crypto/md5"
-)
-
-func Int64Hash(value int64) []byte {
-	md5hash := md5.New()
-	buffer := int64ToBytes(value)
-	md5hash.Write(buffer)
-	md5hash.Write(buffer)
-	md5hash.Write(buffer)
-	md5hash.Write(buffer)
-	return md5hash.Sum(nil)
-}

+ 18 - 4
proxy/vmess/protocol/user/rand.go

@@ -4,9 +4,23 @@ import (
 	"math/rand"
 )
 
-type RandomInt64InRange func(base int64, delta int) int64
+type RandomTimestampGenerator interface {
+	Next() Timestamp
+}
+
+type RealRandomTimestampGenerator struct {
+	base  Timestamp
+	delta int
+}
+
+func NewRandomTimestampGenerator(base Timestamp, delta int) RandomTimestampGenerator {
+	return &RealRandomTimestampGenerator{
+		base:  base,
+		delta: delta,
+	}
+}
 
-func GenerateRandomInt64InRange(base int64, delta int) int64 {
-	rangeInDelta := rand.Intn(delta*2) - delta
-	return base + int64(rangeInDelta)
+func (this *RealRandomTimestampGenerator) Next() Timestamp {
+	rangeInDelta := rand.Intn(this.delta*2) - this.delta
+	return this.base + Timestamp(rangeInDelta)
 }

+ 6 - 1
proxy/vmess/protocol/user/rand_test.go

@@ -10,11 +10,16 @@ import (
 
 func TestGenerateRandomInt64InRange(t *testing.T) {
 	v2testing.Current(t)
+
 	base := time.Now().Unix()
 	delta := 100
+	generator := &RealRandomTimestampGenerator{
+		base:  Timestamp(base),
+		delta: delta,
+	}
 
 	for i := 0; i < 100; i++ {
-		v := GenerateRandomInt64InRange(base, delta)
+		v := int64(generator.Next())
 		assert.Int64(v).AtMost(base + int64(delta))
 		assert.Int64(v).AtLeast(base - int64(delta))
 	}

+ 3 - 2
proxy/vmess/protocol/user/testing/mocks/mockuserset.go

@@ -2,12 +2,13 @@ package mocks
 
 import (
 	"github.com/v2ray/v2ray-core/proxy/vmess"
+	"github.com/v2ray/v2ray-core/proxy/vmess/protocol/user"
 )
 
 type MockUserSet struct {
 	Users      []vmess.User
 	UserHashes map[string]int
-	Timestamps map[string]int64
+	Timestamps map[string]user.Timestamp
 }
 
 func (us *MockUserSet) AddUser(user vmess.User) error {
@@ -15,7 +16,7 @@ func (us *MockUserSet) AddUser(user vmess.User) error {
 	return nil
 }
 
-func (us *MockUserSet) GetUser(userhash []byte) (vmess.User, int64, bool) {
+func (us *MockUserSet) GetUser(userhash []byte) (vmess.User, user.Timestamp, bool) {
 	idx, found := us.UserHashes[string(userhash)]
 	if found {
 		return us.Users[idx], us.Timestamps[string(userhash)], true

+ 2 - 1
proxy/vmess/protocol/user/testing/mocks/static_userset.go

@@ -3,6 +3,7 @@ package mocks
 import (
 	"github.com/v2ray/v2ray-core/common/uuid"
 	"github.com/v2ray/v2ray-core/proxy/vmess"
+	"github.com/v2ray/v2ray-core/proxy/vmess/protocol/user"
 )
 
 type StaticUser struct {
@@ -32,7 +33,7 @@ func (us *StaticUserSet) AddUser(user vmess.User) error {
 	return nil
 }
 
-func (us *StaticUserSet) GetUser(userhash []byte) (vmess.User, int64, bool) {
+func (us *StaticUserSet) GetUser(userhash []byte) (vmess.User, user.Timestamp, bool) {
 	id, _ := uuid.ParseString("703e9102-eb57-499c-8b59-faf4f371bb21")
 	return &StaticUser{id: vmess.NewID(id)}, 0, true
 }

+ 30 - 12
proxy/vmess/protocol/user/userset.go

@@ -5,6 +5,7 @@ import (
 	"time"
 
 	"github.com/v2ray/v2ray-core/common/collect"
+	"github.com/v2ray/v2ray-core/common/serial"
 	"github.com/v2ray/v2ray-core/proxy/vmess"
 )
 
@@ -13,16 +14,32 @@ const (
 	cacheDurationSec  = 120
 )
 
+type Timestamp int64
+
+func (this Timestamp) Bytes() []byte {
+	return serial.Int64Literal(this).Bytes()
+}
+
+func (this Timestamp) HashBytes() []byte {
+	once := this.Bytes()
+	bytes := make([]byte, 0, 32)
+	bytes = append(bytes, once...)
+	bytes = append(bytes, once...)
+	bytes = append(bytes, once...)
+	bytes = append(bytes, once...)
+	return bytes
+}
+
 type idEntry struct {
 	id      *vmess.ID
 	userIdx int
-	lastSec int64
+	lastSec Timestamp
 	hashes  *collect.SizedQueue
 }
 
 type UserSet interface {
 	AddUser(user vmess.User) error
-	GetUser(timeHash []byte) (vmess.User, int64, bool)
+	GetUser(timeHash []byte) (vmess.User, Timestamp, bool)
 }
 
 type TimedUserSet struct {
@@ -34,7 +51,7 @@ type TimedUserSet struct {
 
 type indexTimePair struct {
 	index   int
-	timeSec int64
+	timeSec Timestamp
 }
 
 func NewTimedUserSet() UserSet {
@@ -48,10 +65,11 @@ func NewTimedUserSet() UserSet {
 	return tus
 }
 
-func (us *TimedUserSet) generateNewHashes(nowSec int64, idx int, entry *idEntry) {
-	idHash := NewTimeHash(HMACHash{})
+func (us *TimedUserSet) generateNewHashes(nowSec Timestamp, idx int, entry *idEntry) {
 	for entry.lastSec <= nowSec {
-		idHashSlice := idHash.Hash(entry.id.Bytes(), entry.lastSec)
+		idHash := IDHash(entry.id.Bytes())
+		idHash.Write(entry.lastSec.Bytes())
+		idHashSlice := idHash.Sum(nil)
 		hashValue := string(idHashSlice)
 		us.access.Lock()
 		us.userHash[hashValue] = indexTimePair{idx, entry.lastSec}
@@ -69,7 +87,7 @@ func (us *TimedUserSet) generateNewHashes(nowSec int64, idx int, entry *idEntry)
 
 func (us *TimedUserSet) updateUserHash(tick <-chan time.Time) {
 	for now := range tick {
-		nowSec := now.Unix() + cacheDurationSec
+		nowSec := Timestamp(now.Unix() + cacheDurationSec)
 		for _, entry := range us.ids {
 			us.generateNewHashes(nowSec, entry.userIdx, entry)
 		}
@@ -85,26 +103,26 @@ func (us *TimedUserSet) AddUser(user vmess.User) error {
 	entry := &idEntry{
 		id:      user.ID(),
 		userIdx: idx,
-		lastSec: nowSec - cacheDurationSec,
+		lastSec: Timestamp(nowSec - cacheDurationSec),
 		hashes:  collect.NewSizedQueue(2*cacheDurationSec + 1),
 	}
-	us.generateNewHashes(nowSec+cacheDurationSec, idx, entry)
+	us.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
 	us.ids = append(us.ids, entry)
 	for _, alterid := range user.AlterIDs() {
 		entry := &idEntry{
 			id:      alterid,
 			userIdx: idx,
-			lastSec: nowSec - cacheDurationSec,
+			lastSec: Timestamp(nowSec - cacheDurationSec),
 			hashes:  collect.NewSizedQueue(2*cacheDurationSec + 1),
 		}
-		us.generateNewHashes(nowSec+cacheDurationSec, idx, entry)
+		us.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
 		us.ids = append(us.ids, entry)
 	}
 
 	return nil
 }
 
-func (us *TimedUserSet) GetUser(userHash []byte) (vmess.User, int64, bool) {
+func (us *TimedUserSet) GetUser(userHash []byte) (vmess.User, Timestamp, bool) {
 	defer us.access.RUnlock()
 	us.access.RLock()
 	pair, found := us.userHash[string(userHash)]

+ 14 - 7
proxy/vmess/protocol/vmess.go

@@ -2,10 +2,10 @@
 package protocol
 
 import (
+	"crypto/md5"
 	"encoding/binary"
 	"hash/fnv"
 	"io"
-	"time"
 
 	"github.com/v2ray/v2ray-core/common/alloc"
 	v2crypto "github.com/v2ray/v2ray-core/common/crypto"
@@ -81,7 +81,10 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
 		return nil, proxy.InvalidAuthentication
 	}
 
-	aesStream, err := v2crypto.NewAesDecryptionStream(userObj.ID().CmdKey(), user.Int64Hash(timeSec))
+	timestampHash := user.TimestampHash()
+	timestampHash.Write(timeSec.HashBytes())
+	iv := timestampHash.Sum(nil)
+	aesStream, err := v2crypto.NewAesDecryptionStream(userObj.ID().CmdKey(), iv)
 	if err != nil {
 		log.Debug("VMess: Failed to create AES stream: %v", err)
 		return nil, err
@@ -169,15 +172,16 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
 }
 
 // ToBytes returns a VMessRequest in the form of byte array.
-func (this *VMessRequest) ToBytes(idHash user.CounterHash, randomRangeInt64 user.RandomInt64InRange, buffer *alloc.Buffer) (*alloc.Buffer, error) {
+func (this *VMessRequest) ToBytes(timestampGenerator user.RandomTimestampGenerator, buffer *alloc.Buffer) (*alloc.Buffer, error) {
 	if buffer == nil {
 		buffer = alloc.NewSmallBuffer().Clear()
 	}
 
-	counter := randomRangeInt64(time.Now().Unix(), 30)
-	hash := idHash.Hash(this.User.AnyValidID().Bytes(), counter)
+	timestamp := timestampGenerator.Next()
+	idHash := user.IDHash(this.User.AnyValidID().Bytes())
+	idHash.Write(timestamp.Bytes())
 
-	buffer.Append(hash)
+	buffer.Append(idHash.Sum(nil))
 
 	encryptionBegin := buffer.Len()
 
@@ -209,7 +213,10 @@ func (this *VMessRequest) ToBytes(idHash user.CounterHash, randomRangeInt64 user
 	buffer.AppendBytes(byte(fnvHash>>24), byte(fnvHash>>16), byte(fnvHash>>8), byte(fnvHash))
 	encryptionEnd += 4
 
-	aesStream, err := v2crypto.NewAesEncryptionStream(this.User.ID().CmdKey(), user.Int64Hash(counter))
+	timestampHash := md5.New()
+	timestampHash.Write(timestamp.HashBytes())
+	iv := timestampHash.Sum(nil)
+	aesStream, err := v2crypto.NewAesEncryptionStream(this.User.ID().CmdKey(), iv)
 	if err != nil {
 		return nil, err
 	}

+ 14 - 5
proxy/vmess/protocol/vmess_test.go

@@ -5,6 +5,7 @@ import (
 	"crypto/rand"
 	"io"
 	"testing"
+	"time"
 
 	v2net "github.com/v2ray/v2ray-core/common/net"
 	"github.com/v2ray/v2ray-core/common/uuid"
@@ -15,6 +16,14 @@ import (
 	"github.com/v2ray/v2ray-core/testing/assert"
 )
 
+type FakeTimestampGenerator struct {
+	timestamp user.Timestamp
+}
+
+func (this *FakeTimestampGenerator) Next() user.Timestamp {
+	return this.timestamp
+}
+
 type TestUser struct {
 	id    *vmess.ID
 	level vmess.UserLevel
@@ -48,7 +57,7 @@ func TestVMessSerialization(t *testing.T) {
 		id: userId,
 	}
 
-	userSet := mocks.MockUserSet{[]vmess.User{}, make(map[string]int), make(map[string]int64)}
+	userSet := mocks.MockUserSet{[]vmess.User{}, make(map[string]int), make(map[string]user.Timestamp)}
 	userSet.AddUser(testUser)
 
 	request := new(VMessRequest)
@@ -66,9 +75,9 @@ func TestVMessSerialization(t *testing.T) {
 	request.Address = v2net.DomainAddress("v2ray.com")
 	request.Port = v2net.Port(80)
 
-	mockTime := int64(1823730)
+	mockTime := user.Timestamp(1823730)
 
-	buffer, err := request.ToBytes(user.NewTimeHash(user.HMACHash{}), func(base int64, delta int) int64 { return mockTime }, nil)
+	buffer, err := request.ToBytes(&FakeTimestampGenerator{timestamp: mockTime}, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -104,7 +113,7 @@ func BenchmarkVMessRequestWriting(b *testing.B) {
 	assert.Error(err).IsNil()
 
 	userId := vmess.NewID(id)
-	userSet := mocks.MockUserSet{[]vmess.User{}, make(map[string]int), make(map[string]int64)}
+	userSet := mocks.MockUserSet{[]vmess.User{}, make(map[string]int), make(map[string]user.Timestamp)}
 
 	testUser := &TestUser{
 		id: userId,
@@ -126,6 +135,6 @@ func BenchmarkVMessRequestWriting(b *testing.B) {
 	request.Port = v2net.Port(80)
 
 	for i := 0; i < b.N; i++ {
-		request.ToBytes(user.NewTimeHash(user.HMACHash{}), user.GenerateRandomInt64InRange, nil)
+		request.ToBytes(user.NewRandomTimestampGenerator(user.Timestamp(time.Now().Unix()), 30), nil)
 	}
 }