Browse Source

Added TestsEnabled Settings to enable VMessAEAD test

Shelikhoo 5 years ago
parent
commit
d06a4d1f50

+ 5 - 3
infra/conf/vmess.go

@@ -14,9 +14,10 @@ import (
 )
 )
 
 
 type VMessAccount struct {
 type VMessAccount struct {
-	ID       string `json:"id"`
-	AlterIds uint16 `json:"alterId"`
-	Security string `json:"security"`
+	ID           string `json:"id"`
+	AlterIds     uint16 `json:"alterId"`
+	Security     string `json:"security"`
+	TestsEnabled string `json:"testsEnabled"`
 }
 }
 
 
 // Build implements Buildable
 // Build implements Buildable
@@ -40,6 +41,7 @@ func (a *VMessAccount) Build() *vmess.Account {
 		SecuritySettings: &protocol.SecurityConfig{
 		SecuritySettings: &protocol.SecurityConfig{
 			Type: st,
 			Type: st,
 		},
 		},
+		TestsEnabled: a.TestsEnabled,
 	}
 	}
 }
 }
 
 

+ 6 - 3
proxy/vmess/account.go

@@ -16,6 +16,8 @@ type MemoryAccount struct {
 	AlterIDs []*protocol.ID
 	AlterIDs []*protocol.ID
 	// Security type of the account. Used for client connections.
 	// Security type of the account. Used for client connections.
 	Security protocol.SecurityType
 	Security protocol.SecurityType
+
+	TestsEnabled string
 }
 }
 
 
 // AnyValidID returns an ID that is either the main ID or one of the alternative IDs if any.
 // AnyValidID returns an ID that is either the main ID or one of the alternative IDs if any.
@@ -44,8 +46,9 @@ func (a *Account) AsAccount() (protocol.Account, error) {
 	}
 	}
 	protoID := protocol.NewID(id)
 	protoID := protocol.NewID(id)
 	return &MemoryAccount{
 	return &MemoryAccount{
-		ID:       protoID,
-		AlterIDs: protocol.NewAlterIDs(protoID, uint16(a.AlterId)),
-		Security: a.SecuritySettings.GetSecurityType(),
+		ID:           protoID,
+		AlterIDs:     protocol.NewAlterIDs(protoID, uint16(a.AlterId)),
+		Security:     a.SecuritySettings.GetSecurityType(),
+		TestsEnabled: a.TestsEnabled,
 	}, nil
 	}, nil
 }
 }

+ 11 - 2
proxy/vmess/encoding/client.go

