Browse Source

merge user info inbound metadata

Darien Raymond 7 years ago
parent
commit
595f3d685e

+ 6 - 1
app/dispatcher/default.go

@@ -133,7 +133,12 @@ func (d *DefaultDispatcher) getLink(ctx context.Context) (*vio.Link, *vio.Link)
 		Writer: downlinkWriter,
 		Writer: downlinkWriter,
 	}
 	}
 
 
-	user := protocol.UserFromContext(ctx)
+	sessionInbound := session.InboundFromContext(ctx)
+	var user *protocol.MemoryUser
+	if sessionInbound != nil {
+		user = sessionInbound.User
+	}
+
 	if user != nil && len(user.Email) > 0 {
 	if user != nil && len(user.Email) > 0 {
 		p := d.policy.ForLevel(user.Level)
 		p := d.policy.ForLevel(user.Level)
 		if p.Stats.UserUplink {
 		if p.Stats.UserUplink {

+ 6 - 2
app/router/condition.go

@@ -8,7 +8,6 @@ import (
 
 
 	"v2ray.com/core/app/dispatcher"
 	"v2ray.com/core/app/dispatcher"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
-	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/strmatcher"
 	"v2ray.com/core/common/strmatcher"
 	"v2ray.com/core/proxy"
 	"v2ray.com/core/proxy"
 )
 )
@@ -282,7 +281,12 @@ func NewUserMatcher(users []string) *UserMatcher {
 }
 }
 
 
 func (v *UserMatcher) Apply(ctx context.Context) bool {
 func (v *UserMatcher) Apply(ctx context.Context) bool {
-	user := protocol.UserFromContext(ctx)
+	inbound := session.InboundFromContext(ctx)
+	if inbound == nil {
+		return false
+	}
+
+	user := inbound.User
 	if user == nil {
 	if user == nil {
 		return false
 		return false
 	}
 	}

+ 6 - 2
app/router/condition_test.go

@@ -27,6 +27,10 @@ func withOutbound(outbound *session.Outbound) context.Context {
 	return session.ContextWithOutbound(context.Background(), outbound)
 	return session.ContextWithOutbound(context.Background(), outbound)
 }
 }
 
 
+func withInbound(inbound *session.Inbound) context.Context {
+	return session.ContextWithInbound(context.Background(), inbound)
+}
+
 func TestRoutingRule(t *testing.T) {
 func TestRoutingRule(t *testing.T) {
 	assert := With(t)
 	assert := With(t)
 
 
@@ -131,11 +135,11 @@ func TestRoutingRule(t *testing.T) {
 			},
 			},
 			test: []ruleTest{
 			test: []ruleTest{
 				{
 				{
-					input:  protocol.ContextWithUser(context.Background(), &protocol.MemoryUser{Email: "admin@v2ray.com"}),
+					input:  withInbound(&session.Inbound{User: &protocol.MemoryUser{Email: "admin@v2ray.com"}}),
 					output: true,
 					output: true,
 				},
 				},
 				{
 				{
-					input:  protocol.ContextWithUser(context.Background(), &protocol.MemoryUser{Email: "love@v2ray.com"}),
+					input:  withInbound(&session.Inbound{User: &protocol.MemoryUser{Email: "love@v2ray.com"}}),
 					output: false,
 					output: false,
 				},
 				},
 				{
 				{

+ 1 - 16
common/protocol/context.go

@@ -7,24 +7,9 @@ import (
 type key int
 type key int
 
 
 const (
 const (
-	userKey key = iota
-	requestKey
+	requestKey key = iota
 )
 )
 
 
-// ContextWithUser returns a context combined with a User.
-func ContextWithUser(ctx context.Context, user *MemoryUser) context.Context {
-	return context.WithValue(ctx, userKey, user)
-}
-
-// UserFromContext extracts a User from the given context, if any.
-func UserFromContext(ctx context.Context) *MemoryUser {
-	v := ctx.Value(userKey)
-	if v == nil {
-		return nil
-	}
-	return v.(*MemoryUser)
-}
-
 func ContextWithRequestHeader(ctx context.Context, request *RequestHeader) context.Context {
 func ContextWithRequestHeader(ctx context.Context, request *RequestHeader) context.Context {
 	return context.WithValue(ctx, requestKey, request)
 	return context.WithValue(ctx, requestKey, request)
 }
 }

+ 2 - 0
common/protocol/user.go

@@ -30,7 +30,9 @@ func (u *User) ToMemoryUser() (*MemoryUser, error) {
 	}, nil
 	}, nil
 }
 }
 
 
+// MemoryUser is a parsed form of User, to reduce number of parsing of Account proto.
 type MemoryUser struct {
 type MemoryUser struct {
+	// Account is the parsed account of the protocol.
 	Account Account
 	Account Account
 	Email   string
 	Email   string
 	Level   uint32
 	Level   uint32

+ 3 - 0
common/session/session.go

@@ -7,6 +7,7 @@ import (
 
 
 	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/protocol"
 )
 )
 
 
 // ID of a session.
 // ID of a session.
@@ -34,6 +35,8 @@ type Inbound struct {
 	Source  net.Destination
 	Source  net.Destination
 	Gateway net.Destination
 	Gateway net.Destination
 	Tag     string
 	Tag     string
+	// User is the user that authencates for the inbound. May be nil if the protocol allows anounymous traffic.
+	User *protocol.MemoryUser
 }
 }
 
 
 type Outbound struct {
 type Outbound struct {

+ 12 - 4
proxy/shadowsocks/server.go

@@ -89,6 +89,11 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 	})
 	})
 
 
 	account := s.user.Account.(*MemoryAccount)
 	account := s.user.Account.(*MemoryAccount)
+	inbound := session.InboundFromContext(ctx)
+	if inbound == nil {
+		panic("no inbound metadata")
+	}
+	inbound.User = s.user
 
 
 	reader := buf.NewReader(conn)
 	reader := buf.NewReader(conn)
 	for {
 	for {
@@ -126,7 +131,7 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 			}
 			}
 
 
 			dest := request.Destination()
 			dest := request.Destination()
-			if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {
+			if inbound.Source.IsValid() {
 				log.Record(&log.AccessMessage{
 				log.Record(&log.AccessMessage{
 					From:   inbound.Source,
 					From:   inbound.Source,
 					To:     dest,
 					To:     dest,
@@ -136,7 +141,6 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 			}
 			}
 			newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx))
 			newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx))
 
 
-			ctx = protocol.ContextWithUser(ctx, request.User)
 			ctx = protocol.ContextWithRequestHeader(ctx, request)
 			ctx = protocol.ContextWithRequestHeader(ctx, request)
 			udpServer.Dispatch(ctx, dest, data)
 			udpServer.Dispatch(ctx, dest, data)
 		}
 		}
@@ -162,6 +166,12 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 	}
 	}
 	conn.SetReadDeadline(time.Time{})
 	conn.SetReadDeadline(time.Time{})
 
 
