Browse Source

VMess AEAD Experiment

Shelikhoo 5 years ago
parent
commit
9bf07b1f26

+ 50 - 0
common/antireplay/antireplay.go

@@ -0,0 +1,50 @@
+package antireplay
+
+import (
+	cuckoo "github.com/seiflotfy/cuckoofilter"
+	"sync"
+	"time"
+)
+
+func NewAntiReplayWindow(AntiReplayTime int64) *AntiReplayWindow {
+	arw := &AntiReplayWindow{}
+	arw.AntiReplayTime = AntiReplayTime
+	return arw
+}
+
+type AntiReplayWindow struct {
+	lock           sync.Mutex
+	poolA          *cuckoo.Filter
+	poolB          *cuckoo.Filter
+	lastSwapTime   int64
+	PoolSwap       bool
+	AntiReplayTime int64
+}
+
+func (aw *AntiReplayWindow) Check(sum []byte) bool {
+	aw.lock.Lock()
+
+	if aw.lastSwapTime == 0 {
+		aw.lastSwapTime = time.Now().Unix()
+		aw.poolA = cuckoo.NewFilter(100000)
+		aw.poolB = cuckoo.NewFilter(100000)
+	}
+
+	tnow := time.Now().Unix()
+	timediff := tnow - aw.lastSwapTime
+
+	if timediff >= aw.AntiReplayTime {
+		if aw.PoolSwap {
+			aw.PoolSwap = false
+			aw.poolA.Reset()
+		} else {
+			aw.PoolSwap = true
+			aw.poolB.Reset()
+		}
+		aw.lastSwapTime = tnow
+	}
+
+	ret := aw.poolA.InsertUnique(sum) && aw.poolB.InsertUnique(sum)
+	aw.lock.Unlock()
+	return ret
+}

+ 0 - 2
common/protocol/headers.go