@@ -2,6 +2,7 @@ package encoding
 
 
 import (
 import (
 	"bytes"
 	"bytes"
+	"context"
 	"crypto/aes"
 	"crypto/aes"
 	"crypto/cipher"
 	"crypto/cipher"
 	"crypto/md5"
 	"crypto/md5"
@@ -13,6 +14,7 @@ import (
 	"hash/fnv"
 	"hash/fnv"
 	"io"
 	"io"
 	"os"
 	"os"
+	"strings"
 	vmessaead "v2ray.com/core/proxy/vmess/aead"
 	vmessaead "v2ray.com/core/proxy/vmess/aead"
 
 
 	"golang.org/x/crypto/chacha20poly1305"
 	"golang.org/x/crypto/chacha20poly1305"
@@ -49,13 +51,20 @@ type ClientSession struct {
 }
 }
 
 
 // NewClientSession creates a new ClientSession.
 // NewClientSession creates a new ClientSession.
-func NewClientSession(idHash protocol.IDHash) *ClientSession {
+func NewClientSession(idHash protocol.IDHash, ctx context.Context) *ClientSession {
 	randomBytes := make([]byte, 33) // 16 + 16 + 1
 	randomBytes := make([]byte, 33) // 16 + 16 + 1
 	common.Must2(rand.Read(randomBytes))
 	common.Must2(rand.Read(randomBytes))
 
 
 	session := &ClientSession{}
 	session := &ClientSession{}
 
 
-	session.isAEADRequest = true
+	session.isAEADRequest = false
+
+	if ctxValueTestsEnabled := ctx.Value(vmess.TestsEnabled); ctxValueTestsEnabled != nil {
+		testsEnabled := ctxValueTestsEnabled.(string)
+		if strings.Contains(testsEnabled, "VMessAEAD") {
+			session.isAEADRequest = true
+		}
+	}
 
 
 	if vmessexp, vmessexp_found := os.LookupEnv("VMESSAEADEXPERIMENT"); vmessexp_found {
 	if vmessexp, vmessexp_found := os.LookupEnv("VMESSAEADEXPERIMENT"); vmessexp_found {
 		if vmessexp == "y" {
 		if vmessexp == "y" {

+ 4 - 3
proxy/vmess/encoding/encoding_test.go

@@ -1,6 +1,7 @@
 package encoding_test
 package encoding_test
 
 
 import (
 import (
+	"context"
 	"testing"
 	"testing"
 
 
 	"github.com/google/go-cmp/cmp"
 	"github.com/google/go-cmp/cmp"
@@ -42,7 +43,7 @@ func TestRequestSerialization(t *testing.T) {
 	}
 	}
 
 
 	buffer := buf.New()
 	buffer := buf.New()
-	client := NewClientSession(protocol.DefaultIDHash)
+	client := NewClientSession(protocol.DefaultIDHash, context.TODO())
 	common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
 	common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
 
 
 	buffer2 := buf.New()
 	buffer2 := buf.New()
@@ -92,7 +93,7 @@ func TestInvalidRequest(t *testing.T) {
 	}
 	}
 
 
 	buffer := buf.New()
 	buffer := buf.New()
-	client := NewClientSession(protocol.DefaultIDHash)
+	client := NewClientSession(protocol.DefaultIDHash, context.TODO())
 	common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
 	common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
 
 
 	buffer2 := buf.New()
 	buffer2 := buf.New()
@@ -133,7 +134,7 @@ func TestMuxRequest(t *testing.T) {
 	}
 	}
 
 
 	buffer := buf.New()
 	buffer := buf.New()
-	client := NewClientSession(protocol.DefaultIDHash)
+	client := NewClientSession(protocol.DefaultIDHash, context.TODO())
 	common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
 	common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
 
 
 	buffer2 := buf.New()
 	buffer2 := buf.New()

+ 5 - 2
proxy/vmess/outbound/outbound.go

@@ -89,9 +89,10 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		command = protocol.RequestCommandMux
 		command = protocol.RequestCommandMux
 	}
 	}
 
 
+	user := rec.PickUser()
 	request := &protocol.RequestHeader{
 	request := &protocol.RequestHeader{
 		Version: encoding.Version,
 		Version: encoding.Version,
-		User:    rec.PickUser(),
+		User:    user,
 		Command: command,
 		Command: command,
 		Address: target.Address,
 		Address: target.Address,
 		Port:    target.Port,
 		Port:    target.Port,
@@ -112,7 +113,9 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 	input := link.Reader
 	input := link.Reader
 	output := link.Writer
 	output := link.Writer
 
 
-	session := encoding.NewClientSession(protocol.DefaultIDHash)
+	ctx = context.WithValue(ctx, vmess.TestsEnabled, user.Account.(*vmess.MemoryAccount).TestsEnabled)
+
+	session := encoding.NewClientSession(protocol.DefaultIDHash, ctx)
 	sessionPolicy := v.policyManager.ForLevel(request.User.Level)
 	sessionPolicy := v.policyManager.ForLevel(request.User.Level)
 
 
 	ctx, cancel := context.WithCancel(ctx)
 	ctx, cancel := context.WithCancel(ctx)

+ 3 - 0
proxy/vmess/vmessCtxInterface.go

@@ -0,0 +1,3 @@
+package vmess
+
+const TestsEnabled = "VMessCtxInterface_TestsEnabled"