Parcourir la source

function to compare byte array

Darien Raymond il y a 7 ans
Parent
commit
29ad2cbbdb

+ 29 - 0
common/compare/bytes.go

@@ -0,0 +1,29 @@
+package compare
+
+import "v2ray.com/core/common/errors"
+
+func BytesEqualWithDetail(a []byte, b []byte) error {
+	if len(a) != len(b) {
+		return errors.New("mismatch array length ", len(a), " vs ", len(b))
+	}
+	for idx, v := range a {
+		if b[idx] != v {
+			return errors.New("mismatch array value at index [", idx, "]: ", v, " vs ", b[idx])
+		}
+	}
+	return nil
+}
+
+func BytesEqual(a []byte, b []byte) bool {
+	return BytesEqualWithDetail(a, b) == nil
+}
+
+func BytesAll(arr []byte, value byte) bool {
+	for _, v := range arr {
+		if v != value {
+			return false
+		}
+	}
+
+	return true
+}

+ 43 - 0
common/compare/bytes_test.go

@@ -0,0 +1,43 @@
+package compare_test
+
+import (
+	"testing"
+
+	. "v2ray.com/core/common/compare"
+)
+
+func TestBytesEqual(t *testing.T) {
+	testCases := []struct {
+		Input1 []byte
+		Input2 []byte
+		Result bool
+	}{
+		{
+			Input1: []byte{},
+			Input2: []byte{1},
+			Result: false,
+		},
+		{
+			Input1: nil,
+			Input2: []byte{},
+			Result: true,
+		},
+		{
+			Input1: []byte{1},
+			Input2: []byte{1},
+			Result: true,
+		},
+		{
+			Input1: []byte{1, 2},
+			Input2: []byte{1, 3},
+			Result: false,
+		},
+	}
+
+	for _, testCase := range testCases {
+		cmp := BytesEqual(testCase.Input1, testCase.Input2)
+		if cmp != testCase.Result {
+			t.Errorf("unexpected result %v from %v", cmp, testCase)
+		}
+	}
+}

+ 2 - 2
common/net/address.go

