Forráskód Böngészése

introduce in-memory user

Darien Raymond 7 éve
szülő
commit
54e1bb96cc

+ 5 - 1
app/proxyman/command/command.go

@@ -39,7 +39,11 @@ func (op *AddUserOperation) ApplyInbound(ctx context.Context, handler core.Inbou
 	if !ok {
 		return newError("proxy is not a UserManager")
 	}
-	return um.AddUser(ctx, op.User)
+	mUser, err := op.User.ToMemoryUser()
+	if err != nil {
+		return newError("failed to parse user").Base(err)
+	}
+	return um.AddUser(ctx, mUser)
 }
 
 // ApplyInbound implements InboundOperation.

+ 2 - 2
app/router/condition_test.go

@@ -126,11 +126,11 @@ func TestRoutingRule(t *testing.T) {
 			},
 			test: []ruleTest{
 				{
-					input:  protocol.ContextWithUser(context.Background(), &protocol.User{Email: "admin@v2ray.com"}),
+					input:  protocol.ContextWithUser(context.Background(), &protocol.MemoryUser{Email: "admin@v2ray.com"}),
 					output: true,
 				},
 				{
-					input:  protocol.ContextWithUser(context.Background(), &protocol.User{Email: "love@v2ray.com"}),
+					input:  protocol.ContextWithUser(context.Background(), &protocol.MemoryUser{Email: "love@v2ray.com"}),
 					output: false,
 				},
 				{

+ 3 - 3
common/protocol/context.go

@@ -12,17 +12,17 @@ const (
 )
 
 // ContextWithUser returns a context combined with a User.
-func ContextWithUser(ctx context.Context, user *User) context.Context {
+func ContextWithUser(ctx context.Context, user *MemoryUser) context.Context {
 	return context.WithValue(ctx, userKey, user)
 }
 
 // UserFromContext extracts a User from the given context, if any.
-func UserFromContext(ctx context.Context) *User {
+func UserFromContext(ctx context.Context) *MemoryUser {
 	v := ctx.Value(userKey)
 	if v == nil {
 		return nil
 	}
-	return v.(*User)
+	return v.(*MemoryUser)
 }
 
 func ContextWithRequestHeader(ctx context.Context, request *RequestHeader) context.Context {

+ 1 - 1
common/protocol/headers.go

@@ -47,7 +47,7 @@ type RequestHeader struct {
 	Security SecurityType
 	Port     net.Port
 	Address  net.Address
-	User     *User
+	User     *MemoryUser
 }
 
 func (h *RequestHeader) Destination() net.Destination {

+ 16 - 13
common/protocol/server_spec.go

@@ -46,11 +46,11 @@ func (s *timeoutValidStrategy) Invalidate() {
 type ServerSpec struct {
 	sync.RWMutex
 	dest  net.Destination
-	users []*User
+	users []*MemoryUser
 	valid ValidationStrategy
 }
 
-func NewServerSpec(dest net.Destination, valid ValidationStrategy, users ...*User) *ServerSpec {
+func NewServerSpec(dest net.Destination, valid ValidationStrategy, users ...*MemoryUser) *ServerSpec {
 	return &ServerSpec{
 		dest:  dest,
 		users: users,
@@ -58,33 +58,36 @@ func NewServerSpec(dest net.Destination, valid ValidationStrategy, users ...*Use
 	}
 }
 
-func NewServerSpecFromPB(spec ServerEndpoint) *ServerSpec {
+func NewServerSpecFromPB(spec ServerEndpoint) (*ServerSpec, error) {
 	dest := net.TCPDestination(spec.Address.AsAddress(), net.Port(spec.Port))
-	return NewServerSpec(dest, AlwaysValid(), spec.User...)
+	mUsers := make([]*MemoryUser, len(spec.User))
+	for idx, u := range spec.User {
+		mUser, err := u.ToMemoryUser()
+		if err != nil {
+			return nil, err
+		}
+		mUsers[idx] = mUser
+	}
+	return NewServerSpec(dest, AlwaysValid(), mUsers...), nil
 }
 
 func (s *ServerSpec) Destination() net.Destination {
 	return s.dest
 }
 
-func (s *ServerSpec) HasUser(user *User) bool {
+func (s *ServerSpec) HasUser(user *MemoryUser) bool {
 	s.RLock()
 	defer s.RUnlock()
 
-	accountA, err := user.GetTypedAccount()
-	if err != nil {
-		return false
-	}
 	for _, u := range s.users {
-		accountB, err := u.GetTypedAccount()
-		if err == nil && accountA.Equals(accountB) {
+		if u.Account.Equals(user.Account) {
 			return true
 		}
 	}
 	return false
 }
 
-func (s *ServerSpec) AddUser(user *User) {
+func (s *ServerSpec) AddUser(user *MemoryUser) {
 	if s.HasUser(user) {
 		return
 	}
@@ -95,7 +98,7 @@ func (s *ServerSpec) AddUser(user *User) {
 	s.users = append(s.users, user)
 }
 
-func (s *ServerSpec) PickUser() *User {
+func (s *ServerSpec) PickUser() *MemoryUser {
 	s.RLock()
 	defer s.RUnlock()
 

+ 8 - 9
common/protocol/server_spec_test.go

@@ -6,7 +6,6 @@ import (
 
 	"v2ray.com/core/common/net"
 	. "v2ray.com/core/common/protocol"
-	"v2ray.com/core/common/serial"
 	"v2ray.com/core/common/uuid"
 	"v2ray.com/core/proxy/vmess"
 	. "v2ray.com/ext/assert"
@@ -40,26 +39,26 @@ func TestUserInServerSpec(t *testing.T) {
 	uuid1 := uuid.New()
 	uuid2 := uuid.New()
 
-	spec := NewServerSpec(net.Destination{}, AlwaysValid(), &User{
+	spec := NewServerSpec(net.Destination{}, AlwaysValid(), &MemoryUser{
 		Email:   "test1@v2ray.com",
-		Account: serial.ToTypedMessage(&vmess.Account{Id: uuid1.String()}),
+		Account: &vmess.Account{Id: uuid1.String()},
 	})
-	assert(spec.HasUser(&User{
+	assert(spec.HasUser(&MemoryUser{
 		Email:   "test1@v2ray.com",
-		Account: serial.ToTypedMessage(&vmess.Account{Id: uuid2.String()}),
+		Account: &vmess.Account{Id: uuid2.String()},
 	}), IsFalse)
 
-	spec.AddUser(&User{Email: "test2@v2ray.com"})
-	assert(spec.HasUser(&User{
+	spec.AddUser(&MemoryUser{Email: "test2@v2ray.com"})
+	assert(spec.HasUser(&MemoryUser{
 		Email:   "test1@v2ray.com",
-		Account: serial.ToTypedMessage(&vmess.Account{Id: uuid1.String()}),
+		Account: &vmess.Account{Id: uuid1.String()},
 	}), IsTrue)
 }
 
 func TestPickUser(t *testing.T) {
 	assert := With(t)
 
-	spec := NewServerSpec(net.Destination{}, AlwaysValid(), &User{Email: "test1@v2ray.com"}, &User{Email: "test2@v2ray.com"}, &User{Email: "test3@v2ray.com"})
+	spec := NewServerSpec(net.Destination{}, AlwaysValid(), &MemoryUser{Email: "test1@v2ray.com"}, &MemoryUser{Email: "test2@v2ray.com"}, &MemoryUser{Email: "test3@v2ray.com"})
 	user := spec.PickUser()
 	assert(user.Email, HasSuffix, "@v2ray.com")
 }

+ 18 - 0
common/protocol/user.go

@@ -17,3 +17,21 @@ func (u *User) GetTypedAccount() (Account, error) {
 	}
 	return nil, newError("Unknown account type: ", u.Account.Type)
 }
+
+func (u *User) ToMemoryUser() (*MemoryUser, error) {
+	account, err := u.GetTypedAccount()
+	if err != nil {
+		return nil, err
+	}
+	return &MemoryUser{
+		Account: account,
+		Email:   u.Email,
+		Level:   u.Level,
+	}, nil
+}
+
+type MemoryUser struct {
+	Account Account
+	Email   string
+	Level   uint32
+}

+ 45 - 7
common/task/task.go

@@ -2,12 +2,26 @@ package task
 
 import (
 	"context"
+	"strings"
 
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/signal/semaphore"
 )
 
 type Task func() error
 
+type MultiError []error
+
+func (e MultiError) Error() string {
+	var r strings.Builder
+	common.Must2(r.WriteString("multierr: "))
+	for _, err := range e {
+		common.Must2(r.WriteString(err.Error()))
+		common.Must2(r.WriteString(" | "))
+	}
+	return r.String()
+}
+
 type executionContext struct {
 	ctx       context.Context
 	tasks     []Task
@@ -59,20 +73,44 @@ func Parallel(tasks ...Task) ExecutionOption {
 	}
 }
 
+// Sequential runs all tasks sequentially, and returns the first error encountered.Sequential
+// Once a task returns an error, the following tasks will not run.
 func Sequential(tasks ...Task) ExecutionOption {
 	return func(c *executionContext) {
-		if len(tasks) == 0 {
+		switch len(tasks) {
+		case 0:
 			return
+		case 1:
+			c.tasks = append(c.tasks, tasks[0])
+		default:
+			c.tasks = append(c.tasks, func() error {
+				return execute(tasks...)
+			})
 		}
+	}
+}
 
-		if len(tasks) == 1 {
-			c.tasks = append(c.tasks, tasks[0])
+func SequentialAll(tasks ...Task) ExecutionOption {
+	return func(c *executionContext) {
+		switch len(tasks) {
+		case 0:
 			return
+		case 1:
+			c.tasks = append(c.tasks, tasks[0])
+		default:
+			c.tasks = append(c.tasks, func() error {
+				var merr MultiError
+				for _, task := range tasks {
+					if err := task(); err != nil {
+						merr = append(merr, err)
+					}
+				}
+				if len(merr) == 0 {
+					return nil
+				}
+				return merr
+			})
 		}
-
-		c.tasks = append(c.tasks, func() error {
-			return execute(tasks...)
-		})
 	}
 }
 

+ 1 - 1
proxy/proxy.go

@@ -38,7 +38,7 @@ type Dialer interface {
 // UserManager is the interface for Inbounds and Outbounds that can manage their users.
 type UserManager interface {
 	// AddUser adds a new user.
-	AddUser(context.Context, *protocol.User) error
+	AddUser(context.Context, *protocol.MemoryUser) error
 
 	// RemoveUser removes a user by email.
 	RemoveUser(context.Context, string) error

+ 8 - 5
proxy/shadowsocks/client.go

@@ -27,7 +27,11 @@ type Client struct {
 func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
 	serverList := protocol.NewServerList()
 	for _, rec := range config.Server {
-		serverList.AddServer(protocol.NewServerSpecFromPB(*rec))
+		s, err := protocol.NewServerSpecFromPB(*rec)
+		if err != nil {
+			return nil, newError("failed to parse server spec").Base(err)
+		}
+		serverList.AddServer(s)
 	}
 	if serverList.Size() == 0 {
 		return nil, newError("0 server")
@@ -81,11 +85,10 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial
 	}
 
 	user := server.PickUser()
-	rawAccount, err := user.GetTypedAccount()
-	if err != nil {
-		return newError("failed to get a valid user account").AtWarning().Base(err)
+	account, ok := user.Account.(*MemoryAccount)
+	if !ok {
+		return newError("user account is not valid")
 	}
-	account := rawAccount.(*MemoryAccount)
 	request.User = user
 
 	if account.OneTimeAuth == Account_Auto || account.OneTimeAuth == Account_Enabled {

+ 11 - 35
proxy/shadowsocks/protocol.go

@@ -27,12 +27,8 @@ var addrParser = protocol.NewAddressParser(
 )
 
 // ReadTCPSession reads a Shadowsocks TCP session from the given reader, returns its header and remaining parts.
-func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) {
-	rawAccount, err := user.GetTypedAccount()
-	if err != nil {
-		return nil, nil, newError("failed to parse account").Base(err).AtError()
-	}
-	account := rawAccount.(*MemoryAccount)
+func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) {
+	account := user.Account.(*MemoryAccount)
 
 	buffer := buf.New()
 	defer buffer.Release()
@@ -116,11 +112,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
 // WriteTCPRequest writes Shadowsocks request into the given writer, and returns a writer for body.
 func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) {
 	user := request.User
-	rawAccount, err := user.GetTypedAccount()
-	if err != nil {
-		return nil, newError("failed to parse account").Base(err).AtError()
-	}
-	account := rawAccount.(*MemoryAccount)
+	account := user.Account.(*MemoryAccount)
 
 	if account.Cipher.IsAEAD() {
 		request.Option.Clear(RequestOptionOneTimeAuth)
@@ -167,17 +159,13 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri
 	return chunkWriter, nil
 }
 
-func ReadTCPResponse(user *protocol.User, reader io.Reader) (buf.Reader, error) {
-	rawAccount, err := user.GetTypedAccount()
-	if err != nil {
-		return nil, newError("failed to parse account").Base(err).AtError()
-	}
-	account := rawAccount.(*MemoryAccount)
+func ReadTCPResponse(user *protocol.MemoryUser, reader io.Reader) (buf.Reader, error) {
+	account := user.Account.(*MemoryAccount)
 
 	var iv []byte
 	if account.Cipher.IVSize() > 0 {
 		iv = make([]byte, account.Cipher.IVSize())
-		if _, err = io.ReadFull(reader, iv); err != nil {
+		if _, err := io.ReadFull(reader, iv); err != nil {
 			return nil, newError("failed to read IV").Base(err)
 		}
 	}
@@ -187,11 +175,7 @@ func ReadTCPResponse(user *protocol.User, reader io.Reader) (buf.Reader, error)
 
 func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) {
 	user := request.User
-	rawAccount, err := user.GetTypedAccount()
-	if err != nil {
-		return nil, newError("failed to parse account.").Base(err).AtError()
-	}
-	account := rawAccount.(*MemoryAccount)
+	account := user.Account.(*MemoryAccount)
 
 	var iv []byte
 	if account.Cipher.IVSize() > 0 {
@@ -207,11 +191,7 @@ func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Wr
 
 func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buffer, error) {
 	user := request.User
-	rawAccount, err := user.GetTypedAccount()
-	if err != nil {
-		return nil, newError("failed to parse account.").Base(err).AtError()
-	}
-	account := rawAccount.(*MemoryAccount)
+	account := user.Account.(*MemoryAccount)
 
 	buffer := buf.New()
 	ivLen := account.Cipher.IVSize()
@@ -239,12 +219,8 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
 	return buffer, nil
 }
 
-func DecodeUDPPacket(user *protocol.User, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
-	rawAccount, err := user.GetTypedAccount()
-	if err != nil {
-		return nil, nil, newError("failed to parse account").Base(err).AtError()
-	}
-	account := rawAccount.(*MemoryAccount)
+func DecodeUDPPacket(user *protocol.MemoryUser, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
+	account := user.Account.(*MemoryAccount)
 
 	var iv []byte
 	if !account.Cipher.IsAEAD() && account.Cipher.IVSize() > 0 {
@@ -306,7 +282,7 @@ func DecodeUDPPacket(user *protocol.User, payload *buf.Buffer) (*protocol.Reques
 
 type UDPReader struct {
 	Reader io.Reader
-	User   *protocol.User
+	User   *protocol.MemoryUser
 }
 
 func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {

+ 16 - 10
proxy/shadowsocks/protocol_test.go

@@ -12,6 +12,12 @@ import (
 	. "v2ray.com/ext/assert"
 )
 
+func toAccount(a *Account) protocol.Account {
+	account, err := a.AsAccount()
+	common.Must(err)
+	return account
+}
+
 func TestUDPEncoding(t *testing.T) {
 	assert := With(t)
 
@@ -20,9 +26,9 @@ func TestUDPEncoding(t *testing.T) {
 		Command: protocol.RequestCommandUDP,
 		Address: net.LocalHostIP,
 		Port:    1234,
-		User: &protocol.User{
+		User: &protocol.MemoryUser{
 			Email: "love@v2ray.com",
-			Account: serial.ToTypedMessage(&Account{
+			Account: toAccount(&Account{
 				Password:   "shadowsocks-password",
 				CipherType: CipherType_AES_128_CFB,
 				Ota:        Account_Disabled,
@@ -57,9 +63,9 @@ func TestTCPRequest(t *testing.T) {
 				Address: net.LocalHostIP,
 				Option:  RequestOptionOneTimeAuth,
 				Port:    1234,
-				User: &protocol.User{
+				User: &protocol.MemoryUser{
 					Email: "love@v2ray.com",
-					Account: serial.ToTypedMessage(&Account{
+					Account: toAccount(&Account{
 						Password:   "tcp-password",
 						CipherType: CipherType_CHACHA20,
 					}),
@@ -74,9 +80,9 @@ func TestTCPRequest(t *testing.T) {
 				Address: net.LocalHostIPv6,
 				Option:  RequestOptionOneTimeAuth,
 				Port:    1234,
-				User: &protocol.User{
+				User: &protocol.MemoryUser{
 					Email: "love@v2ray.com",
-					Account: serial.ToTypedMessage(&Account{
+					Account: toAccount(&Account{
 						Password:   "password",
 						CipherType: CipherType_AES_256_CFB,
 					}),
@@ -91,9 +97,9 @@ func TestTCPRequest(t *testing.T) {
 				Address: net.DomainAddress("v2ray.com"),
 				Option:  RequestOptionOneTimeAuth,
 				Port:    1234,
-				User: &protocol.User{
+				User: &protocol.MemoryUser{
 					Email: "love@v2ray.com",
-					Account: serial.ToTypedMessage(&Account{
+					Account: toAccount(&Account{
 						Password:   "password",
 						CipherType: CipherType_CHACHA20_IETF,
 					}),
@@ -135,8 +141,8 @@ func TestTCPRequest(t *testing.T) {
 func TestUDPReaderWriter(t *testing.T) {
 	assert := With(t)
 
-	user := &protocol.User{
-		Account: serial.ToTypedMessage(&Account{
+	user := &protocol.MemoryUser{
+		Account: toAccount(&Account{
 			Password:   "test-password",
 			CipherType: CipherType_CHACHA20_IETF,
 		}),

+ 12 - 13
proxy/shadowsocks/server.go

@@ -20,10 +20,9 @@ import (
 )
 
 type Server struct {
-	config  ServerConfig
-	user    *protocol.User
-	account *MemoryAccount
-	v       *core.Instance
+	config ServerConfig
+	user   *protocol.MemoryUser
+	v      *core.Instance
 }
 
 // NewServer create a new Shadowsocks server.
@@ -32,17 +31,15 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
 		return nil, newError("user is not specified")
 	}
 
-	rawAccount, err := config.User.GetTypedAccount()
+	mUser, err := config.User.ToMemoryUser()
 	if err != nil {
-		return nil, newError("failed to get user account").Base(err)
+		return nil, newError("failed to parse user account").Base(err)
 	}
-	account := rawAccount.(*MemoryAccount)
 
 	s := &Server{
-		config:  *config,
-		user:    config.GetUser(),
-		account: account,
-		v:       core.MustFromContext(ctx),
+		config: *config,
+		user:   mUser,
+		v:      core.MustFromContext(ctx),
 	}
 
 	return s, nil
@@ -90,6 +87,8 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 		conn.Write(data.Bytes())
 	})
 
+	account := s.user.Account.(*MemoryAccount)
+
 	reader := buf.NewReader(conn)
 	for {
 		mpayload, err := reader.ReadMultiBuffer()
@@ -113,13 +112,13 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 				continue
 			}
 
-			if request.Option.Has(RequestOptionOneTimeAuth) && s.account.OneTimeAuth == Account_Disabled {
+			if request.Option.Has(RequestOptionOneTimeAuth) && account.OneTimeAuth == Account_Disabled {
 				newError("client payload enables OTA but server doesn't allow it").WriteToLog(session.ExportIDToError(ctx))
 				payload.Release()
 				continue
 			}
 
-			if !request.Option.Has(RequestOptionOneTimeAuth) && s.account.OneTimeAuth == Account_Enabled {
+			if !request.Option.Has(RequestOptionOneTimeAuth) && account.OneTimeAuth == Account_Enabled {
 				newError("client payload disables OTA but server forces it").WriteToLog(session.ExportIDToError(ctx))
 				payload.Release()
 				continue

+ 5 - 1
proxy/socks/client.go

@@ -28,7 +28,11 @@ type Client struct {
 func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
 	serverList := protocol.NewServerList()
 	for _, rec := range config.Server {
-		serverList.AddServer(protocol.NewServerSpecFromPB(*rec))
+		s, err := protocol.NewServerSpecFromPB(*rec)
+		if err != nil {
+			return nil, newError("failed to get server spec").Base(err)
+		}
+		serverList.AddServer(s)
 	}
 	if serverList.Size() == 0 {
 		return nil, newError("0 target server")

+ 1 - 5
proxy/socks/protocol.go

@@ -350,11 +350,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 
 	common.Must2(b.WriteBytes(socks5Version, 0x01, authByte))
 	if authByte == authPassword {
-		rawAccount, err := request.User.GetTypedAccount()
-		if err != nil {
-			return nil, err
-		}
-		account := rawAccount.(*Account)
+		account := request.User.Account.(*Account)
 
 		common.Must2(b.WriteBytes(0x01, byte(len(account.Username))))
 		common.Must2(b.Write([]byte(account.Username)))

+ 3 - 6
proxy/vmess/encoding/client.go

@@ -58,11 +58,8 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession {
 
 func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error {
 	timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)()
-	account, err := header.User.GetTypedAccount()
-	if err != nil {
-		return newError("failed to get user account: ", err).AtError()
-	}
-	idHash := c.idHash(account.(*vmess.InternalAccount).AnyValidID().Bytes())
+	account := header.User.Account.(*vmess.InternalAccount)
+	idHash := c.idHash(account.AnyValidID().Bytes())
 	common.Must2(idHash.Write(timestamp.Bytes(nil)))
 	common.Must2(writer.Write(idHash.Sum(nil)))
 
@@ -97,7 +94,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
 	timestampHash := md5.New()
 	common.Must2(timestampHash.Write(hashTimestamp(timestamp)))
 	iv := timestampHash.Sum(nil)
-	aesStream := crypto.NewAesEncryptionStream(account.(*vmess.InternalAccount).ID.CmdKey(), iv)
+	aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv)
 	aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
 	common.Must2(writer.Write(buffer.Bytes()))
 	return nil

+ 12 - 7
proxy/vmess/encoding/encoding_test.go

@@ -7,17 +7,22 @@ import (
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
-	"v2ray.com/core/common/serial"
 	"v2ray.com/core/common/uuid"
 	"v2ray.com/core/proxy/vmess"
 	. "v2ray.com/core/proxy/vmess/encoding"
 	. "v2ray.com/ext/assert"
 )
 
+func toAccount(a *vmess.Account) protocol.Account {
+	account, err := a.AsAccount()
+	common.Must(err)
+	return account
+}
+
 func TestRequestSerialization(t *testing.T) {
 	assert := With(t)
 
-	user := &protocol.User{
+	user := &protocol.MemoryUser{
 		Level: 0,
 		Email: "test@v2ray.com",
 	}
@@ -26,7 +31,7 @@ func TestRequestSerialization(t *testing.T) {
 		Id:      id.String(),
 		AlterId: 0,
 	}
-	user.Account = serial.ToTypedMessage(account)
+	user.Account = toAccount(account)
 
 	expectedRequest := &protocol.RequestHeader{
 		Version:  1,
@@ -70,7 +75,7 @@ func TestRequestSerialization(t *testing.T) {
 func TestInvalidRequest(t *testing.T) {
 	assert := With(t)
 
-	user := &protocol.User{
+	user := &protocol.MemoryUser{
 		Level: 0,
 		Email: "test@v2ray.com",
 	}
@@ -79,7 +84,7 @@ func TestInvalidRequest(t *testing.T) {
 		Id:      id.String(),
 		AlterId: 0,
 	}
-	user.Account = serial.ToTypedMessage(account)
+	user.Account = toAccount(account)
 
 	expectedRequest := &protocol.RequestHeader{
 		Version:  1,
@@ -112,7 +117,7 @@ func TestInvalidRequest(t *testing.T) {
 func TestMuxRequest(t *testing.T) {
 	assert := With(t)
 
-	user := &protocol.User{
+	user := &protocol.MemoryUser{
 		Level: 0,
 		Email: "test@v2ray.com",
 	}
@@ -121,7 +126,7 @@ func TestMuxRequest(t *testing.T) {
 		Id:      id.String(),
 		AlterId: 0,
 	}
-	user.Account = serial.ToTypedMessage(account)
+	user.Account = toAccount(account)
 
 	expectedRequest := &protocol.RequestHeader{
 		Version:  1,

+ 1 - 5
proxy/vmess/encoding/server.go

@@ -139,11 +139,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 	timestampHash := md5.New()
 	common.Must2(timestampHash.Write(hashTimestamp(timestamp)))
 	iv := timestampHash.Sum(nil)
-	account, err := user.GetTypedAccount()
-	if err != nil {
-		return nil, newError("failed to get user account").Base(err)
-	}
-	vmessAccount := account.(*vmess.InternalAccount)
+	vmessAccount := user.Account.(*vmess.InternalAccount)
 
 	aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv)
 	decryptor := crypto.NewCryptionReader(aesStream, reader)

+ 24 - 19
proxy/vmess/inbound/inbound.go

@@ -16,7 +16,6 @@ import (
 	"v2ray.com/core/common/log"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
-	"v2ray.com/core/common/serial"
 	"v2ray.com/core/common/session"
 	"v2ray.com/core/common/signal"
 	"v2ray.com/core/common/task"
@@ -29,20 +28,20 @@ import (
 
 type userByEmail struct {
 	sync.Mutex
-	cache           map[string]*protocol.User
+	cache           map[string]*protocol.MemoryUser
 	defaultLevel    uint32
 	defaultAlterIDs uint16
 }
 
 func newUserByEmail(config *DefaultConfig) *userByEmail {
 	return &userByEmail{
-		cache:           make(map[string]*protocol.User),
+		cache:           make(map[string]*protocol.MemoryUser),
 		defaultLevel:    config.Level,
 		defaultAlterIDs: uint16(config.AlterId),
 	}
 }
 
-func (v *userByEmail) addNoLock(u *protocol.User) bool {
+func (v *userByEmail) addNoLock(u *protocol.MemoryUser) bool {
 	email := strings.ToLower(u.Email)
 	user, found := v.cache[email]
 	if found {
@@ -52,14 +51,14 @@ func (v *userByEmail) addNoLock(u *protocol.User) bool {
 	return true
 }
 
-func (v *userByEmail) Add(u *protocol.User) bool {
+func (v *userByEmail) Add(u *protocol.MemoryUser) bool {
 	v.Lock()
 	defer v.Unlock()
 
 	return v.addNoLock(u)
 }
 
-func (v *userByEmail) Get(email string) (*protocol.User, bool) {
+func (v *userByEmail) Get(email string) (*protocol.MemoryUser, bool) {
 	email = strings.ToLower(email)
 
 	v.Lock()
@@ -68,14 +67,16 @@ func (v *userByEmail) Get(email string) (*protocol.User, bool) {
 	user, found := v.cache[email]
 	if !found {
 		id := uuid.New()
-		account := &vmess.Account{
+		rawAccount := &vmess.Account{
 			Id:      id.String(),
 			AlterId: uint32(v.defaultAlterIDs),
 		}
-		user = &protocol.User{
+		account, err := rawAccount.AsAccount()
+		common.Must(err)
+		user = &protocol.MemoryUser{
 			Level:   v.defaultLevel,
 			Email:   email,
-			Account: serial.ToTypedMessage(account),
+			Account: account,
 		}
 		v.cache[email] = user
 	}
@@ -120,7 +121,12 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
 	}
 
 	for _, user := range config.User {
-		if err := handler.AddUser(ctx, user); err != nil {
+		mUser, err := user.ToMemoryUser()
+		if err != nil {
+			return nil, newError("failed to get VMess user").Base(err)
+		}
+
+		if err := handler.AddUser(ctx, mUser); err != nil {
 			return nil, newError("failed to initiate user").Base(err)
 		}
 	}
@@ -130,10 +136,9 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
 
 // Close implements common.Closable.
 func (h *Handler) Close() error {
-	common.Close(h.clients)
-	common.Close(h.sessionHistory)
-	common.Close(h.usersByEmail)
-	return nil
+	return task.Run(
+		task.SequentialAll(
+			task.Close(h.clients), task.Close(h.sessionHistory), task.Close(h.usersByEmail)))()
 }
 
 // Network implements proxy.Inbound.Network().
@@ -143,7 +148,7 @@ func (*Handler) Network() net.NetworkList {
 	}
 }
 
-func (h *Handler) GetUser(email string) *protocol.User {
+func (h *Handler) GetUser(email string) *protocol.MemoryUser {
 	user, existing := h.usersByEmail.Get(email)
 	if !existing {
 		h.clients.Add(user)
@@ -151,7 +156,7 @@ func (h *Handler) GetUser(email string) *protocol.User {
 	return user
 }
 
-func (h *Handler) AddUser(ctx context.Context, user *protocol.User) error {
+func (h *Handler) AddUser(ctx context.Context, user *protocol.MemoryUser) error {
 	if len(user.Email) > 0 && !h.usersByEmail.Add(user) {
 		return newError("User ", user.Email, " already exists.")
 	}
@@ -325,11 +330,11 @@ func (h *Handler) generateCommand(ctx context.Context, request *protocol.Request
 				if user == nil {
 					return nil
 				}
-				account, _ := user.GetTypedAccount()
+				account := user.Account.(*vmess.InternalAccount)
 				return &protocol.CommandSwitchAccount{
 					Port:     port,
-					ID:       account.(*vmess.InternalAccount).ID.UUID(),
-					AlterIds: uint16(len(account.(*vmess.InternalAccount).AlterIDs)),
+					ID:       account.ID.UUID(),
+					AlterIds: uint16(len(account.AlterIDs)),
 					Level:    user.Level,
 					ValidMin: byte(availableMin),
 				}

+ 6 - 4
proxy/vmess/outbound/command.go

@@ -3,14 +3,14 @@ package outbound
 import (
 	"time"
 
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
-	"v2ray.com/core/common/serial"
 	"v2ray.com/core/proxy/vmess"
 )
 
 func (h *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) {
-	account := &vmess.Account{
+	rawAccount := &vmess.Account{
 		Id:      cmd.ID.String(),
 		AlterId: uint32(cmd.AlterIds),
 		SecuritySettings: &protocol.SecurityConfig{
@@ -18,10 +18,12 @@ func (h *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) {
 		},
 	}
 
-	user := &protocol.User{
+	account, err := rawAccount.AsAccount()
+	common.Must(err)
+	user := &protocol.MemoryUser{
 		Email:   "",
 		Level:   cmd.Level,
-		Account: serial.ToTypedMessage(account),
+		Account: account,
 	}
 	dest := net.TCPDestination(cmd.Host, cmd.Port)
 	until := time.Now().Add(time.Duration(cmd.ValidMin) * time.Minute)

+ 6 - 6
proxy/vmess/outbound/outbound.go

@@ -33,7 +33,11 @@ type Handler struct {
 func New(ctx context.Context, config *Config) (*Handler, error) {
 	serverList := protocol.NewServerList()
 	for _, rec := range config.Receiver {
-		serverList.AddServer(protocol.NewServerSpecFromPB(*rec))
+		s, err := protocol.NewServerSpecFromPB(*rec)
+		if err != nil {
+			return nil, newError("failed to parse server spec").Base(err)
+		}
+		serverList.AddServer(s)
 	}
 	handler := &Handler{
 		serverList:   serverList,
@@ -87,11 +91,7 @@ func (v *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia
 		Option:  protocol.RequestOptionChunkStream,
 	}
 
-	rawAccount, err := request.User.GetTypedAccount()
-	if err != nil {
-		return newError("failed to get user account").Base(err).AtWarning()
-	}
-	account := rawAccount.(*vmess.InternalAccount)
+	account := request.User.Account.(*vmess.InternalAccount)
 	request.Security = account.Security
 
 	if request.Security == protocol.SecurityType_AES128_GCM || request.Security == protocol.SecurityType_NONE || request.Security == protocol.SecurityType_CHACHA20_POLY1305 {

+ 7 - 13
proxy/vmess/vmess.go

@@ -23,8 +23,7 @@ const (
 )
 
 type user struct {
-	user    *protocol.User
-	account *InternalAccount
+	user    *protocol.MemoryUser
 	lastSec protocol.Timestamp
 }
 
@@ -80,8 +79,10 @@ func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *
 		}
 	}
 
-	genHashForID(user.account.ID)
-	for _, id := range user.account.AlterIDs {
+	account := user.user.Account.(*InternalAccount)
+
+	genHashForID(account.ID)
+	for _, id := range account.AlterIDs {
 		genHashForID(id)
 	}
 	user.lastSec = nowSec
@@ -111,21 +112,14 @@ func (v *TimedUserValidator) updateUserHash() {
 	}
 }
 
-func (v *TimedUserValidator) Add(u *protocol.User) error {
+func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error {
 	v.Lock()
 	defer v.Unlock()
 
-	rawAccount, err := u.GetTypedAccount()
-	if err != nil {
-		return err
-	}
-	account := rawAccount.(*InternalAccount)
-
 	nowSec := time.Now().Unix()
 
 	uu := &user{
 		user:    u,
-		account: account,
 		lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
 	}
 	v.users = append(v.users, uu)
@@ -134,7 +128,7 @@ func (v *TimedUserValidator) Add(u *protocol.User) error {
 	return nil
 }
 
-func (v *TimedUserValidator) Get(userHash []byte) (*protocol.User, protocol.Timestamp, bool) {
+func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool) {
 	defer v.RUnlock()
 	v.RLock()