dialer.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. package hysteria2
  2. import (
  3. "context"
  4. "sync"
  5. hyClient "github.com/apernet/hysteria/core/v2/client"
  6. hyProtocol "github.com/apernet/hysteria/core/v2/international/protocol"
  7. "github.com/apernet/quic-go/quicvarint"
  8. "github.com/v2fly/v2ray-core/v5/common"
  9. "github.com/v2fly/v2ray-core/v5/common/net"
  10. "github.com/v2fly/v2ray-core/v5/common/session"
  11. "github.com/v2fly/v2ray-core/v5/transport/internet"
  12. "github.com/v2fly/v2ray-core/v5/transport/internet/tls"
  13. )
  14. type dialerConf struct {
  15. net.Destination
  16. *internet.MemoryStreamConfig
  17. }
  18. var RunningClient map[dialerConf](hyClient.Client)
  19. var ClientMutex sync.Mutex
  20. var MBps uint64 = 1000000 / 8 // MByte
  21. func GetClientTLSConfig(dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*hyClient.TLSConfig, error) {
  22. config := tls.ConfigFromStreamSettings(streamSettings)
  23. if config == nil {
  24. return nil, newError(Hy2MustNeedTLS)
  25. }
  26. tlsConfig := config.GetTLSConfig(tls.WithDestination(dest))
  27. return &hyClient.TLSConfig{
  28. RootCAs: tlsConfig.RootCAs,
  29. ServerName: tlsConfig.ServerName,
  30. InsecureSkipVerify: tlsConfig.InsecureSkipVerify,
  31. VerifyPeerCertificate: tlsConfig.VerifyPeerCertificate,
  32. }, nil
  33. }
  34. func ResolveAddress(dest net.Destination) (net.Addr, error) {
  35. var destAddr *net.UDPAddr
  36. if dest.Address.Family().IsIP() {
  37. destAddr = &net.UDPAddr{
  38. IP: dest.Address.IP(),
  39. Port: int(dest.Port),
  40. }
  41. } else {
  42. addr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
  43. if err != nil {
  44. return nil, err
  45. }
  46. destAddr = addr
  47. }
  48. return destAddr, nil
  49. }
  50. type connFactory struct {
  51. hyClient.ConnFactory
  52. NewFunc func(addr net.Addr) (net.PacketConn, error)
  53. }
  54. func (f *connFactory) New(addr net.Addr) (net.PacketConn, error) {
  55. return f.NewFunc(addr)
  56. }
  57. func NewHyClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) (hyClient.Client, error) {
  58. tlsConfig, err := GetClientTLSConfig(dest, streamSettings)
  59. if err != nil {
  60. return nil, err
  61. }
  62. serverAddr, err := ResolveAddress(dest)
  63. if err != nil {
  64. return nil, err
  65. }
  66. config := streamSettings.ProtocolSettings.(*Config)
  67. client, _, err := hyClient.NewClient(&hyClient.Config{
  68. Auth: config.GetPassword(),
  69. TLSConfig: *tlsConfig,
  70. ServerAddr: serverAddr,
  71. ConnFactory: &connFactory{
  72. NewFunc: func(addr net.Addr) (net.PacketConn, error) {
  73. rawConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
  74. IP: []byte{0, 0, 0, 0},
  75. Port: 0,
  76. }, streamSettings.SocketSettings)
  77. if err != nil {
  78. return nil, err
  79. }
  80. return rawConn.(*net.UDPConn), nil
  81. },
  82. },
  83. BandwidthConfig: hyClient.BandwidthConfig{MaxTx: config.Congestion.GetUpMbps() * MBps, MaxRx: config.GetCongestion().GetDownMbps() * MBps},
  84. })
  85. if err != nil {
  86. return nil, err
  87. }
  88. return client, nil
  89. }
  90. func CloseHyClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) error {
  91. ClientMutex.Lock()
  92. defer ClientMutex.Unlock()
  93. client, found := RunningClient[dialerConf{dest, streamSettings}]
  94. if found {
  95. delete(RunningClient, dialerConf{dest, streamSettings})
  96. return client.Close()
  97. }
  98. return nil
  99. }
  100. func GetHyClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) (hyClient.Client, error) {
  101. var err error
  102. var client hyClient.Client
  103. ClientMutex.Lock()
  104. client, found := RunningClient[dialerConf{dest, streamSettings}]
  105. ClientMutex.Unlock()
  106. if !found || !CheckHyClientHealthy(client) {
  107. if found {
  108. // retry
  109. CloseHyClient(dest, streamSettings)
  110. }
  111. client, err = NewHyClient(dest, streamSettings)
  112. if err != nil {
  113. return nil, err
  114. }
  115. ClientMutex.Lock()
  116. RunningClient[dialerConf{dest, streamSettings}] = client
  117. ClientMutex.Unlock()
  118. }
  119. return client, nil
  120. }
  121. func CheckHyClientHealthy(client hyClient.Client) bool {
  122. quicConn := client.GetQuicConn()
  123. if quicConn == nil {
  124. return false
  125. }
  126. select {
  127. case <-quicConn.Context().Done():
  128. return false
  129. default:
  130. }
  131. return true
  132. }
  133. func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
  134. config := streamSettings.ProtocolSettings.(*Config)
  135. client, err := GetHyClient(dest, streamSettings)
  136. if err != nil {
  137. CloseHyClient(dest, streamSettings)
  138. return nil, err
  139. }
  140. quicConn := client.GetQuicConn()
  141. conn := &HyConn{
  142. local: quicConn.LocalAddr(),
  143. remote: quicConn.RemoteAddr(),
  144. }
  145. outbound := session.OutboundFromContext(ctx)
  146. network := net.Network_TCP
  147. if outbound != nil {
  148. network = outbound.Target.Network
  149. }
  150. if network == net.Network_UDP && config.GetUseUdpExtension() { // only hysteria2 can use udpExtension
  151. conn.IsUDPExtension = true
  152. conn.IsServer = false
  153. conn.ClientUDPSession, err = client.UDP()
  154. if err != nil {
  155. CloseHyClient(dest, streamSettings)
  156. return nil, err
  157. }
  158. return conn, nil
  159. }
  160. conn.stream, err = client.OpenStream()
  161. if err != nil {
  162. CloseHyClient(dest, streamSettings)
  163. return nil, err
  164. }
  165. // write TCP frame type
  166. frameSize := int(quicvarint.Len(hyProtocol.FrameTypeTCPRequest))
  167. buf := make([]byte, frameSize)
  168. hyProtocol.VarintPut(buf, hyProtocol.FrameTypeTCPRequest)
  169. _, err = conn.stream.Write(buf)
  170. if err != nil {
  171. CloseHyClient(dest, streamSettings)
  172. return nil, err
  173. }
  174. return conn, nil
  175. }
  176. func init() {
  177. RunningClient = make(map[dialerConf]hyClient.Client)
  178. common.Must(internet.RegisterTransportDialer(protocolName, Dial))
  179. }