Parcourir la source

detect actual address type for domain address type

Darien Raymond il y a 8 ans
Parent
commit
6363c33790
3 fichiers modifiés avec 32 ajouts et 54 suppressions
  1. 25 44
      common/errors/errors.go
  2. 3 3
      proxy/socks/protocol.go
  3. 4 7
      proxy/socks/server.go

+ 25 - 44
common/errors/errors.go

@@ -24,33 +24,49 @@ type Error struct {
 }
 
 // Error implements error.Error().
-func (v *Error) Error() string {
+func (v Error) Error() string {
 	return v.message
 }
 
 // Inner implements hasInnerError.Inner()
-func (v *Error) Inner() error {
+func (v Error) Inner() error {
 	if v.inner == nil {
 		return nil
 	}
 	return v.inner
 }
 
-func (v *Error) ActionRequired() bool {
+func (v Error) ActionRequired() bool {
 	return v.actionRequired
 }
 
+func (v Error) RequireUserAction() Error {
+	v.actionRequired = true
+	return v
+}
+
+func (v Error) Message(msg ...interface{}) Error {
+	return Error{
+		inner:   v,
+		message: serial.Concat(msg...),
+	}
+}
+
+func (v Error) Format(format string, values ...interface{}) Error {
+	return v.Message(fmt.Sprintf(format, values...))
+}
+
 // New returns a new error object with message formed from given arguments.
-func New(msg ...interface{}) error {
-	return &Error{
+func New(msg ...interface{}) Error {
+	return Error{
 		message: serial.Concat(msg...),
 	}
 }
 
-// Base returns an ErrorBuilder based on the given error.
-func Base(err error) ErrorBuilder {
-	return ErrorBuilder{
-		error: err,
+// Base returns an Error based on the given error.
+func Base(err error) Error {
+	return Error{
+		inner: err,
 	}
 }
 
@@ -86,38 +102,3 @@ func IsActionRequired(err error) bool {
 	}
 	return false
 }
-
-type ErrorBuilder struct {
-	error
-	actionRequired bool
-}
-
-func (v ErrorBuilder) RequireUserAction() ErrorBuilder {
-	v.actionRequired = true
-	return v
-}
-
-// Message returns an error object with given message and base error.
-func (v ErrorBuilder) Message(msg ...interface{}) error {
-	if v.error == nil {
-		return nil
-	}
-
-	return &Error{
-		message:        serial.Concat(msg...) + " > " + v.error.Error(),
-		inner:          v.error,
-		actionRequired: v.actionRequired,
-	}
-}
-
-// Format returns an errors object with given message format and base error.
-func (v ErrorBuilder) Format(format string, values ...interface{}) error {
-	if v.error == nil {
-		return nil
-	}
-	return &Error{
-		message:        fmt.Sprintf(format, values...) + " > " + v.error.Error(),
-		inner:          v.error,
-		actionRequired: v.actionRequired,
-	}
-}

+ 3 - 3
proxy/socks/protocol.go

@@ -159,7 +159,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
 			if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil {
 				return nil, err
 			}
-			request.Address = v2net.DomainAddress(string(buffer.BytesFrom(-domainLength)))
+			request.Address = v2net.ParseAddress(string(buffer.BytesFrom(-domainLength)))
 		default:
 			return nil, errors.New("Socks|Server: Unknown address type: ", addrType)
 		}
@@ -400,10 +400,10 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 	}
 
 	if b.Byte(0) != socks5Version {
-		return nil, errors.New("Socks|Client: Unexpected server version: ", b.Byte(0))
+		return nil, errors.New("Socks|Client: Unexpected server version: ", b.Byte(0)).RequireUserAction()
 	}
 	if b.Byte(1) != authByte {
-		return nil, errors.New("Socks|Client: auth method not supported.")
+		return nil, errors.New("Socks|Client: auth method not supported.").RequireUserAction()
 	}
 
 	if authByte == authPassword {

+ 4 - 7
proxy/socks/server.go

@@ -29,7 +29,7 @@ type Server struct {
 func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
 	space := app.SpaceFromContext(ctx)
 	if space == nil {
-		return nil, errors.New("Socks|Server: No space in context.")
+		return nil, errors.New("Socks|Server: No space in context.").RequireUserAction()
 	}
 	s := &Server{
 		config: config,
@@ -130,8 +130,7 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 
 		v2reader := buf.NewReader(reader)
 		if err := buf.PipeUntilEOF(timer, v2reader, input); err != nil {
-			log.Info("Socks|Server: Failed to transport all TCP request: ", err)
-			return err
+			return errors.Base(err).Message("Socks|Server: Failed to transport all TCP request.")
 		}
 		return nil
 	})
@@ -139,17 +138,15 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 	responseDone := signal.ExecuteAsync(func() error {
 		v2writer := buf.NewWriter(writer)
 		if err := buf.PipeUntilEOF(timer, output, v2writer); err != nil {
-			log.Info("Socks|Server: Failed to transport all TCP response: ", err)
-			return err
+			return errors.Base(err).Message("Socks|Server: Failed to transport all TCP response.")
 		}
 		return nil
 	})
 
 	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
-		log.Info("Socks|Server: Connection ends with ", err)
 		input.CloseError()
 		output.CloseError()
-		return err
+		return errors.Base(err).Message("Socks|Server: Connection ends.")
 	}
 
 	runtime.KeepAlive(timer)