Selaa lähdekoodia

more protocol implementation

v2ray 9 vuotta sitten
vanhempi
commit
cba26dabe8

+ 2 - 2
common/protocol/encoding.go

@@ -10,7 +10,7 @@ type RequestEncoder interface {
 }
 
 type RequestDecoder interface {
-	DecodeRequestHeader(io.Reader) *RequestHeader
+	DecodeRequestHeader(io.Reader) (*RequestHeader, error)
 	DecodeRequestBody(io.Reader) io.Reader
 }
 
@@ -20,6 +20,6 @@ type ResponseEncoder interface {
 }
 
 type ResponseDecoder interface {
-	DecodeResponseHeader(io.Reader) *ResponseHeader
+	DecodeResponseHeader(io.Reader) (*ResponseHeader, error)
 	DecodeResponseBody(io.Reader) io.Reader
 }

+ 10 - 0
common/protocol/errors.go

@@ -0,0 +1,10 @@
+package protocol
+
+import (
+	"errors"
+)
+
+var (
+	ErrorInvalidUser    = errors.New("Invalid user.")
+	ErrorInvalidVersion = errors.New("Invalid version.")
+)

+ 11 - 0
common/protocol/headers.go

@@ -2,6 +2,8 @@ package protocol
 
 import (
 	v2net "github.com/v2ray/v2ray-core/common/net"
+	"github.com/v2ray/v2ray-core/common/serial"
+	"github.com/v2ray/v2ray-core/common/uuid"
 )
 
 type RequestCommand byte
@@ -31,3 +33,12 @@ type ResponseCommand interface{}
 type ResponseHeader struct {
 	Command ResponseCommand
 }
+
+type CommandSwitchAccount struct {
+	Host     v2net.Address
+	Port     v2net.Port
+	ID       *uuid.UUID
+	AlterIds serial.Uint16Literal
+	Level    UserLevel
+	ValidMin byte
+}

+ 46 - 0
common/protocol/raw/client.go