@@ -38,8 +38,6 @@ const (
 	RequestOptionChunkMasking bitmask.Byte = 0x04
 
 	RequestOptionGlobalPadding bitmask.Byte = 0x08
-
-	RequestOptionEarlyChecksum bitmask.Byte = 0x16
 )
 
 type RequestHeader struct {

+ 2 - 0
go.mod

@@ -1,12 +1,14 @@
 module v2ray.com/core
 
 require (
+	github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc // indirect
 	github.com/golang/mock v1.2.0
 	github.com/golang/protobuf v1.3.2
 	github.com/google/go-cmp v0.2.0
 	github.com/gorilla/websocket v1.4.1
 	github.com/miekg/dns v1.1.4
 	github.com/refraction-networking/utls v0.0.0-20190909200633-43c36d3c1f57
+	github.com/seiflotfy/cuckoofilter v0.0.0-20200511222245-56093a4d3841
 	go.starlark.net v0.0.0-20190919145610-979af19b165c
 	golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf
 	golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3

+ 4 - 0
go.sum

@@ -1,6 +1,8 @@
 cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
 github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
+github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc h1:8WFBn63wegobsYAX0YjD+8suexZDga5CctH4CCTx2+8=
+github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw=
 github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
 github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
 github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
@@ -16,6 +18,8 @@ github.com/miekg/dns v1.1.4 h1:rCMZsU2ScVSYcAsOXgmC6+AKOK+6pmQTOcw03nfwYV0=
 github.com/miekg/dns v1.1.4/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
 github.com/refraction-networking/utls v0.0.0-20190909200633-43c36d3c1f57 h1:SL1K0QAuC1b54KoY1pjPWe6kSlsFHwK9/oC960fKrTY=
 github.com/refraction-networking/utls v0.0.0-20190909200633-43c36d3c1f57/go.mod h1:tz9gX959MEFfFN5whTIocCLUG57WiILqtdVxI8c6Wj0=
+github.com/seiflotfy/cuckoofilter v0.0.0-20200511222245-56093a4d3841 h1:pnfutQFsV7ySmHUeX6ANGfPsBo29RctUvDn8G3rmJVw=
+github.com/seiflotfy/cuckoofilter v0.0.0-20200511222245-56093a4d3841/go.mod h1:ET5mVvNjwaGXRgZxO9UZr7X+8eAf87AfIYNwRSp9s4Y=
 go.starlark.net v0.0.0-20190919145610-979af19b165c h1:WR7X1xgXJlXhQBdorVc9Db3RhwG+J/kp6bLuMyJjfVw=
 go.starlark.net v0.0.0-20190919145610-979af19b165c/go.mod h1:c1/X6cHgvdXj6pUlmWKMkuqRnW4K8x2vwt6JAaaircg=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=

+ 114 - 0
proxy/vmess/aead/authid.go

@@ -0,0 +1,114 @@
+package aead
+
+import (
+	"bytes"
+	"crypto/aes"
+	"crypto/cipher"
+	rand3 "crypto/rand"
+	"encoding/binary"
+	"errors"
+	"hash/crc32"
+	"io"
+	"math"
+	"time"
+	"v2ray.com/core/common"
+	antiReplayWindow "v2ray.com/core/common/antireplay"
+)
+
+func CreateAuthID(cmdKey []byte, time int64) [16]byte {
+	buf := bytes.NewBuffer(nil)
+	common.Must(binary.Write(buf, binary.BigEndian, time))
+	var zero uint32
+	common.Must2(io.CopyN(buf, rand3.Reader, 4))
+	zero = crc32.ChecksumIEEE(buf.Bytes())
+	common.Must(binary.Write(buf, binary.BigEndian, zero))
+	aesBlock := NewCipherFromKey(cmdKey)
+	if buf.Len() != 16 {
+		panic("Size unexpected")
+	}
+	var result [16]byte
+	aesBlock.Encrypt(result[:], buf.Bytes())
+	return result
+}
+
+func NewCipherFromKey(cmdKey []byte) cipher.Block {
+	aesBlock, err := aes.NewCipher(KDF16(cmdKey, "AES Auth ID Encryption"))
+	if err != nil {
+		panic(err)
+	}
+	return aesBlock
+}
+
+type AuthIDDecoder struct {
+	s cipher.Block
+}
+
+func NewAuthIDDecoder(cmdKey []byte) *AuthIDDecoder {
+	return &AuthIDDecoder{NewCipherFromKey(cmdKey)}
+}
+
+func (aidd *AuthIDDecoder) Decode(data [16]byte) (int64, uint32, int32, []byte) {
+	aidd.s.Decrypt(data[:], data[:])
+	var t int64
+	var zero uint32
+	var rand int32
+	reader := bytes.NewReader(data[:])
+	common.Must(binary.Read(reader, binary.BigEndian, &t))
+	common.Must(binary.Read(reader, binary.BigEndian, &rand))
+	common.Must(binary.Read(reader, binary.BigEndian, &zero))
+	return t, zero, rand, data[:]
+}
+
+func NewAuthIDDecoderHolder() *AuthIDDecoderHolder {
+	return &AuthIDDecoderHolder{make(map[string]*AuthIDDecoderItem), antiReplayWindow.NewAntiReplayWindow(120)}
+}
+
+type AuthIDDecoderHolder struct {
+	aidhi map[string]*AuthIDDecoderItem
+	apw   *antiReplayWindow.AntiReplayWindow
+}
+
+type AuthIDDecoderItem struct {
+	dec    *AuthIDDecoder
+	ticket interface{}
+}
+
+func NewAuthIDDecoderItem(key [16]byte, ticket interface{}) *AuthIDDecoderItem {
+	return &AuthIDDecoderItem{
+		dec:    NewAuthIDDecoder(key[:]),
+		ticket: ticket,
+	}
+}
+
+func (a *AuthIDDecoderHolder) AddUser(key [16]byte, ticket interface{}) {
+	a.aidhi[string(key[:])] = NewAuthIDDecoderItem(key, ticket)
+}
+
+func (a *AuthIDDecoderHolder) RemoveUser(key [16]byte) {
+	delete(a.aidhi, string(key[:]))
+}
+
+func (a *AuthIDDecoderHolder) Match(AuthID [16]byte) (interface{}, error) {
+	if !a.apw.Check(AuthID[:]) {
+		return nil, errReplay
+	}
+	for _, v := range a.aidhi {
+
+		t, z, r, d := v.dec.Decode(AuthID)
+		if z != crc32.ChecksumIEEE(d[:12]) {
+			continue
+		}
+		if math.Abs(float64(t-time.Now().Unix())) > 120 {
+			continue
+		}
+		_ = r
+
+		return v.ticket, nil
+
+	}
+	return nil, errNotFound
+}
+
+var errNotFound = errors.New("user do not exist")
+
+var errReplay = errors.New("replayed request")

+ 141 - 0
proxy/vmess/aead/encrypt.go

@@ -0,0 +1,141 @@
+package aead
+
+import (
+	"bytes"
+	"crypto/aes"
+	"crypto/cipher"
+	"crypto/hmac"
+	"crypto/rand"
+	"encoding/binary"
+	"errors"
+	"io"
+	"time"
+	"v2ray.com/core/common"
+)
+
+func SealVMessAEADHeader(key [16]byte, data []byte) []byte {
+	authid := CreateAuthID(key[:], time.Now().Unix())
+
+	nonce := make([]byte, 8)
+	if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
+		panic(err.Error())
+	}
+
+	lengthbuf := bytes.NewBuffer(nil)
+
+	var HeaderDataLen uint16
+	HeaderDataLen = uint16(len(data))
+
+	common.Must(binary.Write(lengthbuf, binary.BigEndian, HeaderDataLen))
+
+	authidCheck := KDF16(key[:], "VMess AuthID Check Value", string(authid[:]), string(lengthbuf.Bytes()), string(nonce))
+
+	lengthbufb := lengthbuf.Bytes()
+
+	LengthMask := KDF16(key[:], "VMess AuthID Mask Value", string(authid[:]), string(nonce[:]))[:2]
+
+	lengthbufb[0] = lengthbufb[0] ^ LengthMask[0]
+	lengthbufb[1] = lengthbufb[1] ^ LengthMask[1]
+
+	HeaderAEADKey := KDF16(key[:], "VMess Header AEAD Key", string(authid[:]), string(nonce))
+
+	HeaderAEADNonce := KDF(key[:], "VMess Header AEAD Nonce", string(authid[:]), string(nonce))[:12]
+
+	block, err := aes.NewCipher(HeaderAEADKey)
+	if err != nil {
+		panic(err.Error())
+	}
+
+	headerAEAD, err := cipher.NewGCM(block)
+
+	if err != nil {
+		panic(err.Error())
+	}
+
+	headerSealed := headerAEAD.Seal(nil, HeaderAEADNonce, data, authid[:])
+
+	var outPutBuf = bytes.NewBuffer(nil)
+
+	common.Must2(outPutBuf.Write(authid[:])) //16
+
+	common.Must2(outPutBuf.Write(authidCheck)) //16
+
+	common.Must2(outPutBuf.Write(lengthbufb)) //2
+
+	common.Must2(outPutBuf.Write(nonce)) //8
+
+	common.Must2(outPutBuf.Write(headerSealed))
+
+	return outPutBuf.Bytes()
+}
+
+func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte, bool, error, int) {
+	var authidCheck [16]byte
+	var lengthbufb [2]byte
+	var nonce [8]byte
+
+	n, err := io.ReadFull(data, authidCheck[:])
+	if err != nil {
+		return nil, false, err, n
+	}
+
+	n2, err := io.ReadFull(data, lengthbufb[:])
+	if err != nil {
+		return nil, false, err, n + n2
+	}
+
+	n4, err := io.ReadFull(data, nonce[:])
+	if err != nil {
+		return nil, false, err, n + n2 + n4
+	}
+
+	//Unmask Length
+
+	LengthMask := KDF16(key[:], "VMess AuthID Mask Value", string(authid[:]), string(nonce[:]))[:2]
+
+	lengthbufb[0] = lengthbufb[0] ^ LengthMask[0]
+	lengthbufb[1] = lengthbufb[1] ^ LengthMask[1]
+
+	authidCheckV := KDF16(key[:], "VMess AuthID Check Value", string(authid[:]), string(lengthbufb[:]), string(nonce[:]))
+
+	if !hmac.Equal(authidCheckV, authidCheck[:]) {
+		return nil, true, errCheckMismatch, n + n2 + n4
+	}
+
+	var length uint16
+
+	common.Must(binary.Read(bytes.NewReader(lengthbufb[:]), binary.BigEndian, &length))
+
+	HeaderAEADKey := KDF16(key[:], "VMess Header AEAD Key", string(authid[:]), string(nonce[:]))
+
+	HeaderAEADNonce := KDF(key[:], "VMess Header AEAD Nonce", string(authid[:]), string(nonce[:]))[:12]
+
+	//16 == AEAD Tag size
+	header := make([]byte, length+16)
+
+	n3, err := io.ReadFull(data, header)
+	if err != nil {
+		return nil, false, err, n + n2 + n3 + n4
+	}
+
+	block, err := aes.NewCipher(HeaderAEADKey)
+	if err != nil {
+		panic(err.Error())
+	}
+
+	headerAEAD, err := cipher.NewGCM(block)
+
+	if err != nil {
+		panic(err.Error())
+	}
+
+	out, erropenAEAD := headerAEAD.Open(nil, HeaderAEADNonce, header, authid[:])
+
+	if erropenAEAD != nil {
+		return nil, true, erropenAEAD, n + n2 + n3 + n4
+	}
+
+	return out, false, nil, n + n2 + n3 + n4
+}
+
+var errCheckMismatch = errors.New("check verify failed")

