Darien Raymond 8 سال پیش
والد
کامیت
c9c2338f05
4فایلهای تغییر یافته به همراه34 افزوده شده و 49 حذف شده
  1. 24 42
      app/dispatcher/impl/default.go
  2. 5 4
      app/dispatcher/impl/sniffer.go
  3. 4 2
      transport/ray/direct.go
  4. 1 1
      transport/ray/ray.go

+ 24 - 42
app/dispatcher/impl/default.go

@@ -12,6 +12,7 @@ import (
 	"v2ray.com/core/app/proxyman"
 	"v2ray.com/core/app/router"
 	"v2ray.com/core/common"
+	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/proxy"
 	"v2ray.com/core/transport/ray"
@@ -21,11 +22,13 @@ var (
 	errSniffingTimeout = newError("timeout on sniffing")
 )
 
+// DefaultDispatcher is a default implementation of Dispatcher.
 type DefaultDispatcher struct {
 	ohm    proxyman.OutboundHandlerManager
 	router *router.Router
 }
 
+// NewDefaultDispatcher create a new DefaultDispatcher.
 func NewDefaultDispatcher(ctx context.Context, config *dispatcher.Config) (*DefaultDispatcher, error) {
 	space := app.SpaceFromContext(ctx)
 	if space == nil {
@@ -43,21 +46,17 @@ func NewDefaultDispatcher(ctx context.Context, config *dispatcher.Config) (*Defa
 	return d, nil
 }
 
-func (DefaultDispatcher) Start() error {
+func (*DefaultDispatcher) Start() error {
 	return nil
 }
 
-func (DefaultDispatcher) Close() {}
+func (*DefaultDispatcher) Close() {}
 
-func (DefaultDispatcher) Interface() interface{} {
+func (*DefaultDispatcher) Interface() interface{} {
 	return (*dispatcher.Interface)(nil)
 }
 
-type domainOrError struct {
-	domain string
-	err    error
-}
-
+// Dispatch implements Dispatcher.Interface.
 func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (ray.InboundRay, error) {
 	if !destination.IsValid() {
 		panic("Dispatcher: Invalid destination.")
@@ -70,15 +69,13 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
 		go d.routedDispatch(ctx, outbound, destination)
 	} else {
 		go func() {
-			done := make(chan domainOrError)
-			go snifer(ctx, sniferList, outbound, done)
-			de := <-done
-			if de.err != nil {
-				log.Trace(newError("failed to snif").Base(de.err))
+			domain, err := snifer(ctx, sniferList, outbound)
+			if err != nil {
+				log.Trace(newError("failed to snif").Base(err))
 				return
 			}
-			log.Trace(newError("sniffed domain: ", de.domain))
-			destination.Address = net.ParseAddress(de.domain)
+			log.Trace(newError("sniffed domain: ", domain))
+			destination.Address = net.ParseAddress(domain)
 			ctx = proxy.ContextWithTarget(ctx, destination)
 			d.routedDispatch(ctx, outbound, destination)
 		}()
@@ -86,31 +83,24 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
 	return outbound, nil
 }
 
-func snifer(ctx context.Context, sniferList []proxyman.KnownProtocols, outbound ray.OutboundRay, done chan<- domainOrError) {
-	payload := make([]byte, 2048)
+func snifer(ctx context.Context, sniferList []proxyman.KnownProtocols, outbound ray.OutboundRay) (string, error) {
+	payload := buf.New()
+	defer payload.Release()
+
 	totalAttempt := 0
 	for {
 		select {
 		case <-ctx.Done():
-			done <- domainOrError{
-				domain: "",
-				err:    ctx.Err(),
-			}
-			return
+			return "", ctx.Err()
 		case <-time.After(time.Millisecond * 100):
 			totalAttempt++
 			if totalAttempt > 5 {
-				done <- domainOrError{
-					domain: "",
-					err:    errSniffingTimeout,
-				}
-				return
+				return "", errSniffingTimeout
 			}
-			mb := outbound.OutboundInput().Peek()
-			if mb.IsEmpty() {
+			outbound.OutboundInput().Peek(payload)
+			if payload.IsEmpty() {
 				continue
 			}
-			nBytes := mb.Copy(payload)
 			for _, protocol := range sniferList {
 				var f func([]byte) (string, error)
 				switch protocol {
@@ -122,21 +112,13 @@ func snifer(ctx context.Context, sniferList []proxyman.KnownProtocols, outbound
 					panic("Unsupported protocol")
 				}
 
-				domain, err := f(payload[:nBytes])
+				domain, err := f(payload.Bytes())
 				if err != ErrMoreData {
-					done <- domainOrError{
-						domain: domain,
-						err:    err,
-					}
-					return
+					return domain, err
 				}
 			}
-			if nBytes == 2048 {
-				done <- domainOrError{
-					domain: "",
-					err:    ErrInvalidData,
-				}
-				return
+			if payload.IsFull() {
+				return "", ErrInvalidData
 			}
 		}
 	}

+ 5 - 4
app/dispatcher/impl/sniffer.go

@@ -44,7 +44,8 @@ func SniffHTTP(b []byte) (string, error) {
 		key := strings.ToLower(string(parts[0]))
 		value := strings.ToLower(string(bytes.Trim(parts[1], " ")))
 		if key == "host" {
-			return value, nil
+			domain := strings.Split(value, ":")
+			return domain[0], nil
 		}
 	}
 	return "", ErrMoreData
@@ -60,11 +61,11 @@ func ReadClientHello(data []byte) (string, error) {
 	if len(data) < 42 {
 		return "", ErrMoreData
 	}
-	sessionIdLen := int(data[38])
-	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
+	sessionIDLen := int(data[38])
+	if sessionIDLen > 32 || len(data) < 39+sessionIDLen {
 		return "", ErrInvalidData
 	}
-	data = data[39+sessionIdLen:]
+	data = data[39+sessionIDLen:]
 	if len(data) < 2 {
 		return "", ErrMoreData
 	}

+ 4 - 2
transport/ray/direct.go

@@ -75,11 +75,13 @@ func (s *Stream) getData() (buf.MultiBuffer, error) {
 	return nil, nil
 }
 
-func (s *Stream) Peek() buf.MultiBuffer {
+func (s *Stream) Peek(b *buf.Buffer) {
 	s.access.RLock()
 	defer s.access.RUnlock()
 
-	return s.data
+	b.Reset(func(data []byte) (int, error) {
+		return s.data.Copy(data), nil
+	})
 }
 
 func (s *Stream) Read() (buf.MultiBuffer, error) {

+ 1 - 1
transport/ray/ray.go

@@ -42,7 +42,7 @@ type InputStream interface {
 	buf.Reader
 	buf.TimeoutReader
 	RayStream
-	Peek() buf.MultiBuffer
+	Peek(*buf.Buffer)
 }
 
 type OutputStream interface {