소스 검색

use AddressParser in mux

Darien Raymond 7 년 전
부모
커밋
c43a5e7d85
4개의 변경된 파일51개의 추가작업 그리고 77개의 파일을 삭제
  1. 44 70
      app/proxyman/mux/frame.go
  2. 1 1
      app/proxyman/mux/reader.go
  3. 3 3
      app/proxyman/mux/writer.go
  4. 3 3
      proxy/vmess/encoding/encoding.go

+ 44 - 70
app/proxyman/mux/frame.go

@@ -28,6 +28,13 @@ const (
 	TargetNetworkUDP TargetNetwork = 0x02
 )
 
+var addrParser = protocol.NewAddressParser(
+	protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4),
+	protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain),
+	protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv6), net.AddressFamilyIPv6),
+	protocol.PortThenAddress(),
+)
+
 /*
 Frame format
 2 bytes - length
@@ -48,88 +55,55 @@ type FrameMetadata struct {
 	SessionStatus SessionStatus
 }
 
-func (f FrameMetadata) AsSupplier() buf.Supplier {
-	return func(b []byte) (int, error) {
-		lengthBytes := b
-		b = serial.Uint16ToBytes(uint16(0), b[:0]) // place holder for length
-
-		b = serial.Uint16ToBytes(f.SessionID, b)
-		b = append(b, byte(f.SessionStatus), byte(f.Option))
-		length := 4
-
-		if f.SessionStatus == SessionStatusNew {
-			switch f.Target.Network {
-			case net.Network_TCP:
-				b = append(b, byte(TargetNetworkTCP))
-			case net.Network_UDP:
-				b = append(b, byte(TargetNetworkUDP))
-			}
-			length++
-
-			b = serial.Uint16ToBytes(f.Target.Port.Value(), b)
-			length += 2
-
-			addr := f.Target.Address
-			switch addr.Family() {
-			case net.AddressFamilyIPv4:
-				b = append(b, byte(protocol.AddressTypeIPv4))
-				b = append(b, addr.IP()...)
-				length += 5
-			case net.AddressFamilyIPv6:
-				b = append(b, byte(protocol.AddressTypeIPv6))
-				b = append(b, addr.IP()...)
-				length += 17
-			case net.AddressFamilyDomain:
-				domain := addr.Domain()
-				if protocol.IsDomainTooLong(domain) {
-					return 0, newError("domain name too long: ", domain)
-				}
-				nDomain := len(domain)
-				b = append(b, byte(protocol.AddressTypeDomain), byte(nDomain))
-				b = append(b, domain...)
-				length += nDomain + 2
-			}
+func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
+	lenBytes := b.Bytes()
+	b.AppendBytes(0x00, 0x00)
+
+	len0 := b.Len()
+	if err := b.AppendSupplier(serial.WriteUint16(f.SessionID)); err != nil {
+		return err
+	}
+
+	b.AppendBytes(byte(f.SessionStatus), byte(f.Option))
+
+	if f.SessionStatus == SessionStatusNew {
+		switch f.Target.Network {
+		case net.Network_TCP:
+			b.AppendBytes(byte(TargetNetworkTCP))
+		case net.Network_UDP:
+			b.AppendBytes(byte(TargetNetworkUDP))
 		}
 
-		serial.Uint16ToBytes(uint16(length), lengthBytes[:0])
-		return length + 2, nil
+		if err := addrParser.WriteAddressPort(b, f.Target.Address, f.Target.Port); err != nil {
+			return err
+		}
 	}
+
+	len1 := b.Len()
+	serial.Uint16ToBytes(uint16(len1-len0), lenBytes)
+	return nil
 }
 
-func ReadFrameFrom(b []byte) (*FrameMetadata, error) {
-	if len(b) < 4 {
-		return nil, newError("insufficient buffer: ", len(b))
+func ReadFrameFrom(b *buf.Buffer) (*FrameMetadata, error) {
+	if b.Len() < 4 {
+		return nil, newError("insufficient buffer: ", b.Len())
 	}
 
 	f := &FrameMetadata{
-		SessionID:     serial.BytesToUint16(b[:2]),
-		SessionStatus: SessionStatus(b[2]),
-		Option:        bitmask.Byte(b[3]),
+		SessionID:     serial.BytesToUint16(b.BytesTo(2)),
+		SessionStatus: SessionStatus(b.Byte(2)),
+		Option:        bitmask.Byte(b.Byte(3)),
 	}
 
-	b = b[4:]
-
 	if f.SessionStatus == SessionStatusNew {
-		network := TargetNetwork(b[0])
-		port := net.PortFromBytes(b[1:3])
-		addrType := protocol.AddressType(b[3])
-		b = b[4:]
-
-		var addr net.Address
-		switch addrType {
-		case protocol.AddressTypeIPv4:
-			addr = net.IPAddress(b[0:4])
-			b = b[4:]
-		case protocol.AddressTypeIPv6:
-			addr = net.IPAddress(b[0:16])
-			b = b[16:]
-		case protocol.AddressTypeDomain:
-			nDomain := int(b[0])
-			addr = net.DomainAddress(string(b[1 : 1+nDomain]))
-			b = b[nDomain+1:]
-		default:
-			return nil, newError("unknown address type: ", addrType)
+		network := TargetNetwork(b.Byte(4))
+		b.SliceFrom(5)
+
+		addr, port, err := addrParser.ReadAddressPort(nil, b)
+		if err != nil {
+			return nil, newError("failed to parse address and port").Base(err)
 		}
+
 		switch network {
 		case TargetNetworkTCP:
 			f.Target = net.TCPDestination(addr, port)

+ 1 - 1
app/proxyman/mux/reader.go

@@ -23,7 +23,7 @@ func ReadMetadata(reader io.Reader) (*FrameMetadata, error) {
 	if err := b.Reset(buf.ReadFullFrom(reader, int(metaLen))); err != nil {
 		return nil, err
 	}
-	return ReadFrameFrom(b.Bytes())
+	return ReadFrameFrom(b)
 }
 
 // PacketReader is an io.Reader that reads whole chunk of Mux frames every time.

+ 3 - 3
app/proxyman/mux/writer.go

@@ -53,7 +53,7 @@ func (w *Writer) getNextFrameMeta() FrameMetadata {
 func (w *Writer) writeMetaOnly() error {
 	meta := w.getNextFrameMeta()
 	b := buf.New()
-	if err := b.Reset(meta.AsSupplier()); err != nil {
+	if err := meta.WriteTo(b); err != nil {
 		return err
 	}
 	return w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
@@ -64,7 +64,7 @@ func (w *Writer) writeData(mb buf.MultiBuffer) error {
 	meta.Option.Set(OptionData)
 
 	frame := buf.New()
-	if err := frame.Reset(meta.AsSupplier()); err != nil {
+	if err := meta.WriteTo(frame); err != nil {
 		return err
 	}
 	if err := frame.AppendSupplier(serial.WriteUint16(uint16(mb.Len()))); err != nil {
@@ -107,7 +107,7 @@ func (w *Writer) Close() error {
 	}
 
 	frame := buf.New()
-	common.Must(frame.Reset(meta.AsSupplier()))
+	common.Must(meta.WriteTo(frame))
 
 	w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(frame))
 	return nil

+ 3 - 3
proxy/vmess/encoding/encoding.go

@@ -12,8 +12,8 @@ const (
 )
 
 var addrParser = protocol.NewAddressParser(
-	protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
-	protocol.AddressFamilyByte(0x02, net.AddressFamilyDomain),
-	protocol.AddressFamilyByte(0x03, net.AddressFamilyIPv6),
+	protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4),
+	protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain),
+	protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv6), net.AddressFamilyIPv6),
 	protocol.PortThenAddress(),
 )