+ 26 - 0
proxy/vmess/aead/kdf.go

@@ -0,0 +1,26 @@
+package aead
+
+import (
+	"crypto/hmac"
+	"crypto/sha256"
+	"hash"
+)
+
+func KDF(key []byte, path ...string) []byte {
+	hmacf := hmac.New(func() hash.Hash {
+		return sha256.New()
+	}, []byte("VMess AEAD KDF"))
+
+	for _, v := range path {
+		hmacf = hmac.New(func() hash.Hash {
+			return hmacf
+		}, []byte(v))
+	}
+	hmacf.Write(key)
+	return hmacf.Sum(nil)
+}
+
+func KDF16(key []byte, path ...string) []byte {
+	r := KDF(key, path...)
+	return r[:16]
+}

+ 93 - 12
proxy/vmess/encoding/client.go

@@ -1,12 +1,19 @@
 package encoding
 
 import (
+	"bytes"
+	"crypto/aes"
+	"crypto/cipher"
 	"crypto/md5"
 	"crypto/rand"
+	"crypto/sha256"
 	"encoding/binary"
+	"fmt"
 	"hash"
 	"hash/fnv"
 	"io"
+	"os"
+	vmessaead "v2ray.com/core/proxy/vmess/aead"
 
 	"golang.org/x/crypto/chacha20poly1305"
 
@@ -37,6 +44,8 @@ type ClientSession struct {
 	responseBodyIV  [16]byte
 	responseReader  io.Reader
 	responseHeader  byte
+
+	isAEADRequest bool
 }
 
 // NewClientSession creates a new ClientSession.
@@ -45,11 +54,29 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession {
 	common.Must2(rand.Read(randomBytes))
 
 	session := &ClientSession{}
+
+	session.isAEADRequest = false
+
+	if vmessexp, vmessexp_found := os.LookupEnv("VMESSAEADEXPERIMENT"); vmessexp_found {
+		if vmessexp == "y" {
+			session.isAEADRequest = true
+			fmt.Println("=======VMESSAEADEXPERIMENT ENABLED========")
+		}
+	}
+
 	copy(session.requestBodyKey[:], randomBytes[:16])
 	copy(session.requestBodyIV[:], randomBytes[16:32])
 	session.responseHeader = randomBytes[32]
-	session.responseBodyKey = md5.Sum(session.requestBodyKey[:])
-	session.responseBodyIV = md5.Sum(session.requestBodyIV[:])
+	if !session.isAEADRequest {
+		session.responseBodyKey = md5.Sum(session.requestBodyKey[:])
+		session.responseBodyIV = md5.Sum(session.requestBodyIV[:])
+	} else {
+		BodyKey := sha256.Sum256(session.requestBodyKey[:])
+		copy(session.responseBodyKey[:], BodyKey[:16])
+		BodyIV := sha256.Sum256(session.requestBodyKey[:])
+		copy(session.responseBodyIV[:], BodyIV[:16])
+	}
+
 	session.idHash = idHash
 
 	return session
@@ -58,9 +85,11 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession {
 func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error {
 	timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)()
 	account := header.User.Account.(*vmess.MemoryAccount)
-	idHash := c.idHash(account.AnyValidID().Bytes())
-	common.Must2(serial.WriteUint64(idHash, uint64(timestamp)))
-	common.Must2(writer.Write(idHash.Sum(nil)))
+	if !c.isAEADRequest {
+		idHash := c.idHash(account.AnyValidID().Bytes())
+		common.Must2(serial.WriteUint64(idHash, uint64(timestamp)))
+		common.Must2(writer.Write(idHash.Sum(nil)))
+	}
 
 	buffer := buf.New()
 	defer buffer.Release()
@@ -92,10 +121,18 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
 		fnv1a.Sum(hashBytes[:0])
 	}
 
-	iv := hashTimestamp(md5.New(), timestamp)
-	aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:])
-	aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
-	common.Must2(writer.Write(buffer.Bytes()))
+	if !c.isAEADRequest {
+		iv := hashTimestamp(md5.New(), timestamp)
+		aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:])
+		aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
+		common.Must2(writer.Write(buffer.Bytes()))
+	} else {
+		var fixedLengthCmdKey [16]byte
+		copy(fixedLengthCmdKey[:], account.ID.CmdKey())
+		vmessout := vmessaead.SealVMessAEADHeader(fixedLengthCmdKey, buffer.Bytes())
+		common.Must2(io.Copy(writer, bytes.NewReader(vmessout)))
+	}
+
 	return nil
 }
 