+	inbound := session.InboundFromContext(ctx)
+	if inbound == nil {
+		panic("no inbound metadata")
+	}
+	inbound.User = s.user
+
 	dest := request.Destination()
 	dest := request.Destination()
 	log.Record(&log.AccessMessage{
 	log.Record(&log.AccessMessage{
 		From:   conn.RemoteAddr(),
 		From:   conn.RemoteAddr(),
@@ -171,8 +181,6 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 	})
 	})
 	newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx))
 	newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx))
 
 
-	ctx = protocol.ContextWithUser(ctx, request.User)
-
 	ctx, cancel := context.WithCancel(ctx)
 	ctx, cancel := context.WithCancel(ctx)
 	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
 	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
 
 

+ 6 - 1
proxy/vmess/inbound/inbound.go

@@ -264,8 +264,13 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
 		newError("unable to set back read deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
 		newError("unable to set back read deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
 	}
 	}
 
 
+	inbound := session.InboundFromContext(ctx)
+	if inbound == nil {
+		panic("no inbound metadata")
+	}
+	inbound.User = request.User
+
 	sessionPolicy = h.policyManager.ForLevel(request.User.Level)
 	sessionPolicy = h.policyManager.ForLevel(request.User.Level)
-	ctx = protocol.ContextWithUser(ctx, request.User)
 
 
 	ctx, cancel := context.WithCancel(ctx)
 	ctx, cancel := context.WithCancel(ctx)
 	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
 	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)