Browse Source

Some fixes about Hysteria 2 (#3147)

* hysteria2: remove unused code

* hysteria2: don't ignore some errors

* hysteria2: properly implement TCP request padding

* hysteria2: fix dialer reuse
dyhkwong 1 year ago
parent
commit
84adf2bdb2

+ 2 - 2
proxy/hysteria2/client.go

@@ -76,6 +76,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 	}
 	newError("tunneling request to ", destination, " via ", server.Destination().NetAddr()).WriteToLog(session.ExportIDToError(ctx))
 
+	defer conn.Close()
+
 	iConn := conn
 	if statConn, ok := conn.(*internet.StatCouterConnection); ok {
 		iConn = statConn.Connection // will not count the UDP traffic.
@@ -87,8 +89,6 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 		return newError(hyTransport.CanNotUseUdpExtension)
 	}
 
-	defer conn.Close()
-
 	user := server.PickUser()
 	userLevel := uint32(0)
 	if user != nil {

+ 1 - 28
proxy/hysteria2/config.go

@@ -1,19 +1,11 @@
 package hysteria2
 
 import (
-	"crypto/sha256"
-	"encoding/hex"
-	"fmt"
-
-	"github.com/v2fly/v2ray-core/v5/common"
 	"github.com/v2fly/v2ray-core/v5/common/protocol"
 )
 
 // MemoryAccount is an account type converted from Account.
-type MemoryAccount struct {
-	Password string
-	Key      []byte
-}
+type MemoryAccount struct{}
 
 // AsAccount implements protocol.AsAccount.
 func (a *Account) AsAccount() (protocol.Account, error) {
@@ -22,24 +14,5 @@ func (a *Account) AsAccount() (protocol.Account, error) {
 
 // Equals implements protocol.Account.Equals().
 func (a *MemoryAccount) Equals(another protocol.Account) bool {
-	if account, ok := another.(*MemoryAccount); ok {
-		return a.Password == account.Password
-	}
 	return false
 }
-
-func hexSha224(password string) []byte {
-	buf := make([]byte, 56)
-	hash := sha256.New224()
-	common.Must2(hash.Write([]byte(password)))
-	hex.Encode(buf, hash.Sum(nil))
-	return buf
-}
-
-func hexString(data []byte) string {
-	str := ""
-	for _, v := range data {
-		str += fmt.Sprintf("%02x", v)
-	}
-	return str
-}

+ 18 - 34
proxy/hysteria2/protocol.go

@@ -2,7 +2,7 @@ package hysteria2
 
 import (
 	"io"
-	gonet "net"
+	"math/rand"
 
 	hyProtocol "github.com/apernet/hysteria/core/v2/international/protocol"
 	"github.com/apernet/quic-go/quicvarint"
@@ -12,6 +12,10 @@ import (
 	hyTransport "github.com/v2fly/v2ray-core/v5/transport/internet/hysteria2"
 )
 
+const (
+	paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
+)
+
 // ConnWriter is TCP Connection Writer Wrapper
 type ConnWriter struct {
 	io.Writer
@@ -61,17 +65,17 @@ func QuicLen(s int) int {
 func (c *ConnWriter) writeTCPHeader() error {
 	c.TCPHeaderSent = true
 
-	// TODO: the padding length here should be randomized
-
-	padding := "Jimmy Was Here"
-	paddingLen := len(padding)
+	paddingLen := 64 + rand.Intn(512-64)
+	padding := make([]byte, paddingLen)
+	for i := range padding {
+		padding[i] = paddingChars[rand.Intn(len(paddingChars))]
+	}
 	addressAndPort := c.Target.NetAddr()
 	addressLen := len(addressAndPort)
-	size := QuicLen(addressLen) + addressLen + QuicLen(paddingLen) + paddingLen
-
-	if size > hyProtocol.MaxAddressLength+hyProtocol.MaxPaddingLength {
-		return newError("invalid header length")
+	if addressLen > hyProtocol.MaxAddressLength {
+		return newError("address length too large: ", addressLen)
 	}
+	size := QuicLen(addressLen) + addressLen + QuicLen(paddingLen) + paddingLen
 
 	buf := make([]byte, size)
 	i := hyProtocol.VarintPut(buf, uint64(addressLen))
@@ -120,7 +124,7 @@ func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net
 	return nil
 }
 
-func (w *PacketWriter) WriteTo(payload []byte, addr gonet.Addr) (int, error) {
+func (w *PacketWriter) WriteTo(payload []byte, addr net.Addr) (int, error) {
 	dest := net.DestinationFromAddr(addr)
 
 	return w.writePacket(payload, dest)
@@ -145,7 +149,10 @@ func (c *ConnReader) Read(p []byte) (int, error) {
 func (c *ConnReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 	b := buf.New()
 	_, err := b.ReadFrom(c)
-	return buf.MultiBuffer{b}, err
+	if err != nil {
+		return nil, err
+	}
+	return buf.MultiBuffer{b}, nil
 }
 
 // PacketPayload combines udp payload and destination
@@ -178,26 +185,3 @@ func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
 	b := buf.FromBytes(data)
 	return &PacketPayload{Target: *dest, Buffer: buf.MultiBuffer{b}}, nil
 }
-
-type PacketConnectionReader struct {
-	reader  *PacketReader
-	payload *PacketPayload
-}
-
-func (r *PacketConnectionReader) ReadFrom(p []byte) (n int, addr gonet.Addr, err error) {
-	if r.payload == nil || r.payload.Buffer.IsEmpty() {
-		r.payload, err = r.reader.ReadMultiBufferWithMetadata()
-		if err != nil {
-			return
-		}
-	}
-
-	addr = &gonet.UDPAddr{
-		IP:   r.payload.Target.Address.IP(),
-		Port: int(r.payload.Target.Port),
-	}
-
-	r.payload.Buffer, n = buf.SplitFirstBytes(r.payload.Buffer, p)
-
-	return
-}

+ 2 - 11
proxy/hysteria2/server.go

@@ -13,7 +13,6 @@ import (
 	"github.com/v2fly/v2ray-core/v5/common/errors"
 	"github.com/v2fly/v2ray-core/v5/common/log"
 	"github.com/v2fly/v2ray-core/v5/common/net"
-	"github.com/v2fly/v2ray-core/v5/common/net/packetaddr"
 	udp_proto "github.com/v2fly/v2ray-core/v5/common/protocol/udp"
 	"github.com/v2fly/v2ray-core/v5/common/session"
 	"github.com/v2fly/v2ray-core/v5/common/signal"
@@ -33,8 +32,7 @@ func init() {
 
 // Server is an inbound connection handler that handles messages in protocol.
 type Server struct {
-	policyManager  policy.Manager
-	packetEncoding packetaddr.PacketAddrType
+	policyManager policy.Manager
 }
 
 // NewServer creates a new inbound handler.
@@ -171,12 +169,6 @@ func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Sess
 
 func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error { // {{{
 	udpDispatcherConstructor := udp.NewSplitDispatcher
-	switch s.packetEncoding {
-	case packetaddr.PacketAddrType_None:
-	case packetaddr.PacketAddrType_Packet:
-		packetAddrDispatcherFactory := udp.NewPacketAddrDispatcherCreator(ctx)
-		udpDispatcherConstructor = packetAddrDispatcherFactory.NewPacketAddrDispatcher
-	}
 
 	udpServer := udpDispatcherConstructor(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
 		if err := clientWriter.WriteMultiBufferWithMetadata(buf.MultiBuffer{packet.Payload}, packet.Source); err != nil {
@@ -185,7 +177,6 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
 	})
 
 	inbound := session.InboundFromContext(ctx)
-	// user := inbound.User
 
 	for {
 		select {
@@ -213,4 +204,4 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
 			}
 		}
 	}
-} // }}}
+}

+ 0 - 9
transport/internet/hysteria2/conn.go

@@ -8,7 +8,6 @@ import (
 	hyServer "github.com/apernet/hysteria/core/v2/server"
 	"github.com/apernet/quic-go"
 
-	"github.com/v2fly/v2ray-core/v5/common/buf"
 	"github.com/v2fly/v2ray-core/v5/common/net"
 )
 
@@ -20,7 +19,6 @@ type HyConn struct {
 	IsServer         bool
 	ClientUDPSession hyClient.HyUDPConn
 	ServerUDPSession *hyServer.UdpSessionEntry
-	Target           net.Destination
 
 	stream quic.Stream
 	local  net.Addr
@@ -36,13 +34,6 @@ func (c *HyConn) Read(b []byte) (int, error) {
 	return c.stream.Read(b)
 }
 
-func (c *HyConn) WriteMultiBuffer(mb buf.MultiBuffer) error {
-	mb = buf.Compact(mb)
-	mb, err := buf.WriteMultiBuffer(c, mb)
-	buf.ReleaseMulti(mb)
-	return err
-}
-
 func (c *HyConn) Write(b []byte) (int, error) {
 	if c.IsUDPExtension {
 		dest, _ := net.ParseDestination("udp:v2fly.org:6666")

+ 30 - 22
transport/internet/hysteria2/dialer.go

@@ -15,7 +15,12 @@ import (
 	"github.com/v2fly/v2ray-core/v5/transport/internet/tls"
 )
 
-var RunningClient map[net.Addr](hyClient.Client)
+type dialerConf struct {
+	net.Destination
+	*internet.MemoryStreamConfig
+}
+
+var RunningClient map[dialerConf](hyClient.Client)
 var ClientMutex sync.Mutex
 var MBps uint64 = 1000000 / 8 // MByte
 
@@ -61,12 +66,17 @@ func (f *connFactory) New(addr net.Addr) (net.PacketConn, error) {
 	return f.NewFunc(addr)
 }
 
-func NewHyClient(serverAddr net.Addr, streamSettings *internet.MemoryStreamConfig) (hyClient.Client, error) {
+func NewHyClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) (hyClient.Client, error) {
 	tlsConfig, err := GetClientTLSConfig(streamSettings)
 	if err != nil {
 		return nil, err
 	}
 
+	serverAddr, err := ResolveAddress(dest)
+	if err != nil {
+		return nil, err
+	}
+
 	config := streamSettings.ProtocolSettings.(*Config)
 	client, _, err := hyClient.NewClient(&hyClient.Config{
 		Auth:       config.GetPassword(),
@@ -93,36 +103,36 @@ func NewHyClient(serverAddr net.Addr, streamSettings *internet.MemoryStreamConfi
 	return client, nil
 }
 
-func CloseHyClient(serverAddr net.Addr) error {
+func CloseHyClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) error {
 	ClientMutex.Lock()
 	defer ClientMutex.Unlock()
 
-	client, found := RunningClient[serverAddr]
+	client, found := RunningClient[dialerConf{dest, streamSettings}]
 	if found {
-		delete(RunningClient, serverAddr)
+		delete(RunningClient, dialerConf{dest, streamSettings})
 		return client.Close()
 	}
 	return nil
 }
 
-func GetHyClient(serverAddr net.Addr, streamSettings *internet.MemoryStreamConfig) (hyClient.Client, error) {
+func GetHyClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) (hyClient.Client, error) {
 	var err error
 	var client hyClient.Client
 
 	ClientMutex.Lock()
-	client, found := RunningClient[serverAddr]
+	client, found := RunningClient[dialerConf{dest, streamSettings}]
 	ClientMutex.Unlock()
 	if !found || !CheckHyClientHealthy(client) {
 		if found {
 			// retry
-			CloseHyClient(serverAddr)
+			CloseHyClient(dest, streamSettings)
 		}
-		client, err = NewHyClient(serverAddr, streamSettings)
+		client, err = NewHyClient(dest, streamSettings)
 		if err != nil {
 			return nil, err
 		}
 		ClientMutex.Lock()
-		RunningClient[serverAddr] = client
+		RunningClient[dialerConf{dest, streamSettings}] = client
 		ClientMutex.Unlock()
 	}
 	return client, nil
@@ -144,14 +154,9 @@ func CheckHyClientHealthy(client hyClient.Client) bool {
 func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
 	config := streamSettings.ProtocolSettings.(*Config)
 
-	serverAddr, err := ResolveAddress(dest)
+	client, err := GetHyClient(dest, streamSettings)
 	if err != nil {
-		return nil, err
-	}
-
-	client, err := GetHyClient(serverAddr, streamSettings)
-	if err != nil {
-		CloseHyClient(serverAddr)
+		CloseHyClient(dest, streamSettings)
 		return nil, err
 	}
 
@@ -165,7 +170,6 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 	network := net.Network_TCP
 	if outbound != nil {
 		network = outbound.Target.Network
-		conn.Target = outbound.Target
 	}
 
 	if network == net.Network_UDP && config.GetUseUdpExtension() { // only hysteria2 can use udpExtension
@@ -173,7 +177,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 		conn.IsServer = false
 		conn.ClientUDPSession, err = client.UDP()
 		if err != nil {
-			CloseHyClient(serverAddr)
+			CloseHyClient(dest, streamSettings)
 			return nil, err
 		}
 		return conn, nil
@@ -181,7 +185,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 
 	conn.stream, err = client.OpenStream()
 	if err != nil {
-		CloseHyClient(serverAddr)
+		CloseHyClient(dest, streamSettings)
 		return nil, err
 	}
 
@@ -189,11 +193,15 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 	frameSize := int(quicvarint.Len(hyProtocol.FrameTypeTCPRequest))
 	buf := make([]byte, frameSize)
 	hyProtocol.VarintPut(buf, hyProtocol.FrameTypeTCPRequest)
-	conn.stream.Write(buf)
+	_, err = conn.stream.Write(buf)
+	if err != nil {
+		CloseHyClient(dest, streamSettings)
+		return nil, err
+	}
 	return conn, nil
 }
 
 func init() {
-	RunningClient = make(map[net.Addr]hyClient.Client)
+	RunningClient = make(map[dialerConf]hyClient.Client)
 	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
 }