@@ -161,8 +198,49 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
 }
 
 func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) {
-	aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
-	c.responseReader = crypto.NewCryptionReader(aesStream, reader)
+	if !c.isAEADRequest {
+		aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
+		c.responseReader = crypto.NewCryptionReader(aesStream, reader)
+	} else {
+		resph := vmessaead.KDF16(c.responseBodyKey[:], "AEAD Resp Header Len Key")
+		respi := vmessaead.KDF(c.responseBodyIV[:], "AEAD Resp Header Len IV")[:12]
+
+		aesblock := common.Must2(aes.NewCipher(resph)).(cipher.Block)
+		aeadHeader := common.Must2(cipher.NewGCM(aesblock)).(cipher.AEAD)
+
+		var AEADLen [18]byte
+		var lenresp int
+
+		var lenrespr uint16
+
+		if _, err := io.ReadFull(reader, AEADLen[:]); err != nil {
+			return nil, newError("Unable to Read Header Len").Base(err)
+		}
+		if AEADLend, err := aeadHeader.Open(nil, respi, AEADLen[:], nil); err != nil {
+			return nil, newError("Failed To Decrypt Length").Base(err)
+		} else {
+			common.Must(binary.Read(bytes.NewReader(AEADLend), binary.BigEndian, &lenrespr))
+			lenresp = int(lenrespr)
+		}
+
+		resphc := vmessaead.KDF16(c.responseBodyKey[:], "AEAD Resp Header Key")
+		respic := vmessaead.KDF(c.responseBodyIV[:], "AEAD Resp Header IV")[:12]
+
+		aesblockc := common.Must2(aes.NewCipher(resphc)).(cipher.Block)
+		aeadHeaderc := common.Must2(cipher.NewGCM(aesblockc)).(cipher.AEAD)
+
+		respPayload := make([]byte, lenresp+16)
+
+		if _, err := io.ReadFull(reader, respPayload); err != nil {
+			return nil, newError("Unable to Read Header Data").Base(err)
+		}
+
+		if AEADData, err := aeadHeaderc.Open(nil, respic, respPayload, nil); err != nil {
+			return nil, newError("Failed To Decrypt Payload").Base(err)
+		} else {
+			c.responseReader = bytes.NewReader(AEADData)
+		}
+	}
 
 	buffer := buf.StackNew()
 	defer buffer.Release()
@@ -192,7 +270,10 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
 			header.Command = command
 		}
 	}
