Bläddra i källkod

support mtproto conn type 0xee. fixes #1297

Darien Raymond 7 år sedan
förälder
incheckning
2e94561584
4 ändrade filer med 62 tillägg och 11 borttagningar
  1. 40 7
      proxy/mtproto/auth.go
  2. 1 1
      proxy/mtproto/auth_test.go
  3. 2 1
      proxy/mtproto/client.go
  4. 19 2
      proxy/mtproto/server.go

+ 40 - 7
proxy/mtproto/auth.go

@@ -1,6 +1,7 @@
 package mtproto
 
 import (
+	"context"
 	"crypto/rand"
 	"crypto/sha256"
 	"io"
@@ -13,6 +14,35 @@ const (
 	HeaderSize = 64
 )
 
+type SessionContext struct {
+	ConnectionType [4]byte
+	DataCenterID   uint16
+}
+
+func DefaultSessionContext() SessionContext {
+	return SessionContext{
+		ConnectionType: [4]byte{0xef, 0xef, 0xef, 0xef},
+		DataCenterID:   0,
+	}
+}
+
+type contextKey int32
+
+const (
+	sessionContextKey contextKey = iota
+)
+
+func ContextWithSessionContext(ctx context.Context, c SessionContext) context.Context {
+	return context.WithValue(ctx, sessionContextKey, c)
+}
+
+func SessionContextFromContext(ctx context.Context) SessionContext {
+	if c := ctx.Value(sessionContextKey); c != nil {
+		return c.(SessionContext)
+	}
+	return DefaultSessionContext()
+}
+
 type Authentication struct {
 	Header        [HeaderSize]byte
 	DecodingKey   [32]byte
@@ -29,12 +59,18 @@ func (a *Authentication) DataCenterID() uint16 {
 	return uint16(x) - 1
 }
 
+func (a *Authentication) ConnectionType() [4]byte {
+	var x [4]byte
+	copy(x[:], a.Header[56:60])
+	return x
+}
+
 func (a *Authentication) ApplySecret(b []byte) {
 	a.DecodingKey = sha256.Sum256(append(a.DecodingKey[:], b...))
 	a.EncodingKey = sha256.Sum256(append(a.EncodingKey[:], b...))
 }
 
-func generateRandomBytes(random []byte) {
+func generateRandomBytes(random []byte, connType [4]byte) {
 	for {
 		common.Must2(rand.Read(random))
 
@@ -51,19 +87,16 @@ func generateRandomBytes(random []byte) {
 			continue
 		}
 
-		random[56] = 0xef
-		random[57] = 0xef
-		random[58] = 0xef
-		random[59] = 0xef
+		copy(random[56:60], connType[:])
 
 		return
 	}
 }
 
-func NewAuthentication() *Authentication {
+func NewAuthentication(sc SessionContext) *Authentication {
 	auth := getAuthenticationObject()
 	random := auth.Header[:]
-	generateRandomBytes(random)
+	generateRandomBytes(random, sc.ConnectionType)
 	copy(auth.EncodingKey[:], random[8:])
 	copy(auth.EncodingNonce[:], random[8+32:])
 	keyivInverse := Inverse(random[8 : 8+32+16])

+ 1 - 1
proxy/mtproto/auth_test.go

@@ -32,7 +32,7 @@ func TestInverse(t *testing.T) {
 func TestAuthenticationReadWrite(t *testing.T) {
 	assert := With(t)
 
-	a := NewAuthentication()
+	a := NewAuthentication(DefaultSessionContext())
 	b := bytes.NewReader(a.Header[:])
 	a2, err := ReadAuthentication(b)
 	assert(err, IsNil)

+ 2 - 1
proxy/mtproto/client.go

@@ -36,7 +36,8 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial
 	}
 	defer conn.Close() // nolint: errcheck
 
-	auth := NewAuthentication()
+	sc := SessionContextFromContext(ctx)
+	auth := NewAuthentication(sc)
 	defer putAuthenticationObject(auth)
 
 	request := func() error {

+ 19 - 2
proxy/mtproto/server.go

@@ -64,6 +64,16 @@ func (s *Server) Network() net.NetworkList {
 	}
 }
 
+func isValidConnectionType(c [4]byte) bool {
+	if compare.BytesAll(c[:], 0xef) {
+		return true
+	}
+	if compare.BytesAll(c[:], 0xee) {
+		return true
+	}
+	return false
+}
+
 func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher core.Dispatcher) error {
 	sPolicy := s.policy.ForLevel(s.user.Level)
 
@@ -85,8 +95,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
 	decryptor := crypto.NewAesCTRStream(auth.DecodingKey[:], auth.DecodingNonce[:])
 	decryptor.XORKeyStream(auth.Header[:], auth.Header[:])
 
-	if !compare.BytesAll(auth.Header[56:60], 0xef) {
-		return newError("invalid connection type: ", auth.Header[56:60])
+	ct := auth.ConnectionType()
+	if !isValidConnectionType(ct) {
+		return newError("invalid connection type: ", ct)
 	}
 
 	dcID := auth.DataCenterID()
@@ -104,6 +115,12 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
 	timer := signal.CancelAfterInactivity(ctx, cancel, sPolicy.Timeouts.ConnectionIdle)
 	ctx = core.ContextWithBufferPolicy(ctx, sPolicy.Buffer)
 
+	sc := SessionContext{
+		ConnectionType: ct,
+		DataCenterID:   dcID,
+	}
+	ctx = ContextWithSessionContext(ctx, sc)
+
 	link, err := dispatcher.Dispatch(ctx, dest)
 	if err != nil {
 		return newError("failed to dispatch request to: ", dest).Base(err)