dialer.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. package hysteria2
  2. import (
  3. "context"
  4. "sync"
  5. "github.com/apernet/quic-go/quicvarint"
  6. hyClient "github.com/v2fly/hysteria/core/v2/client"
  7. hyProtocol "github.com/v2fly/hysteria/core/v2/international/protocol"
  8. "github.com/v2fly/v2ray-core/v5/common"
  9. "github.com/v2fly/v2ray-core/v5/common/net"
  10. "github.com/v2fly/v2ray-core/v5/common/session"
  11. "github.com/v2fly/v2ray-core/v5/transport/internet"
  12. "github.com/v2fly/v2ray-core/v5/transport/internet/tls"
  13. )
  14. type dialerConf struct {
  15. net.Destination
  16. *internet.MemoryStreamConfig
  17. }
  18. var (
  19. RunningClient map[dialerConf](hyClient.Client)
  20. ClientMutex sync.Mutex
  21. MBps uint64 = 1000000 / 8 // MByte
  22. )
  23. func GetClientTLSConfig(dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*hyClient.TLSConfig, error) {
  24. config := tls.ConfigFromStreamSettings(streamSettings)
  25. if config == nil {
  26. return nil, newError(Hy2MustNeedTLS)
  27. }
  28. tlsConfig := config.GetTLSConfig(tls.WithDestination(dest))
  29. return &hyClient.TLSConfig{
  30. RootCAs: tlsConfig.RootCAs,
  31. ServerName: tlsConfig.ServerName,
  32. InsecureSkipVerify: tlsConfig.InsecureSkipVerify,
  33. VerifyPeerCertificate: tlsConfig.VerifyPeerCertificate,
  34. }, nil
  35. }
  36. func ResolveAddress(dest net.Destination) (net.Addr, error) {
  37. var destAddr *net.UDPAddr
  38. if dest.Address.Family().IsIP() {
  39. destAddr = &net.UDPAddr{
  40. IP: dest.Address.IP(),
  41. Port: int(dest.Port),
  42. }
  43. } else {
  44. addr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
  45. if err != nil {
  46. return nil, err
  47. }
  48. destAddr = addr
  49. }
  50. return destAddr, nil
  51. }
  52. type connFactory struct {
  53. hyClient.ConnFactory
  54. NewFunc func(addr net.Addr) (net.PacketConn, error)
  55. }
  56. func (f *connFactory) New(addr net.Addr) (net.PacketConn, error) {
  57. return f.NewFunc(addr)
  58. }
  59. func NewHyClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) (hyClient.Client, error) {
  60. tlsConfig, err := GetClientTLSConfig(dest, streamSettings)
  61. if err != nil {
  62. return nil, err
  63. }
  64. serverAddr, err := ResolveAddress(dest)
  65. if err != nil {
  66. return nil, err
  67. }
  68. config := streamSettings.ProtocolSettings.(*Config)
  69. client, _, err := hyClient.NewClient(&hyClient.Config{
  70. Auth: config.GetPassword(),
  71. TLSConfig: *tlsConfig,
  72. ServerAddr: serverAddr,
  73. ConnFactory: &connFactory{
  74. NewFunc: func(addr net.Addr) (net.PacketConn, error) {
  75. rawConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
  76. IP: []byte{0, 0, 0, 0},
  77. Port: 0,
  78. }, streamSettings.SocketSettings)
  79. if err != nil {
  80. return nil, err
  81. }
  82. return rawConn.(*net.UDPConn), nil
  83. },
  84. },
  85. BandwidthConfig: hyClient.BandwidthConfig{MaxTx: config.Congestion.GetUpMbps() * MBps, MaxRx: config.GetCongestion().GetDownMbps() * MBps},
  86. })
  87. if err != nil {
  88. return nil, err
  89. }
  90. return client, nil
  91. }
  92. func CloseHyClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) error {
  93. ClientMutex.Lock()
  94. defer ClientMutex.Unlock()
  95. client, found := RunningClient[dialerConf{dest, streamSettings}]
  96. if found {
  97. delete(RunningClient, dialerConf{dest, streamSettings})
  98. return client.Close()
  99. }
  100. return nil
  101. }
  102. func GetHyClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) (hyClient.Client, error) {
  103. var err error
  104. var client hyClient.Client
  105. ClientMutex.Lock()
  106. client, found := RunningClient[dialerConf{dest, streamSettings}]
  107. ClientMutex.Unlock()
  108. if !found || !CheckHyClientHealthy(client) {
  109. if found {
  110. // retry
  111. CloseHyClient(dest, streamSettings)
  112. }
  113. client, err = NewHyClient(dest, streamSettings)
  114. if err != nil {
  115. return nil, err
  116. }
  117. ClientMutex.Lock()
  118. RunningClient[dialerConf{dest, streamSettings}] = client
  119. ClientMutex.Unlock()
  120. }
  121. return client, nil
  122. }
  123. func CheckHyClientHealthy(client hyClient.Client) bool {
  124. quicConn := client.GetQuicConn()
  125. if quicConn == nil {
  126. return false
  127. }
  128. select {
  129. case <-quicConn.Context().Done():
  130. return false
  131. default:
  132. }
  133. return true
  134. }
  135. func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
  136. config := streamSettings.ProtocolSettings.(*Config)
  137. client, err := GetHyClient(dest, streamSettings)
  138. if err != nil {
  139. CloseHyClient(dest, streamSettings)
  140. return nil, err
  141. }
  142. quicConn := client.GetQuicConn()
  143. conn := &HyConn{
  144. local: quicConn.LocalAddr(),
  145. remote: quicConn.RemoteAddr(),
  146. }
  147. outbound := session.OutboundFromContext(ctx)
  148. network := net.Network_TCP
  149. if outbound != nil {
  150. network = outbound.Target.Network
  151. }
  152. if network == net.Network_UDP && config.GetUseUdpExtension() { // only hysteria2 can use udpExtension
  153. conn.IsUDPExtension = true
  154. conn.IsServer = false
  155. conn.ClientUDPSession, err = client.UDP()
  156. if err != nil {
  157. CloseHyClient(dest, streamSettings)
  158. return nil, err
  159. }
  160. return conn, nil
  161. }
  162. conn.stream, err = client.OpenStream()
  163. if err != nil {
  164. CloseHyClient(dest, streamSettings)
  165. return nil, err
  166. }
  167. // write TCP frame type
  168. frameSize := quicvarint.Len(hyProtocol.FrameTypeTCPRequest)
  169. buf := make([]byte, frameSize)
  170. hyProtocol.VarintPut(buf, hyProtocol.FrameTypeTCPRequest)
  171. _, err = conn.stream.Write(buf)
  172. if err != nil {
  173. CloseHyClient(dest, streamSettings)
  174. return nil, err
  175. }
  176. return conn, nil
  177. }
  178. func init() {
  179. RunningClient = make(map[dialerConf]hyClient.Client)
  180. common.Must(internet.RegisterTransportDialer(protocolName, Dial))
  181. }