-
+	if c.isAEADRequest {
+		aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
+		c.responseReader = crypto.NewCryptionReader(aesStream, reader)
+	}
 	return header, nil
 }
 

+ 125 - 14
proxy/vmess/encoding/server.go

@@ -1,7 +1,11 @@
 package encoding
 
 import (
+	"bytes"
+	"crypto/aes"
+	"crypto/cipher"
 	"crypto/md5"
+	"crypto/sha256"
 	"encoding/binary"
 	"hash/fnv"
 	"io"
@@ -9,6 +13,7 @@ import (
 	"sync"
 	"time"
 	"v2ray.com/core/common/dice"
+	vmessaead "v2ray.com/core/proxy/vmess/aead"
 
 	"golang.org/x/crypto/chacha20poly1305"
 
@@ -99,6 +104,10 @@ type ServerSession struct {
 	responseBodyIV  [16]byte
 	responseWriter  io.Writer
 	responseHeader  byte
+
+	isAEADRequest bool
+
+	isAEADForced bool
 }
 
 // NewServerSession creates a new ServerSession, using the given UserValidator.
@@ -153,17 +162,44 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 		return nil, newError("failed to read request header").Base(err)
 	}
 
-	user, timestamp, valid := s.userValidator.Get(buffer.Bytes())
-	if !valid {
+	var decryptor io.Reader
+	var vmessAccount *vmess.MemoryAccount
+
+	user, foundAEAD := s.userValidator.GetAEAD(buffer.Bytes())
+
+	var fixedSizeAuthID [16]byte
+	copy(fixedSizeAuthID[:], buffer.Bytes())
+
+	if foundAEAD == true {
+		vmessAccount = user.Account.(*vmess.MemoryAccount)
+		var fixedSizeCmdKey [16]byte
+		copy(fixedSizeCmdKey[:], vmessAccount.ID.CmdKey())
+		aeadData, shouldDrain, errorReason, bytesRead := vmessaead.OpenVMessAEADHeader(fixedSizeCmdKey, fixedSizeAuthID, reader)
+		if errorReason != nil {
+			if shouldDrain {
+				readSizeRemain -= bytesRead
+				return nil, drainConnection(newError("AEAD read failed").Base(errorReason))
+			} else {
+				return nil, drainConnection(newError("AEAD read failed, drain skiped").Base(errorReason))
+			}
+		}
+		decryptor = bytes.NewReader(aeadData)
+		s.isAEADRequest = true
+	} else if !s.isAEADForced {
+		userLegacy, timestamp, valid, userValidationError := s.userValidator.Get(buffer.Bytes())
+		if !valid || userValidationError != nil {
+			return nil, drainConnection(newError("invalid user").Base(userValidationError))
+		}
+		user = userLegacy
+		iv := hashTimestamp(md5.New(), timestamp)
+		vmessAccount = userLegacy.Account.(*vmess.MemoryAccount)
+
+		aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv[:])
+		decryptor = crypto.NewCryptionReader(aesStream, reader)
+	} else {
 		return nil, drainConnection(newError("invalid user"))
 	}
 
-	iv := hashTimestamp(md5.New(), timestamp)
-	vmessAccount := user.Account.(*vmess.MemoryAccount)
-
-	aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv[:])
-	decryptor := crypto.NewCryptionReader(aesStream, reader)
-
 	readSizeRemain -= int(buffer.Len())
 	buffer.Clear()
 	if _, err := buffer.ReadFullFrom(decryptor, 38); err != nil {
@@ -182,7 +218,16 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 	sid.key = s.requestBodyKey
 	sid.nonce = s.requestBodyIV
 	if !s.sessionHistory.addIfNotExits(sid) {
-		return nil, drainConnection(newError("duplicated session id, possibly under replay attack"))
+		if !s.isAEADRequest {
+			drainErr := s.userValidator.BurnTaintFuse(fixedSizeAuthID[:])
+			if drainErr != nil {
+				return nil, drainConnection(newError("duplicated session id, possibly under replay attack, and failed to taint userHash").Base(drainErr))
+			}
+			return nil, drainConnection(newError("duplicated session id, possibly under replay attack, userHash tainted"))
+		} else {
+			return nil, newError("duplicated session id, possibly under replay attack, but this is a AEAD request")
+		}
+
 	}
 
 	s.responseHeader = buffer.Byte(33)             // 1 byte
@@ -205,11 +250,25 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 
 	if padingLen > 0 {
 		if _, err := buffer.ReadFullFrom(decryptor, int32(padingLen)); err != nil {
+			if !s.isAEADRequest {
+				burnErr := s.userValidator.BurnTaintFuse(fixedSizeAuthID[:])
+				if burnErr != nil {
+					return nil, newError("failed to read padding, failed to taint userHash").Base(burnErr).Base(err)
+				}
+				return nil, newError("failed to read padding, userHash tainted").Base(err)
+			}
 			return nil, newError("failed to read padding").Base(err)
 		}
 	}
 
 	if _, err := buffer.ReadFullFrom(decryptor, 4); err != nil {
+		if !s.isAEADRequest {
+			burnErr := s.userValidator.BurnTaintFuse(fixedSizeAuthID[:])
+			if burnErr != nil {
+				return nil, newError("failed to read checksum, failed to taint userHash").Base(burnErr).Base(err)
+			}
+			return nil, newError("failed to read checksum, userHash tainted").Base(err)
+		}
 		return nil, newError("failed to read checksum").Base(err)
 	}
 
@@ -219,8 +278,18 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 	expectedHash := binary.BigEndian.Uint32(buffer.BytesFrom(-4))
 
 	if actualHash != expectedHash {
-		//It is possible that we are under attack described in https://github.com/v2ray/v2ray-core/issues/2523
-		return nil, drainConnection(newError("invalid auth"))
+		if !s.isAEADRequest {
+			Autherr := newError("invalid auth, legacy userHash tainted")
+			burnErr := s.userValidator.BurnTaintFuse(fixedSizeAuthID[:])
+			if burnErr != nil {
+				Autherr = newError("invalid auth, can't taint legacy userHash").Base(burnErr)
+			}
+			//It is possible that we are under attack described in https://github.com/v2ray/v2ray-core/issues/2523
+			return nil, drainConnection(Autherr)
+		} else {
+			return nil, newError("invalid auth, but this is a AEAD request")
+		}
+
 	}
 
 	if request.Address == nil {
@@ -299,18 +368,60 @@ func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
 
 // EncodeResponseHeader writes encoded response header into the given writer.
 func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, writer io.Writer) {
-	s.responseBodyKey = md5.Sum(s.requestBodyKey[:])
-	s.responseBodyIV = md5.Sum(s.requestBodyIV[:])
+	var encryptionWriter io.Writer
+	if !s.isAEADRequest {
+		s.responseBodyKey = md5.Sum(s.requestBodyKey[:])
+		s.responseBodyIV = md5.Sum(s.requestBodyIV[:])
+	} else {
+		BodyKey := sha256.Sum256(s.requestBodyKey[:])
+		copy(s.responseBodyKey[:], BodyKey[:16])
+		BodyIV := sha256.Sum256(s.requestBodyKey[:])
+		copy(s.responseBodyIV[:], BodyIV[:16])
+	}
 
 	aesStream := crypto.NewAesEncryptionStream(s.responseBodyKey[:], s.responseBodyIV[:])
-	encryptionWriter := crypto.NewCryptionWriter(aesStream, writer)
+	encryptionWriter = crypto.NewCryptionWriter(aesStream, writer)
 	s.responseWriter = encryptionWriter
 
+	aeadBuffer := bytes.NewBuffer(nil)
+
+	if s.isAEADRequest {
+		encryptionWriter = aeadBuffer
+	}
+
 	common.Must2(encryptionWriter.Write([]byte{s.responseHeader, byte(header.Option)}))
 	err := MarshalCommand(header.Command, encryptionWriter)
 	if err != nil {
 		common.Must2(encryptionWriter.Write([]byte{0x00, 0x00}))
 	}
+
+	if s.isAEADRequest {
+
+		resph := vmessaead.KDF16(s.responseBodyKey[:], "AEAD Resp Header Len Key")
+		respi := vmessaead.KDF(s.responseBodyIV[:], "AEAD Resp Header Len IV")[:12]
+
+		aesblock := common.Must2(aes.NewCipher(resph)).(cipher.Block)
+		aeadHeader := common.Must2(cipher.NewGCM(aesblock)).(cipher.AEAD)
+
+		aeadlenBuf := bytes.NewBuffer(nil)
+
+		var aeadLen uint16
+		aeadLen = uint16(aeadBuffer.Len())
+
+		common.Must(binary.Write(aeadlenBuf, binary.BigEndian, aeadLen))
+
+		sealedLen := aeadHeader.Seal(nil, respi, aeadlenBuf.Bytes(), nil)
+		common.Must2(io.Copy(writer, bytes.NewReader(sealedLen)))
+
+		resphc := vmessaead.KDF16(s.responseBodyKey[:], "AEAD Resp Header Key")
+		respic := vmessaead.KDF(s.responseBodyIV[:], "AEAD Resp Header IV")[:12]
+
+		aesblockc := common.Must2(aes.NewCipher(resphc)).(cipher.Block)
+		aeadHeaderc := common.Must2(cipher.NewGCM(aesblockc)).(cipher.AEAD)
+
+		sealed := aeadHeaderc.Seal(nil, respic, aeadBuffer.Bytes(), nil)
+		common.Must2(io.Copy(writer, bytes.NewReader(sealed)))
+	}
 }
 
 // EncodeResponseBody returns a Writer that auto-encrypt content written by caller.

+ 64 - 15
proxy/vmess/validator.go

@@ -8,6 +8,7 @@ import (
 	"sync"
 	"time"
 	"v2ray.com/core/common/dice"
+	"v2ray.com/core/proxy/vmess/aead"
 
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/protocol"
@@ -28,27 +29,33 @@ type user struct {
 // TimedUserValidator is a user Validator based on time.
 type TimedUserValidator struct {
 	sync.RWMutex
-	users         []*user
-	userHash      map[[16]byte]indexTimePair
-	hasher        protocol.IDHash
-	baseTime      protocol.Timestamp
-	task          *task.Periodic
+	users    []*user
+	userHash map[[16]byte]indexTimePair
+	hasher   protocol.IDHash
+	baseTime protocol.Timestamp
+	task     *task.Periodic
+
 	behaviorSeed  uint64
 	behaviorFused bool
+
+	aeadDecoderHolder *aead.AuthIDDecoderHolder
 }
 
 type indexTimePair struct {
 	user    *user
 	timeInc uint32
+
+	taintedFuse *bool
 }
 
 // NewTimedUserValidator creates a new TimedUserValidator.
 func NewTimedUserValidator(hasher protocol.IDHash) *TimedUserValidator {
 	tuv := &TimedUserValidator{
-		users:    make([]*user, 0, 16),
-		userHash: make(map[[16]byte]indexTimePair, 1024),
-		hasher:   hasher,
-		baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2),
+		users:             make([]*user, 0, 16),
+		userHash:          make(map[[16]byte]indexTimePair, 1024),
+		hasher:            hasher,
+		baseTime:          protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2),
+		aeadDecoderHolder: aead.NewAuthIDDecoderHolder(),
 	}
 	tuv.task = &task.Periodic{
 		Interval: updateInterval,
@@ -76,8 +83,9 @@ func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *
 			idHash.Reset()
 
 			v.userHash[hashValue] = indexTimePair{
-				user:    user,
-				timeInc: uint32(ts - v.baseTime),
+				user:        user,
+				timeInc:     uint32(ts - v.baseTime),
+				taintedFuse: new(bool),
 			}
 		}
 	}
@@ -128,15 +136,19 @@ func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error {
 	v.users = append(v.users, uu)
 	v.generateNewHashes(protocol.Timestamp(nowSec), uu)
 
+	account := uu.user.Account.(*MemoryAccount)
 	if v.behaviorFused == false {
-		account := uu.user.Account.(*MemoryAccount)
 		v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), account.ID.Bytes())
 	}
 
