Browse Source

Make isAEAD more efficient

RPRX 5 years ago
parent
commit
470dc8523b

+ 14 - 28
proxy/vmess/encoding/client.go

@@ -12,8 +12,6 @@ import (
 	"hash"
 	"hash"
 	"hash/fnv"
 	"hash/fnv"
 	"io"
 	"io"
-	"os"
-	vmessaead "v2ray.com/core/proxy/vmess/aead"
 
 
 	"golang.org/x/crypto/chacha20poly1305"
 	"golang.org/x/crypto/chacha20poly1305"
 
 
@@ -25,6 +23,7 @@ import (
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/serial"
 	"v2ray.com/core/common/serial"
 	"v2ray.com/core/proxy/vmess"
 	"v2ray.com/core/proxy/vmess"
+	vmessaead "v2ray.com/core/proxy/vmess/aead"
 )
 )
 
 
 func hashTimestamp(h hash.Hash, t protocol.Timestamp) []byte {
 func hashTimestamp(h hash.Hash, t protocol.Timestamp) []byte {
@@ -37,6 +36,7 @@ func hashTimestamp(h hash.Hash, t protocol.Timestamp) []byte {
 
 
 // ClientSession stores connection session info for VMess client.
 // ClientSession stores connection session info for VMess client.
 type ClientSession struct {
 type ClientSession struct {
+	isAEAD          bool
 	idHash          protocol.IDHash
 	idHash          protocol.IDHash
 	requestBodyKey  [16]byte
 	requestBodyKey  [16]byte
 	requestBodyIV   [16]byte
 	requestBodyIV   [16]byte
@@ -44,35 +44,23 @@ type ClientSession struct {
 	responseBodyIV  [16]byte
 	responseBodyIV  [16]byte
 	responseReader  io.Reader
 	responseReader  io.Reader
 	responseHeader  byte
 	responseHeader  byte
-
-	isAEADRequest bool
 }
 }
 
 
 // NewClientSession creates a new ClientSession.
 // NewClientSession creates a new ClientSession.
-func NewClientSession(idHash protocol.IDHash, ctx context.Context) *ClientSession {
-	randomBytes := make([]byte, 33) // 16 + 16 + 1
-	common.Must2(rand.Read(randomBytes))
-
-	session := &ClientSession{}
+func NewClientSession(isAEAD bool, idHash protocol.IDHash, ctx context.Context) *ClientSession {
 
 
-	session.isAEADRequest = false
-
-	if ctxValueAlterID := ctx.Value(vmess.AlterID); ctxValueAlterID != nil {
-		if ctxValueAlterID == 0 {
-			session.isAEADRequest = true
-		}
-	}
-
-	if vmessAeadDisable, vmessAeadDisableFound := os.LookupEnv("V2RAY_VMESS_AEAD_DISABLED"); vmessAeadDisableFound {
-		if vmessAeadDisable == "true" {
-			session.isAEADRequest = false
-		}
+	session := &ClientSession{
+		isAEAD: isAEAD,
+		idHash: idHash,
 	}
 	}
 
 
+	randomBytes := make([]byte, 33) // 16 + 16 + 1
+	common.Must2(rand.Read(randomBytes))
 	copy(session.requestBodyKey[:], randomBytes[:16])
 	copy(session.requestBodyKey[:], randomBytes[:16])
 	copy(session.requestBodyIV[:], randomBytes[16:32])
 	copy(session.requestBodyIV[:], randomBytes[16:32])
 	session.responseHeader = randomBytes[32]
 	session.responseHeader = randomBytes[32]
-	if !session.isAEADRequest {
+
+	if !session.isAEAD {
 		session.responseBodyKey = md5.Sum(session.requestBodyKey[:])
 		session.responseBodyKey = md5.Sum(session.requestBodyKey[:])
 		session.responseBodyIV = md5.Sum(session.requestBodyIV[:])
 		session.responseBodyIV = md5.Sum(session.requestBodyIV[:])
 	} else {
 	} else {
@@ -82,15 +70,13 @@ func NewClientSession(idHash protocol.IDHash, ctx context.Context) *ClientSessio
 		copy(session.responseBodyIV[:], BodyIV[:16])
 		copy(session.responseBodyIV[:], BodyIV[:16])
 	}
 	}
 
 
-	session.idHash = idHash
-
 	return session
 	return session
 }
 }
 
 
 func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error {
 func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error {
 	timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)()
 	timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)()
 	account := header.User.Account.(*vmess.MemoryAccount)
 	account := header.User.Account.(*vmess.MemoryAccount)
-	if !c.isAEADRequest {
+	if !c.isAEAD {
 		idHash := c.idHash(account.AnyValidID().Bytes())
 		idHash := c.idHash(account.AnyValidID().Bytes())
 		common.Must2(serial.WriteUint64(idHash, uint64(timestamp)))
 		common.Must2(serial.WriteUint64(idHash, uint64(timestamp)))
 		common.Must2(writer.Write(idHash.Sum(nil)))
 		common.Must2(writer.Write(idHash.Sum(nil)))
@@ -126,7 +112,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
 		fnv1a.Sum(hashBytes[:0])
 		fnv1a.Sum(hashBytes[:0])
 	}
 	}
 
 
-	if !c.isAEADRequest {
+	if !c.isAEAD {
 		iv := hashTimestamp(md5.New(), timestamp)
 		iv := hashTimestamp(md5.New(), timestamp)
 		aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:])
 		aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:])
 		aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
 		aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
@@ -203,7 +189,7 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
 }
 }
 
 
 func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) {
 func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) {
-	if !c.isAEADRequest {
+	if !c.isAEAD {
 		aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
 		aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
 		c.responseReader = crypto.NewCryptionReader(aesStream, reader)
 		c.responseReader = crypto.NewCryptionReader(aesStream, reader)
 	} else {
 	} else {
@@ -274,7 +260,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
 			header.Command = command
 			header.Command = command
 		}
 		}
 	}
 	}