@@ -8,7 +8,9 @@ import (
 
 	"github.com/v2ray/v2ray-core/common/alloc"
 	"github.com/v2ray/v2ray-core/common/crypto"
+	"github.com/v2ray/v2ray-core/common/log"
 	"github.com/v2ray/v2ray-core/common/protocol"
+	"github.com/v2ray/v2ray-core/transport"
 )
 
 func hashTimestamp(t protocol.Timestamp) []byte {
@@ -27,6 +29,7 @@ type ClientSession struct {
 	responseHeader  byte
 	responseBodyKey []byte
 	responseBodyIV  []byte
+	responseReader  io.Reader
 	idHash          protocol.IDHash
 }
 
@@ -38,6 +41,10 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession {
 	session.requestBodyKey = randomBytes[:16]
 	session.requestBodyIV = randomBytes[16:32]
 	session.responseHeader = randomBytes[32]
+	responseBodyKey := md5.Sum(session.requestBodyKey)
+	responseBodyIV := md5.Sum(session.requestBodyIV)
+	session.responseBodyKey = responseBodyKey[:]
+	session.responseBodyIV = responseBodyIV[:]
 	session.idHash = idHash
 
 	return session
@@ -97,3 +104,42 @@ func (this *ClientSession) EncodeRequestBody(writer io.Writer) io.Writer {
 	return crypto.NewCryptionWriter(aesStream, writer)
 }
 
+func (this *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) {
+	aesStream := crypto.NewAesDecryptionStream(this.responseBodyKey, this.responseBodyIV)
+	this.responseReader = crypto.NewCryptionReader(aesStream, reader)
+
+	buffer := alloc.NewSmallBuffer()
+	defer buffer.Release()
+
+	_, err := io.ReadFull(this.responseReader, buffer.Value[:4])
+	if err != nil {
+		log.Error("Raw: Failed to read response header: ", err)
+		return nil, err
+	}
+
+	if buffer.Value[0] != this.responseHeader {
+		log.Warning("Raw: Unexpected response header. Expecting %d, but actually %d", this.responseHeader, buffer.Value[0])
+		return nil, transport.ErrorCorruptedPacket
+	}
+
+	header := new(protocol.ResponseHeader)
+
+	if buffer.Value[2] != 0 {
+		cmdId := buffer.Value[2]
+		dataLen := int(buffer.Value[3])
+		_, err := io.ReadFull(this.responseReader, buffer.Value[:dataLen])
+		if err != nil {
+			log.Error("Raw: Failed to read response command: ", err)
+			return nil, err
+		}
+		data := buffer.Value[:dataLen]
+		command, err := UnmarshalCommand(cmdId, data)
+		header.Command = command
+	}
+
+	return header, nil
+}
+
+func (this *ClientSession) DecodeResponseBody(reader io.Reader) io.Reader {
+	return this.responseReader
+}

+ 115 - 0
common/protocol/raw/commands.go

@@ -0,0 +1,115 @@
+package raw
+
+import (
+	"errors"
+	"io"
+
+	v2net "github.com/v2ray/v2ray-core/common/net"
+	"github.com/v2ray/v2ray-core/common/protocol"
+	"github.com/v2ray/v2ray-core/common/serial"
+	"github.com/v2ray/v2ray-core/common/uuid"
+	"github.com/v2ray/v2ray-core/transport"
+)
+
+var (
+	ErrorCommandTypeMismatch = errors.New("Command type mismatch.")
+	ErrorUnknownCommand      = errors.New("Unknown command.")
+)
+
+func MarshalCommand(command interface{}, writer io.Writer) error {
+	var factory CommandFactory
+	switch command.(type) {
+	case *protocol.CommandSwitchAccount:
+		factory = new(CommandSwitchAccountFactory)
+	default:
+		return ErrorUnknownCommand
+	}
+	return factory.Marshal(command, writer)
+}
+
+func UnmarshalCommand(cmdId byte, data []byte) (protocol.ResponseCommand, error) {
+	var factory CommandFactory
+	switch cmdId {
+	case 1:
+		factory = new(CommandSwitchAccountFactory)
+	default:
+		return nil, ErrorUnknownCommand
+	}
+	return factory.Unmarshal(data)
+}
+
+type CommandFactory interface {
+	Marshal(command interface{}, writer io.Writer) error
+	Unmarshal(data []byte) (interface{}, error)
+}
+
+type CommandSwitchAccountFactory struct {
+}
+
+func (this *CommandSwitchAccountFactory) Marshal(command interface{}, writer io.Writer) error {
+	cmd, ok := command.(*protocol.CommandSwitchAccount)
+	if !ok {
+		return ErrorCommandTypeMismatch
+	}
+
+	hostStr := ""
+	if cmd.Host != nil {
+		hostStr = cmd.Host.String()
+	}
+	writer.Write([]byte{byte(len(hostStr))})
+
+	if len(hostStr) > 0 {
+		writer.Write([]byte(hostStr))
+	}
+
+	writer.Write(cmd.Port.Bytes())
+
+	idBytes := cmd.ID.Bytes()
+	writer.Write(idBytes)
+
+	writer.Write(cmd.AlterIds.Bytes())
+	writer.Write([]byte{byte(cmd.Level)})
+
+	writer.Write([]byte{cmd.ValidMin})
+	return nil
+}
+
+func (this *CommandSwitchAccountFactory) Unmarshal(data []byte) (interface{}, error) {
+	cmd := new(protocol.CommandSwitchAccount)
+	if len(data) == 0 {
+		return nil, transport.ErrorCorruptedPacket
+	}
+	lenHost := int(data[0])
+	if len(data) < lenHost+1 {
+		return nil, transport.ErrorCorruptedPacket
+	}
+	if lenHost > 0 {
+		cmd.Host = v2net.ParseAddress(string(data[1 : 1+lenHost]))
+	}
+	portStart := 1 + lenHost
+	if len(data) < portStart+2 {
+		return nil, transport.ErrorCorruptedPacket
+	}
+	cmd.Port = v2net.PortFromBytes(data[portStart : portStart+2])
+	idStart := portStart + 2
+	if len(data) < idStart+16 {
+		return nil, transport.ErrorCorruptedPacket
+	}
+	cmd.ID, _ = uuid.ParseBytes(data[idStart : idStart+16])
+	alterIdStart := idStart + 16
+	if len(data) < alterIdStart+2 {
+		return nil, transport.ErrorCorruptedPacket
+	}
+	cmd.AlterIds = serial.BytesLiteral(data[alterIdStart : alterIdStart+2]).Uint16()
+	levelStart := alterIdStart + 2
+	if len(data) < levelStart+1 {
+		return nil, transport.ErrorCorruptedPacket
+	}
+	cmd.Level = protocol.UserLevel(data[levelStart])
+	timeStart := levelStart + 1
+	if len(data) < timeStart {
+		return nil, transport.ErrorCorruptedPacket
+	}
+	cmd.ValidMin = data[timeStart]
+	return cmd, nil
+}

+ 42 - 0
common/protocol/raw/commands_test.go

@@ -0,0 +1,42 @@
+package raw_test
+
+import (
+	"bytes"
+	"testing"
+
+	netassert "github.com/v2ray/v2ray-core/common/net/testing/assert"
+	"github.com/v2ray/v2ray-core/common/protocol"
+	. "github.com/v2ray/v2ray-core/common/protocol/raw"
+	"github.com/v2ray/v2ray-core/common/uuid"
+	v2testing "github.com/v2ray/v2ray-core/testing"
+	"github.com/v2ray/v2ray-core/testing/assert"
+)
+
+func TestSwitchAccount(t *testing.T) {
+	v2testing.Current(t)
+
+	sa := &protocol.CommandSwitchAccount{
+		Port:     1234,
+		ID:       uuid.New(),
+		AlterIds: 1024,
+		Level:    128,
+		ValidMin: 16,
+	}
+
+	buffer := bytes.NewBuffer(make([]byte, 0, 1024))
+	err := MarshalCommand(sa, buffer)
+	assert.Error(err).IsNil()
+
+	cmd, err := UnmarshalCommand(1, buffer.Bytes())
+	assert.Error(err).IsNil()
+
+	sa2, ok := cmd.(*protocol.CommandSwitchAccount)
+	assert.Bool(ok).IsTrue()
+	assert.Pointer(sa.Host).IsNil()
+	assert.Pointer(sa2.Host).IsNil()
+	netassert.Port(sa.Port).Equals(sa2.Port)
+	assert.String(sa.ID).Equals(sa2.ID.String())
+	assert.Uint16(sa.AlterIds.Value()).Equals(sa2.AlterIds.Value())
+	assert.Byte(byte(sa.Level)).Equals(byte(sa2.Level))
+	assert.Byte(sa.ValidMin).Equals(sa2.ValidMin)
+}

+ 143 - 0
common/protocol/raw/server.go

@@ -0,0 +1,143 @@
+package raw
+
+import (
+	"crypto/md5"
+	"hash/fnv"
+	"io"
+
+	"github.com/v2ray/v2ray-core/common/alloc"
+	"github.com/v2ray/v2ray-core/common/crypto"
+	"github.com/v2ray/v2ray-core/common/log"
+	v2net "github.com/v2ray/v2ray-core/common/net"
+	"github.com/v2ray/v2ray-core/common/protocol"
+	"github.com/v2ray/v2ray-core/common/serial"
+	"github.com/v2ray/v2ray-core/transport"
+)
+
+type ServerSession struct {
+	userValidator   protocol.UserValidator
+	requestBodyKey  []byte
+	requestBodyIV   []byte
+	responseBodyKey []byte
+	responseBodyIV  []byte
+	responseHeader  byte
+	responseWriter  io.Writer
+}
+
+func (this *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) {
+	buffer := alloc.NewSmallBuffer()
+	defer buffer.Release()
+
+	_, err := io.ReadFull(reader, buffer.Value[:protocol.IDBytesLen])
+	if err != nil {
+		log.Error("Raw: Failed to read request header: ", err)
+		return nil, err
+	}
+
+	user, timestamp, valid := this.userValidator.Get(buffer.Value[:protocol.IDBytesLen])
+	if !valid {
+		return nil, protocol.ErrorInvalidUser
+	}
+
+	timestampHash := md5.New()
+	timestampHash.Write(hashTimestamp(timestamp))
+	iv := timestampHash.Sum(nil)
+	aesStream := crypto.NewAesDecryptionStream(user.ID.CmdKey(), iv)
+	decryptor := crypto.NewCryptionReader(aesStream, reader)
+
+	nBytes, err := io.ReadFull(decryptor, buffer.Value[:41])
+	if err != nil {
+		log.Debug("Raw: Failed to read request header (", nBytes, " bytes): ", err)
+		return nil, err
+	}
+	bufferLen := nBytes
+
+	request := &protocol.RequestHeader{
+		User:    user,
+		Version: buffer.Value[0],
+	}
+
+	if request.Version != Version {
+		log.Warning("Raw: Invalid protocol version ", request.Version)
+		return nil, protocol.ErrorInvalidVersion
+	}
+
+	this.requestBodyIV = append([]byte(nil), buffer.Value[1:17]...)   // 16 bytes
+	this.requestBodyKey = append([]byte(nil), buffer.Value[17:33]...) // 16 bytes
+	this.responseHeader = buffer.Value[33]                            // 1 byte
+	request.Option = protocol.RequestOption(buffer.Value[34])         // 1 byte + 2 bytes reserved
+	request.Command = protocol.RequestCommand(buffer.Value[37])
+
+	request.Port = v2net.PortFromBytes(buffer.Value[38:40])
+
+	switch buffer.Value[40] {
+	case AddrTypeIPv4:
+		nBytes, err = io.ReadFull(decryptor, buffer.Value[41:45]) // 4 bytes
+		bufferLen += 4
+		if err != nil {
+			log.Debug("VMess: Failed to read target IPv4 (", nBytes, " bytes): ", err)
+			return nil, err
+		}
+		request.Address = v2net.IPAddress(buffer.Value[41:45])
+	case AddrTypeIPv6:
+		nBytes, err = io.ReadFull(decryptor, buffer.Value[41:57]) // 16 bytes
+		bufferLen += 16
+		if err != nil {
+			log.Debug("VMess: Failed to read target IPv6 (", nBytes, " bytes): ", nBytes, err)
+			return nil, err
+		}
+		request.Address = v2net.IPAddress(buffer.Value[41:57])
+	case AddrTypeDomain:
+		nBytes, err = io.ReadFull(decryptor, buffer.Value[41:42])
+		if err != nil {
+			log.Debug("VMess: Failed to read target domain (", nBytes, " bytes): ", nBytes, err)
+			return nil, err
+		}
+		domainLength := int(buffer.Value[41])
+		if domainLength == 0 {
+			return nil, transport.ErrorCorruptedPacket
+		}
+		nBytes, err = io.ReadFull(decryptor, buffer.Value[42:42+domainLength])
+		if err != nil {
+			log.Debug("VMess: Failed to read target domain (", nBytes, " bytes): ", nBytes, err)
+			return nil, err
+		}
+		bufferLen += 1 + domainLength
+		domainBytes := append([]byte(nil), buffer.Value[42:42+domainLength]...)
+		request.Address = v2net.DomainAddress(string(domainBytes))
+	}
+
+	nBytes, err = io.ReadFull(decryptor, buffer.Value[bufferLen:bufferLen+4])
+	if err != nil {
+		log.Debug("VMess: Failed to read checksum (", nBytes, " bytes): ", nBytes, err)
+		return nil, err
+	}
+
+	fnv1a := fnv.New32a()
+	fnv1a.Write(buffer.Value[:bufferLen])
+	actualHash := fnv1a.Sum32()
+	expectedHash := serial.BytesLiteral(buffer.Value[bufferLen : bufferLen+4]).Uint32Value()
+
+	if actualHash != expectedHash {
+		return nil, transport.ErrorCorruptedPacket
+	}
+
+	return request, nil
+}
+
+func (this *ServerSession) DecodeRequestBody(reader io.Reader) io.Reader {
+	aesStream := crypto.NewAesDecryptionStream(this.requestBodyKey, this.requestBodyIV)
+	return crypto.NewCryptionReader(aesStream, reader)
+}
+
+func (this *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, writer io.Writer) {
+	responseBodyKey := md5.Sum(this.requestBodyKey)
+	responseBodyIV := md5.Sum(this.requestBodyIV)
+	this.responseBodyKey = responseBodyKey[:]
+	this.requestBodyIV = responseBodyIV[:]
+
+	aesStream := crypto.NewAesEncryptionStream(this.responseBodyKey, this.responseBodyIV)
+	encryptionWriter := crypto.NewCryptionWriter(aesStream, writer)
+	this.responseWriter = encryptionWriter
+
+}