Browse Source

refactor socks handshake

Darien Raymond 7 years ago
parent
commit
f2f67132a7
1 changed files with 141 additions and 107 deletions
  1. 141 107
      proxy/socks/protocol.go

+ 141 - 107
proxy/socks/protocol.go

@@ -1,13 +1,13 @@
 package socks
 
 import (
+	"encoding/binary"
 	"io"
 
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
-	"v2ray.com/core/common/serial"
 )
 
 const (
@@ -43,144 +43,177 @@ type ServerSession struct {
 	port   net.Port
 }
 
-func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {
-	buffer := buf.New()
-	defer buffer.Release()
+func (s *ServerSession) handshake4(cmd byte, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {
+	if s.config.AuthType == AuthType_PASSWORD {
+		writeSocks4Response(writer, socks4RequestRejected, net.AnyIP, net.Port(0)) // nolint: errcheck
+		return nil, newError("socks 4 is not allowed when auth is required.")
+	}
 
-	request := new(protocol.RequestHeader)
+	var port net.Port
+	var address net.Address
 
-	if _, err := buffer.ReadFullFrom(reader, 2); err != nil {
-		return nil, newError("insufficient header").Base(err)
+	{
+		buffer := buf.StackNew()
+		if _, err := buffer.ReadFullFrom(reader, 6); err != nil {
+			buffer.Release()
+			return nil, newError("insufficient header").Base(err)
+		}
+		port = net.PortFromBytes(buffer.BytesRange(0, 2))
+		address = net.IPAddress(buffer.BytesRange(2, 6))
+		buffer.Release()
 	}
 
-	version := buffer.Byte(0)
-	if version == socks4Version {
-		if s.config.AuthType == AuthType_PASSWORD {
-			writeSocks4Response(writer, socks4RequestRejected, net.AnyIP, net.Port(0)) // nolint: errcheck
-			return nil, newError("socks 4 is not allowed when auth is required.")
+	if _, err := ReadUntilNull(reader); /* user id */ err != nil {
+		return nil, err
+	}
+	if address.IP()[0] == 0x00 {
+		domain, err := ReadUntilNull(reader)
+		if err != nil {
+			return nil, newError("failed to read domain for socks 4a").Base(err)
 		}
+		address = net.DomainAddress(domain)
+	}
 
-		if _, err := buffer.ReadFullFrom(reader, 6); err != nil {
-			return nil, newError("insufficient header").Base(err)
+	switch cmd {
+	case cmdTCPConnect:
+		request := &protocol.RequestHeader{
+			Command: protocol.RequestCommandTCP,
+			Address: address,
+			Port:    port,
+			Version: socks4Version,
 		}
-		port := net.PortFromBytes(buffer.BytesRange(2, 4))
-		address := net.IPAddress(buffer.BytesRange(4, 8))
-		if _, err := ReadUntilNull(reader); /* user id */ err != nil {
+		if err := writeSocks4Response(writer, socks4RequestGranted, net.AnyIP, net.Port(0)); err != nil {
 			return nil, err
 		}
-		if address.IP()[0] == 0x00 {
-			domain, err := ReadUntilNull(reader)
-			if err != nil {
-				return nil, newError("failed to read domain for socks 4a").Base(err)
-			}
-			address = net.DomainAddress(domain)
-		}
+		return request, nil
+	default:
+		writeSocks4Response(writer, socks4RequestRejected, net.AnyIP, net.Port(0)) // nolint: errcheck
+		return nil, newError("unsupported command: ", cmd)
+	}
+}
 
-		switch buffer.Byte(1) {
-		case cmdTCPConnect:
-			request.Command = protocol.RequestCommandTCP
-			request.Address = address
-			request.Port = port
-			request.Version = socks4Version
-			if err := writeSocks4Response(writer, socks4RequestGranted, net.AnyIP, net.Port(0)); err != nil {
-				return nil, err
-			}
-			return request, nil
-		default:
-			writeSocks4Response(writer, socks4RequestRejected, net.AnyIP, net.Port(0)) // nolint: errcheck
-			return nil, newError("unsupported command: ", buffer.Byte(1))
-		}
+func (s *ServerSession) auth5(nMethod byte, reader io.Reader, writer io.Writer) error {
+	buffer := buf.StackNew()
+	defer buffer.Release()
+
+	if _, err := buffer.ReadFullFrom(reader, int32(nMethod)); err != nil {
+		return newError("failed to read auth methods").Base(err)
 	}
 
-	if version == socks5Version {
-		nMethod := int32(buffer.Byte(1))
-		if _, err := buffer.ReadFullFrom(reader, nMethod); err != nil {
-			return nil, newError("failed to read auth methods").Base(err)
-		}
+	var expectedAuth byte = authNotRequired
+	if s.config.AuthType == AuthType_PASSWORD {
+		expectedAuth = authPassword
+	}
 
-		var expectedAuth byte = authNotRequired
-		if s.config.AuthType == AuthType_PASSWORD {
-			expectedAuth = authPassword
-		}
+	if !hasAuthMethod(expectedAuth, buffer.BytesRange(0, int32(nMethod))) {
+		writeSocks5AuthenticationResponse(writer, socks5Version, authNoMatchingMethod) // nolint: errcheck
+		return newError("no matching auth method")
+	}
 
-		if !hasAuthMethod(expectedAuth, buffer.BytesRange(2, 2+nMethod)) {
-			writeSocks5AuthenticationResponse(writer, socks5Version, authNoMatchingMethod) // nolint: errcheck
-			return nil, newError("no matching auth method")
+	if err := writeSocks5AuthenticationResponse(writer, socks5Version, expectedAuth); err != nil {
+		return newError("failed to write auth response").Base(err)
+	}
+
+	if expectedAuth == authPassword {
+		username, password, err := ReadUsernamePassword(reader)
+		if err != nil {
+			return newError("failed to read username and password for authentication").Base(err)
 		}
 
-		if err := writeSocks5AuthenticationResponse(writer, socks5Version, expectedAuth); err != nil {
-			return nil, newError("failed to write auth response").Base(err)
+		if !s.config.HasAccount(username, password) {
+			writeSocks5AuthenticationResponse(writer, 0x01, 0xFF) // nolint: errcheck
+			return newError("invalid username or password")
 		}
 
-		if expectedAuth == authPassword {
-			username, password, err := ReadUsernamePassword(reader)
-			if err != nil {
-				return nil, newError("failed to read username and password for authentication").Base(err)
-			}
+		if err := writeSocks5AuthenticationResponse(writer, 0x01, 0x00); err != nil {
+			return newError("failed to write auth response").Base(err)
+		}
+	}
 
-			if !s.config.HasAccount(username, password) {
-				writeSocks5AuthenticationResponse(writer, 0x01, 0xFF) // nolint: errcheck
-				return nil, newError("invalid username or password")
-			}
+	return nil
+}
 
-			if err := writeSocks5AuthenticationResponse(writer, 0x01, 0x00); err != nil {
-				return nil, newError("failed to write auth response").Base(err)
-			}
-		}
+func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {
+	if err := s.auth5(nMethod, reader, writer); err != nil {
+		return nil, err
+	}
 
-		buffer.Clear()
+	var cmd byte
+	{
+		buffer := buf.StackNew()
 		if _, err := buffer.ReadFullFrom(reader, 3); err != nil {
+			buffer.Release()
 			return nil, newError("failed to read request").Base(err)
 		}
+		cmd = buffer.Byte(1)
+		buffer.Release()
+	}
 
-		cmd := buffer.Byte(1)
-		switch cmd {
-		case cmdTCPConnect, cmdTorResolve, cmdTorResolvePTR:
-			// We don't have a solution for Tor case now. Simply treat it as connect command.
-			request.Command = protocol.RequestCommandTCP
-		case cmdUDPPort:
-			if !s.config.UdpEnabled {
-				writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) // nolint: errcheck
-				return nil, newError("UDP is not enabled.")
-			}
-			request.Command = protocol.RequestCommandUDP
-		case cmdTCPBind:
-			writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) // nolint: errcheck
-			return nil, newError("TCP bind is not supported.")
-		default:
+	request := new(protocol.RequestHeader)
+	switch cmd {
+	case cmdTCPConnect, cmdTorResolve, cmdTorResolvePTR:
+		// We don't have a solution for Tor case now. Simply treat it as connect command.
+		request.Command = protocol.RequestCommandTCP
+	case cmdUDPPort:
+		if !s.config.UdpEnabled {
 			writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) // nolint: errcheck
-			return nil, newError("unknown command ", cmd)
+			return nil, newError("UDP is not enabled.")
 		}
+		request.Command = protocol.RequestCommandUDP
+	case cmdTCPBind:
+		writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) // nolint: errcheck
+		return nil, newError("TCP bind is not supported.")
+	default:
+		writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) // nolint: errcheck
+		return nil, newError("unknown command ", cmd)
+	}
 
-		buffer.Clear()
+	request.Version = socks5Version
 
-		request.Version = socks5Version
+	addr, port, err := addrParser.ReadAddressPort(nil, reader)
+	if err != nil {
+		return nil, newError("failed to read address").Base(err)
+	}
+	request.Address = addr
+	request.Port = port
 
-		addr, port, err := addrParser.ReadAddressPort(buffer, reader)
-		if err != nil {
-			return nil, newError("failed to read address").Base(err)
-		}
-		request.Address = addr
-		request.Port = port
-
-		responseAddress := net.AnyIP
-		responsePort := net.Port(1717)
-		if request.Command == protocol.RequestCommandUDP {
-			addr := s.config.Address.AsAddress()
-			if addr == nil {
-				addr = net.LocalHostIP
-			}
-			responseAddress = addr
-			responsePort = s.port
-		}
-		if err := writeSocks5Response(writer, statusSuccess, responseAddress, responsePort); err != nil {
-			return nil, err
+	responseAddress := net.AnyIP
+	responsePort := net.Port(1717)
+	if request.Command == protocol.RequestCommandUDP {
+		addr := s.config.Address.AsAddress()
+		if addr == nil {
+			addr = net.LocalHostIP
 		}
+		responseAddress = addr
+		responsePort = s.port
+	}
+	if err := writeSocks5Response(writer, statusSuccess, responseAddress, responsePort); err != nil {
+		return nil, err
+	}
 
-		return request, nil
+	return request, nil
+}
+
+// Handshake performs a Socks4/4a/5 handshake.
+func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {
+	buffer := buf.StackNew()
+	if _, err := buffer.ReadFullFrom(reader, 2); err != nil {
+		buffer.Release()
+		return nil, newError("insufficient header").Base(err)
 	}
 
-	return nil, newError("unknown Socks version: ", version)
+	version := buffer.Byte(0)
+	cmd := buffer.Byte(1)
+	buffer.Release()
+
+	switch version {
+	case socks4Version:
+		return s.handshake4(cmd, reader, writer)
+	case socks5Version:
+		return s.handshake5(cmd, reader, writer)
+	default:
+		return nil, newError("unknown Socks version: ", version)
+	}
 }
 
 // ReadUsernamePassword reads Socks 5 username/password message from the given reader.
@@ -264,12 +297,13 @@ func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, po
 }
 
 func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error {
-	buffer := buf.New()
+	buffer := buf.StackNew()
 	defer buffer.Release()
 
 	common.Must(buffer.WriteByte(0x00))
 	common.Must(buffer.WriteByte(errCode))
-	common.Must2(serial.WriteUint16(buffer, port.Value()))
+	portBytes := buffer.Extend(2)
+	binary.BigEndian.PutUint16(portBytes, port.Value())
 	common.Must2(buffer.Write(address.IP()))
 	return buf.WriteAllBytes(writer, buffer.Bytes())
 }