-	if c.isAEADRequest {
+	if c.isAEAD {
 		aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
 		aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
 		c.responseReader = crypto.NewCryptionReader(aesStream, reader)
 		c.responseReader = crypto.NewCryptionReader(aesStream, reader)
 	}
 	}

+ 3 - 3
proxy/vmess/encoding/encoding_test.go

@@ -43,7 +43,7 @@ func TestRequestSerialization(t *testing.T) {
 	}
 	}
 
 
 	buffer := buf.New()
 	buffer := buf.New()
-	client := NewClientSession(protocol.DefaultIDHash, context.TODO())
+	client := NewClientSession(true, protocol.DefaultIDHash, context.TODO())
 	common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
 	common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
 
 
 	buffer2 := buf.New()
 	buffer2 := buf.New()
@@ -93,7 +93,7 @@ func TestInvalidRequest(t *testing.T) {
 	}
 	}
 
 
 	buffer := buf.New()
 	buffer := buf.New()
-	client := NewClientSession(protocol.DefaultIDHash, context.TODO())
+	client := NewClientSession(true, protocol.DefaultIDHash, context.TODO())
 	common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
 	common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
 
 
 	buffer2 := buf.New()
 	buffer2 := buf.New()
@@ -134,7 +134,7 @@ func TestMuxRequest(t *testing.T) {
 	}
 	}
 
 
 	buffer := buf.New()
 	buffer := buf.New()
-	client := NewClientSession(protocol.DefaultIDHash, context.TODO())
+	client := NewClientSession(true, protocol.DefaultIDHash, context.TODO())
 	common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
 	common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
 
 
 	buffer2 := buf.New()
 	buffer2 := buf.New()

+ 15 - 6
proxy/vmess/outbound/outbound.go

@@ -6,6 +6,7 @@ package outbound
 
 
 import (
 import (
 	"context"
 	"context"
+	"os"
 	"time"
 	"time"
 
 
 	"v2ray.com/core"
 	"v2ray.com/core"
@@ -30,6 +31,7 @@ type Handler struct {
 	serverList    *protocol.ServerList
 	serverList    *protocol.ServerList
 	serverPicker  protocol.ServerPicker
 	serverPicker  protocol.ServerPicker
 	policyManager policy.Manager
 	policyManager policy.Manager
+	aead_disabled bool
 }
 }
 
 
 // New creates a new VMess outbound handler.
 // New creates a new VMess outbound handler.
@@ -50,16 +52,20 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
 		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
 		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
 	}
 	}
 
 
+	if disabled, _ := os.LookupEnv("V2RAY_VMESS_AEAD_DISABLED"); disabled == "true" {
+		handler.aead_disabled = true
+	}
+
 	return handler, nil
 	return handler, nil
 }
 }
 
 
 // Process implements proxy.Outbound.Process().
 // Process implements proxy.Outbound.Process().
-func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
+func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
 	var rec *protocol.ServerSpec
 	var rec *protocol.ServerSpec
 	var conn internet.Connection
 	var conn internet.Connection
 
 
 	err := retry.ExponentialBackoff(5, 200).On(func() error {
 	err := retry.ExponentialBackoff(5, 200).On(func() error {
-		rec = v.serverPicker.PickServer()
+		rec = h.serverPicker.PickServer()
 		rawConn, err := dialer.Dial(ctx, rec.Destination())
 		rawConn, err := dialer.Dial(ctx, rec.Destination())
 		if err != nil {
 		if err != nil {
 			return err
 			return err
@@ -113,10 +119,13 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 	input := link.Reader
 	input := link.Reader
 	output := link.Writer
 	output := link.Writer
 
 
-	ctx = context.WithValue(ctx, vmess.AlterID, len(account.AlterIDs))
+	isAEAD := false
+	if !h.aead_disabled && len(account.AlterIDs) == 0 {
+		isAEAD = true
+	}
 
 
-	session := encoding.NewClientSession(protocol.DefaultIDHash, ctx)
-	sessionPolicy := v.policyManager.ForLevel(request.User.Level)
+	session := encoding.NewClientSession(isAEAD, protocol.DefaultIDHash, ctx)
+	sessionPolicy := h.policyManager.ForLevel(request.User.Level)
 
 
 	ctx, cancel := context.WithCancel(ctx)
 	ctx, cancel := context.WithCancel(ctx)
 	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
 	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
@@ -159,7 +168,7 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		if err != nil {
 		if err != nil {
 			return newError("failed to read header").Base(err)
 			return newError("failed to read header").Base(err)
 		}
 		}
-		v.handleCommand(rec.Destination(), header.Command)
+		h.handleCommand(rec.Destination(), header.Command)
 
 
 		bodyReader := session.DecodeResponseBody(request, reader)
 		bodyReader := session.DecodeResponseBody(request, reader)
 
 

+ 1 - 0
proxy/vmess/vmessCtxInterface.go

@@ -1,3 +1,4 @@
 package vmess
 package vmess
 
 
+// example
 const AlterID = "VMessCtxInterface_AlterID"
 const AlterID = "VMessCtxInterface_AlterID"