+	var cmdkeyfl [16]byte
+	copy(cmdkeyfl[:], account.ID.CmdKey())
+	v.aeadDecoderHolder.AddUser(cmdkeyfl, u)
+
 	return nil
 }
 
-func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool) {
+func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool, error) {
 	defer v.RUnlock()
 	v.RLock()
 
@@ -148,9 +160,25 @@ func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protoco
 	if found {
 		var user protocol.MemoryUser
 		user = pair.user.user
-		return &user, protocol.Timestamp(pair.timeInc) + v.baseTime, true
+		if *pair.taintedFuse == false {
+			return &user, protocol.Timestamp(pair.timeInc) + v.baseTime, true, nil
+		}
+		return nil, 0, false, ErrTainted
+	}
+	return nil, 0, false, ErrNotFound
+}
+
+func (v *TimedUserValidator) GetAEAD(userHash []byte) (*protocol.MemoryUser, bool) {
+	defer v.RUnlock()
+	v.RLock()
+	var userHashFL [16]byte
+	copy(userHashFL[:], userHash)
+
+	userd, err := v.aeadDecoderHolder.Match(userHashFL)
+	if err != nil {
+		return nil, false
 	}
-	return nil, 0, false
+	return userd.(*protocol.MemoryUser), true
 }
 
 func (v *TimedUserValidator) Remove(email string) bool {
@@ -162,6 +190,9 @@ func (v *TimedUserValidator) Remove(email string) bool {
 	for i, u := range v.users {
 		if strings.EqualFold(u.user.Email, email) {
 			idx = i
+			var cmdkeyfl [16]byte
+			copy(cmdkeyfl[:], u.user.Account.(*MemoryAccount).ID.CmdKey())
+			v.aeadDecoderHolder.RemoveUser(cmdkeyfl)
 			break
 		}
 	}
@@ -191,3 +222,21 @@ func (v *TimedUserValidator) GetBehaviorSeed() uint64 {
 	}
 	return v.behaviorSeed
 }
+
+func (v *TimedUserValidator) BurnTaintFuse(userHash []byte) error {
+	v.Lock()
+	defer v.Unlock()
+	var userHashFL [16]byte
+	copy(userHashFL[:], userHash)
+
+	pair, found := v.userHash[userHashFL]
+	if found {
+		*pair.taintedFuse = true
+		return nil
+	}
+	return ErrNotFound
+}
+
+var ErrNotFound = newError("Not Found")
+
+var ErrTainted = newError("ErrTainted")

+ 2 - 2
proxy/vmess/validator_test.go

@@ -39,7 +39,7 @@ func TestUserValidator(t *testing.T) {
 			common.Must2(serial.WriteUint64(idHash, uint64(ts)))
 			userHash := idHash.Sum(nil)
 
-			euser, ets, found := v.Get(userHash)
+			euser, ets, found, _ := v.Get(userHash)
 			if !found {
 				t.Fatal("user not found")
 			}
@@ -67,7 +67,7 @@ func TestUserValidator(t *testing.T) {
 			common.Must2(serial.WriteUint64(idHash, uint64(ts)))
 			userHash := idHash.Sum(nil)
 
-			euser, _, found := v.Get(userHash)
+			euser, _, found, _ := v.Get(userHash)
 			if found || euser != nil {
 				t.Error("unexpected user")
 			}