@@ -4,7 +4,7 @@ import (
 	"net"
 	"strings"
 
-	"v2ray.com/core/common/predicate"
+	"v2ray.com/core/common/compare"
 )
 
 var (
@@ -94,7 +94,7 @@ func IPAddress(ip []byte) Address {
 		var addr ipv4Address = [4]byte{ip[0], ip[1], ip[2], ip[3]}
 		return addr
 	case net.IPv6len:
-		if predicate.BytesAll(ip[0:10], 0) && predicate.BytesAll(ip[10:12], 0xff) {
+		if compare.BytesAll(ip[0:10], 0) && compare.BytesAll(ip[10:12], 0xff) {
 			return IPAddress(ip[12:16])
 		}
 		var addr ipv6Address = [16]byte{

+ 30 - 0
common/peer/latency.go

@@ -0,0 +1,30 @@
+package peer
+
+import (
+	"sync"
+)
+
+type Latency interface {
+	Value() uint64
+}
+
+type HasLatency interface {
+	ConnectionLatency() Latency
+	HandshakeLatency() Latency
+}
+
+type AverageLatency struct {
+	access sync.Mutex
+	value  uint64
+}
+
+func (al *AverageLatency) Update(newValue uint64) {
+	al.access.Lock()
+	defer al.access.Unlock()
+
+	al.value = (al.value + newValue*2) / 3
+}
+
+func (al *AverageLatency) Value() uint64 {
+	return al.value
+}

+ 1 - 0
common/peer/peer.go

@@ -0,0 +1 @@
+package peer

+ 0 - 10
common/predicate/arrays.go

@@ -1,10 +0,0 @@
-package predicate
-
-func BytesAll(array []byte, b byte) bool {
-	for _, val := range array {
-		if val != b {
-			return false
-		}
-	}
-	return true
-}

+ 0 - 39
common/predicate/predicate.go

@@ -1,39 +0,0 @@
-package predicate // import "v2ray.com/core/common/predicate"
-
-type Predicate func() bool
-
-func (v Predicate) And(predicate Predicate) Predicate {
-	return All(v, predicate)
-}
-
-func (v Predicate) Or(predicate Predicate) Predicate {
-	return Any(v, predicate)
-}
-
-func All(predicates ...Predicate) Predicate {
-	return func() bool {
-		for _, p := range predicates {
-			if !p() {
-				return false
-			}
-		}
-		return true
-	}
-}
-
-func Any(predicates ...Predicate) Predicate {
-	return func() bool {
-		for _, p := range predicates {
-			if p() {
-				return true
-			}
-		}
-		return false
-	}
-}
-
-func Not(predicate Predicate) Predicate {
-	return func() bool {
-		return !predicate()
-	}
-}

+ 2 - 3
common/protocol/id.go

@@ -56,7 +56,7 @@ func NewID(uuid uuid.UUID) *ID {
 	return id
 }
 
-func nextId(u *uuid.UUID) uuid.UUID {
+func nextID(u *uuid.UUID) uuid.UUID {
 	md5hash := md5.New()
 	common.Must2(md5hash.Write(u.Bytes()))
 	common.Must2(md5hash.Write([]byte("16167dc8-16b6-4e6d-b8bb-65dd68113a81")))
@@ -74,8 +74,7 @@ func NewAlterIDs(primary *ID, alterIDCount uint16) []*ID {
 	alterIDs := make([]*ID, alterIDCount)
 	prevID := primary.UUID()
 	for idx := range alterIDs {
-		newid := nextId(&prevID)
-		// TODO: check duplicates
+		newid := nextID(&prevID)
 		alterIDs[idx] = NewID(newid)
 		prevID = newid
 	}

+ 2 - 2
common/protocol/id_test.go

@@ -3,7 +3,7 @@ package protocol_test
 import (
 	"testing"
 
-	"v2ray.com/core/common/predicate"
+	"v2ray.com/core/common/compare"
 	. "v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/uuid"
 	. "v2ray.com/ext/assert"
@@ -13,7 +13,7 @@ func TestCmdKey(t *testing.T) {
 	assert := With(t)
 
 	id := NewID(uuid.New())
-	assert(predicate.BytesAll(id.CmdKey(), 0), IsFalse)
+	assert(compare.BytesAll(id.CmdKey(), 0), IsFalse)
 }
 
 func TestIdEquals(t *testing.T) {

+ 2 - 2
proxy/mtproto/server.go

@@ -7,9 +7,9 @@ import (
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
+	"v2ray.com/core/common/compare"
 	"v2ray.com/core/common/crypto"
 	"v2ray.com/core/common/net"
-	"v2ray.com/core/common/predicate"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/session"
 	"v2ray.com/core/common/signal"
@@ -85,7 +85,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
 	decryptor := crypto.NewAesCTRStream(auth.DecodingKey[:], auth.DecodingNonce[:])
 	decryptor.XORKeyStream(auth.Header[:], auth.Header[:])
 
-	if !predicate.BytesAll(auth.Header[56:60], 0xef) {
+	if !compare.BytesAll(auth.Header[56:60], 0xef) {
 		return newError("invalid connection type: ", auth.Header[56:60])
 	}
 

+ 14 - 6
testing/scenarios/command_test.go

@@ -18,6 +18,7 @@ import (
 	"v2ray.com/core/app/router"
 	"v2ray.com/core/app/stats"
 	statscmd "v2ray.com/core/app/stats/command"
+	"v2ray.com/core/common/compare"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/serial"
@@ -101,12 +102,17 @@ func TestCommanderRemoveHandler(t *testing.T) {
 	servers, err := InitializeServerConfigs(clientConfig)
 	assert(err, IsNil)
 
+	defer CloseAllServers(servers)
+
 	{
 		conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{
 			IP:   []byte{127, 0, 0, 1},
 			Port: int(clientPort),
 		})
-		assert(err, IsNil)
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer conn.Close() // nolint: errcheck
 
 		payload := "commander request."
 		nBytes, err := conn.Write([]byte(payload))
@@ -116,8 +122,9 @@ func TestCommanderRemoveHandler(t *testing.T) {
 		response := make([]byte, 1024)
 		nBytes, err = conn.Read(response)
 		assert(err, IsNil)
-		assert(response[:nBytes], Equals, xor([]byte(payload)))
-		assert(conn.Close(), IsNil)
+		if err := compare.BytesEqualWithDetail(response[:nBytes], xor([]byte(payload))); err != nil {
+			t.Fatal(err)
+		}
 	}
 
 	cmdConn, err := grpc.Dial(fmt.Sprintf("127.0.0.1:%d", cmdPort), grpc.WithInsecure(), grpc.WithBlock())
@@ -137,8 +144,6 @@ func TestCommanderRemoveHandler(t *testing.T) {
 		})
 		assert(err, IsNotNil)
 	}
-
-	CloseAllServers(servers)
 }
 
 func TestCommanderAddRemoveUser(t *testing.T) {
@@ -487,7 +492,10 @@ func TestCommanderStats(t *testing.T) {
 	}
 
 	servers, err := InitializeServerConfigs(serverConfig, clientConfig)
-	assert(err, IsNil)
+	if err != nil {
+		t.Fatal("Failed to create all servers", err)
+	}
+	defer CloseAllServers(servers)
 
 	conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{
 		IP:   []byte{127, 0, 0, 1},

+ 7 - 6
transport/internet/kcp/connection.go

@@ -8,7 +8,6 @@ import (
 	"time"
 
 	"v2ray.com/core/common/buf"
-	"v2ray.com/core/common/predicate"
 	"v2ray.com/core/common/signal"
 	"v2ray.com/core/common/signal/semaphore"
 )
@@ -119,13 +118,13 @@ func (info *RoundTripInfo) SmoothedTime() uint32 {
 
 type Updater struct {
 	interval        int64
-	shouldContinue  predicate.Predicate
-	shouldTerminate predicate.Predicate
+	shouldContinue  func() bool
+	shouldTerminate func() bool
 	updateFunc      func()
 	notifier        *semaphore.Instance
 }
 
-func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTerminate predicate.Predicate, updateFunc func()) *Updater {
+func NewUpdater(interval uint32, shouldContinue func() bool, shouldTerminate func() bool, updateFunc func()) *Updater {
 	u := &Updater{
 		interval:        int64(time.Duration(interval) * time.Millisecond),
 		shouldContinue:  shouldContinue,
@@ -230,12 +229,14 @@ func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, con
 	}
 	conn.dataUpdater = NewUpdater(
 		config.GetTTIValue(),
-		predicate.Not(isTerminating).And(predicate.Any(conn.sendingWorker.UpdateNecessary, conn.receivingWorker.UpdateNecessary)),
+		func() bool {
+			return !isTerminating() && (conn.sendingWorker.UpdateNecessary() || conn.receivingWorker.UpdateNecessary())
+		},
 		isTerminating,
 		conn.updateTask)
 	conn.pingUpdater = NewUpdater(
 		5000, // 5 seconds
-		predicate.Not(isTerminated),
+		func() bool { return !isTerminated() },
 		isTerminated,
 		conn.updateTask)
 	conn.pingUpdater.WakeUp()