Browse Source

Refactor socks request handling

V2Ray 10 years ago
parent
commit
ddad439614
2 changed files with 123 additions and 93 deletions
  1. 9 4
      common/net/transport.go
  2. 114 89
      proxy/socks/socks.go

+ 9 - 4
common/net/transport.go

@@ -8,13 +8,18 @@ const (
 	bufferSize = 4 * 1024
 	bufferSize = 4 * 1024
 )
 )
 
 
+func ReadFrom(reader io.Reader) ([]byte, error) {
+	buffer := make([]byte, bufferSize)
+	nBytes, err := reader.Read(buffer)
+	return buffer[:nBytes], err
+}
+
 // ReaderToChan dumps all content from a given reader to a chan by constantly reading it until EOF.
 // ReaderToChan dumps all content from a given reader to a chan by constantly reading it until EOF.
 func ReaderToChan(stream chan<- []byte, reader io.Reader) error {
 func ReaderToChan(stream chan<- []byte, reader io.Reader) error {
 	for {
 	for {
-		buffer := make([]byte, bufferSize)
-		nBytes, err := reader.Read(buffer)
-		if nBytes > 0 {
-			stream <- buffer[:nBytes]
+		data, err := ReadFrom(reader)
+		if len(data) > 0 {
+			stream <- data
 		}
 		}
 		if err != nil {
 		if err != nil {
 			return err
 			return err

+ 114 - 89
proxy/socks/socks.go

@@ -41,6 +41,9 @@ func (server *SocksServer) Listen(port uint16) error {
 	}
 	}
 	server.accepting = true
 	server.accepting = true
 	go server.AcceptConnections(listener)
 	go server.AcceptConnections(listener)
+	if server.config.UDPEnabled {
+		server.ListenUDP(port)
+	}
 	return nil
 	return nil
 }
 }
 
 
@@ -66,120 +69,142 @@ func (server *SocksServer) HandleConnection(connection net.Conn) error {
 		return err
 		return err
 	}
 	}
 
 
-	var dest v2net.Destination
-
-	// TODO refactor this part
-	if errors.HasCode(err, 1000) {
-		result := protocol.Socks4RequestGranted
-		if auth4.Command == protocol.CmdBind {
-			result = protocol.Socks4RequestRejected
-		}
-		socks4Response := protocol.NewSocks4AuthenticationResponse(result, auth4.Port, auth4.IP[:])
-		connection.Write(socks4Response.ToBytes(nil))
-
-		if result == protocol.Socks4RequestRejected {
-			return errors.NewInvalidOperationError("Socks4 command " + strconv.Itoa(int(auth4.Command)))
-		}
-
-		dest = v2net.NewTCPDestination(v2net.IPAddress(auth4.IP[:], auth4.Port))
+	if err != nil && errors.HasCode(err, 1000) {
+		return server.handleSocks4(reader, connection, auth4)
 	} else {
 	} else {
-		expectedAuthMethod := protocol.AuthNotRequired
-		if server.config.IsPassword() {
-			expectedAuthMethod = protocol.AuthUserPass
-		}
+		return server.handleSocks5(reader, connection, auth)
+	}
+}
 
 
-		if !auth.HasAuthMethod(expectedAuthMethod) {
-			authResponse := protocol.NewAuthenticationResponse(protocol.AuthNoMatchingMethod)
-			err = protocol.WriteAuthentication(connection, authResponse)
-			if err != nil {
-				log.Error("Socks failed to write authentication: %v", err)
-				return err
-			}
-			log.Warning("Socks client doesn't support allowed any auth methods.")
-			return errors.NewInvalidOperationError("Unsupported auth methods.")
-		}
+func (server *SocksServer) handleSocks5(reader io.Reader, writer io.Writer, auth protocol.Socks5AuthenticationRequest) error {
+	expectedAuthMethod := protocol.AuthNotRequired
+	if server.config.IsPassword() {
+		expectedAuthMethod = protocol.AuthUserPass
+	}
 
 
-		authResponse := protocol.NewAuthenticationResponse(expectedAuthMethod)
-		err = protocol.WriteAuthentication(connection, authResponse)
+	if !auth.HasAuthMethod(expectedAuthMethod) {
+		authResponse := protocol.NewAuthenticationResponse(protocol.AuthNoMatchingMethod)
+		err := protocol.WriteAuthentication(writer, authResponse)
 		if err != nil {
 		if err != nil {
 			log.Error("Socks failed to write authentication: %v", err)
 			log.Error("Socks failed to write authentication: %v", err)
 			return err
 			return err
 		}
 		}
-		if server.config.IsPassword() {
-			upRequest, err := protocol.ReadUserPassRequest(reader)
-			if err != nil {
-				log.Error("Socks failed to read username and password: %v", err)
-				return err
-			}
-			status := byte(0)
-			if !upRequest.IsValid(server.config.Username, server.config.Password) {
-				status = byte(0xFF)
-			}
-			upResponse := protocol.NewSocks5UserPassResponse(status)
-			err = protocol.WriteUserPassResponse(connection, upResponse)
-			if err != nil {
-				log.Error("Socks failed to write user pass response: %v", err)
-				return err
-			}
-			if status != byte(0) {
-				err = errors.NewAuthenticationError(upRequest.AuthDetail())
-				log.Warning(err.Error())
-				return err
-			}
-		}
+		log.Warning("Socks client doesn't support allowed any auth methods.")
+		return errors.NewInvalidOperationError("Unsupported auth methods.")
+	}
 
 
-		request, err := protocol.ReadRequest(reader)
+	authResponse := protocol.NewAuthenticationResponse(expectedAuthMethod)
+	err := protocol.WriteAuthentication(writer, authResponse)
+	if err != nil {
+		log.Error("Socks failed to write authentication: %v", err)
+		return err
+	}
+	if server.config.IsPassword() {
+		upRequest, err := protocol.ReadUserPassRequest(reader)
 		if err != nil {
 		if err != nil {
-			log.Error("Socks failed to read request: %v", err)
+			log.Error("Socks failed to read username and password: %v", err)
 			return err
 			return err
 		}
 		}
+		status := byte(0)
+		if !upRequest.IsValid(server.config.Username, server.config.Password) {
+			status = byte(0xFF)
+		}
+		upResponse := protocol.NewSocks5UserPassResponse(status)
+		err = protocol.WriteUserPassResponse(writer, upResponse)
+		if err != nil {
+			log.Error("Socks failed to write user pass response: %v", err)
+			return err
+		}
+		if status != byte(0) {
+			err = errors.NewAuthenticationError(upRequest.AuthDetail())
+			log.Warning(err.Error())
+			return err
+		}
+	}
 
 
-		response := protocol.NewSocks5Response()
+	request, err := protocol.ReadRequest(reader)
+	if err != nil {
+		log.Error("Socks failed to read request: %v", err)
+		return err
+	}
 
 
-		if request.Command == protocol.CmdBind || request.Command == protocol.CmdUdpAssociate {
-			response := protocol.NewSocks5Response()
-			response.Error = protocol.ErrorCommandNotSupported
-			err = protocol.WriteResponse(connection, response)
-			if err != nil {
-				log.Error("Socks failed to write response: %v", err)
-				return err
-			}
-			log.Warning("Unsupported socks command %d", request.Command)
-			return errors.NewInvalidOperationError("Socks command " + strconv.Itoa(int(request.Command)))
-		}
+	response := protocol.NewSocks5Response()
 
 
-		response.Error = protocol.ErrorSuccess
-		response.Port = request.Port
-		response.AddrType = request.AddrType
-		switch response.AddrType {
-		case protocol.AddrTypeIPv4:
-			copy(response.IPv4[:], request.IPv4[:])
-		case protocol.AddrTypeIPv6:
-			copy(response.IPv6[:], request.IPv6[:])
-		case protocol.AddrTypeDomain:
-			response.Domain = request.Domain
-		}
-		err = protocol.WriteResponse(connection, response)
+	if request.Command == protocol.CmdBind || (!server.config.UDPEnabled && request.Command == protocol.CmdUdpAssociate) {
+		response := protocol.NewSocks5Response()
+		response.Error = protocol.ErrorCommandNotSupported
+		err = protocol.WriteResponse(writer, response)
 		if err != nil {
 		if err != nil {
 			log.Error("Socks failed to write response: %v", err)
 			log.Error("Socks failed to write response: %v", err)
 			return err
 			return err
 		}
 		}
+		log.Warning("Unsupported socks command %d", request.Command)
+		return errors.NewInvalidOperationError("Socks command " + strconv.Itoa(int(request.Command)))
+	}
+
+	response.Error = protocol.ErrorSuccess
+	response.Port = request.Port
+	response.AddrType = request.AddrType
+	switch response.AddrType {
+	case protocol.AddrTypeIPv4:
+		copy(response.IPv4[:], request.IPv4[:])
+	case protocol.AddrTypeIPv6:
+		copy(response.IPv6[:], request.IPv6[:])
+	case protocol.AddrTypeDomain:
+		response.Domain = request.Domain
+	}
+	err = protocol.WriteResponse(writer, response)
+	if err != nil {
+		log.Error("Socks failed to write response: %v", err)
+		return err
+	}
+
+	dest := request.Destination()
+	data, err := v2net.ReadFrom(reader)
+	if err != nil {
+		return err
+	}
 
 
-		dest = request.Destination()
+	packet := v2net.NewPacket(dest, data, true)
+	server.transport(reader, writer, packet)
+	return nil
+}
+
+func (server *SocksServer) handleSocks4(reader io.Reader, writer io.Writer, auth protocol.Socks4AuthenticationRequest) error {
+	result := protocol.Socks4RequestGranted
+	if auth.Command == protocol.CmdBind {
+		result = protocol.Socks4RequestRejected
 	}
 	}
+	socks4Response := protocol.NewSocks4AuthenticationResponse(result, auth.Port, auth.IP[:])
+	writer.Write(socks4Response.ToBytes(nil))
 
 
-	ray := server.vPoint.DispatchToOutbound(v2net.NewPacket(dest, nil, true))
+	if result == protocol.Socks4RequestRejected {
+		return errors.NewInvalidOperationError("Socks4 command " + strconv.Itoa(int(auth.Command)))
+	}
+
+	dest := v2net.NewTCPDestination(v2net.IPAddress(auth.IP[:], auth.Port))
+	data, err := v2net.ReadFrom(reader)
+	if err != nil {
+		return err
+	}
+
+	packet := v2net.NewPacket(dest, data, true)
+	server.transport(reader, writer, packet)
+	return nil
+}
+
+func (server *SocksServer) transport(reader io.Reader, writer io.Writer, firstPacket v2net.Packet) {
+	ray := server.vPoint.DispatchToOutbound(firstPacket)
 	input := ray.InboundInput()
 	input := ray.InboundInput()
 	output := ray.InboundOutput()
 	output := ray.InboundOutput()
-	var readFinish, writeFinish sync.Mutex
-	readFinish.Lock()
-	writeFinish.Lock()
 
 
-	go dumpInput(reader, input, &readFinish)
-	go dumpOutput(connection, output, &writeFinish)
-	writeFinish.Lock()
+	var inputFinish, outputFinish sync.Mutex
+	inputFinish.Lock()
+	outputFinish.Lock()
 
 
-	return nil
+	go dumpInput(reader, input, &inputFinish)
+	go dumpOutput(writer, output, &outputFinish)
+	outputFinish.Lock()
 }
 }
 
 
 func dumpInput(reader io.Reader, input chan<- []byte, finish *sync.Mutex) {
 func dumpInput(reader io.Reader, input chan<- []byte, finish *sync.Mutex) {