httpDialer.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. package transportcommon
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "net/http"
  9. "sync"
  10. "time"
  11. "golang.org/x/net/http2"
  12. "github.com/v2fly/v2ray-core/v5/transport/internet/security"
  13. )
  14. type DialerFunc func(ctx context.Context, addr string) (net.Conn, error)
  15. func NewALPNAwareHTTPRoundTripper(ctx context.Context, dialer DialerFunc,
  16. backdropTransport http.RoundTripper,
  17. ) http.RoundTripper {
  18. return NewALPNAwareHTTPRoundTripperWithH2Pool(ctx, dialer, backdropTransport, 1)
  19. }
  20. // NewALPNAwareHTTPRoundTripperWithH2Pool creates an instance of RoundTripper that dial to remote HTTPS endpoint with
  21. // an alternative version of TLS implementation.
  22. func NewALPNAwareHTTPRoundTripperWithH2Pool(ctx context.Context, dialer DialerFunc,
  23. backdropTransport http.RoundTripper,
  24. h2PoolSize int,
  25. ) http.RoundTripper {
  26. rtImpl := &alpnAwareHTTPRoundTripperImpl{
  27. connectWithH1: map[string]bool{},
  28. backdropTransport: backdropTransport,
  29. pendingConn: map[pendingConnKey]*unclaimedConnection{},
  30. dialer: dialer,
  31. ctx: ctx,
  32. }
  33. rtImpl.init()
  34. return rtImpl
  35. }
  36. type alpnAwareHTTPRoundTripperImpl struct {
  37. accessConnectWithH1 sync.Mutex
  38. connectWithH1 map[string]bool
  39. httpsH1Transport http.RoundTripper
  40. httpsH2Transport http.RoundTripper
  41. backdropTransport http.RoundTripper
  42. accessDialingConnection sync.Mutex
  43. pendingConn map[pendingConnKey]*unclaimedConnection
  44. ctx context.Context
  45. dialer DialerFunc
  46. h2PoolSize int
  47. }
  48. type pendingConnKey struct {
  49. isH2 bool
  50. dest string
  51. }
  52. var (
  53. errEAGAIN = errors.New("incorrect ALPN negotiated, try again with another ALPN")
  54. errEAGAINTooMany = errors.New("incorrect ALPN negotiated")
  55. errExpired = errors.New("connection have expired")
  56. )
  57. func (r *alpnAwareHTTPRoundTripperImpl) RoundTrip(req *http.Request) (*http.Response, error) {
  58. if req.URL.Scheme != "https" {
  59. return r.backdropTransport.RoundTrip(req)
  60. }
  61. for retryCount := 0; retryCount < 5; retryCount++ {
  62. effectivePort := req.URL.Port()
  63. if effectivePort == "" {
  64. effectivePort = "443"
  65. }
  66. if r.getShouldConnectWithH1(fmt.Sprintf("%v:%v", req.URL.Hostname(), effectivePort)) {
  67. resp, err := r.httpsH1Transport.RoundTrip(req)
  68. if errors.Is(err, errEAGAIN) {
  69. continue
  70. }
  71. return resp, err
  72. }
  73. resp, err := r.httpsH2Transport.RoundTrip(req)
  74. if errors.Is(err, errEAGAIN) {
  75. continue
  76. }
  77. return resp, err
  78. }
  79. return nil, errEAGAINTooMany
  80. }
  81. func (r *alpnAwareHTTPRoundTripperImpl) getShouldConnectWithH1(domainName string) bool {
  82. r.accessConnectWithH1.Lock()
  83. defer r.accessConnectWithH1.Unlock()
  84. if value, set := r.connectWithH1[domainName]; set {
  85. return value
  86. }
  87. return false
  88. }
  89. func (r *alpnAwareHTTPRoundTripperImpl) setShouldConnectWithH1(domainName string) {
  90. r.accessConnectWithH1.Lock()
  91. defer r.accessConnectWithH1.Unlock()
  92. r.connectWithH1[domainName] = true
  93. }
  94. func (r *alpnAwareHTTPRoundTripperImpl) clearShouldConnectWithH1(domainName string) {
  95. r.accessConnectWithH1.Lock()
  96. defer r.accessConnectWithH1.Unlock()
  97. r.connectWithH1[domainName] = false
  98. }
  99. func getPendingConnectionID(dest string, alpnIsH2 bool) pendingConnKey {
  100. return pendingConnKey{isH2: alpnIsH2, dest: dest}
  101. }
  102. func (r *alpnAwareHTTPRoundTripperImpl) putConn(addr string, alpnIsH2 bool, conn net.Conn) {
  103. connID := getPendingConnectionID(addr, alpnIsH2)
  104. r.pendingConn[connID] = NewUnclaimedConnection(conn, time.Minute)
  105. }
  106. func (r *alpnAwareHTTPRoundTripperImpl) getConn(addr string, alpnIsH2 bool) net.Conn {
  107. connID := getPendingConnectionID(addr, alpnIsH2)
  108. if conn, ok := r.pendingConn[connID]; ok {
  109. delete(r.pendingConn, connID)
  110. if claimedConnection, err := conn.claimConnection(); err == nil {
  111. return claimedConnection
  112. }
  113. }
  114. return nil
  115. }
  116. func (r *alpnAwareHTTPRoundTripperImpl) dialOrGetTLSWithExpectedALPN(ctx context.Context, addr string, expectedH2 bool) (net.Conn, error) {
  117. r.accessDialingConnection.Lock()
  118. defer r.accessDialingConnection.Unlock()
  119. if r.getShouldConnectWithH1(addr) == expectedH2 {
  120. return nil, errEAGAIN
  121. }
  122. // Get a cached connection if possible to reduce preflight connection closed without sending data
  123. if gconn := r.getConn(addr, expectedH2); gconn != nil {
  124. return gconn, nil
  125. }
  126. conn, err := r.dialTLS(ctx, addr)
  127. if err != nil {
  128. return nil, err
  129. }
  130. protocol := ""
  131. if connAPLNGetter, ok := conn.(security.ConnectionApplicationProtocol); ok {
  132. connectionALPN, err := connAPLNGetter.GetConnectionApplicationProtocol()
  133. if err != nil {
  134. return nil, newError("failed to get connection ALPN").Base(err).AtWarning()
  135. }
  136. protocol = connectionALPN
  137. }
  138. protocolIsH2 := protocol == http2.NextProtoTLS
  139. if protocolIsH2 == expectedH2 {
  140. return conn, err
  141. }
  142. r.putConn(addr, protocolIsH2, conn)
  143. if protocolIsH2 {
  144. r.clearShouldConnectWithH1(addr)
  145. } else {
  146. r.setShouldConnectWithH1(addr)
  147. }
  148. return nil, errEAGAIN
  149. }
  150. func (r *alpnAwareHTTPRoundTripperImpl) dialTLS(ctx context.Context, addr string) (net.Conn, error) {
  151. _ = ctx
  152. return r.dialer(r.ctx, addr)
  153. }
  154. func (r *alpnAwareHTTPRoundTripperImpl) init() {
  155. if r.h2PoolSize >= 2 {
  156. r.httpsH2Transport = newH2TransportPool(int64(r.h2PoolSize), func() *http2.Transport {
  157. return &http2.Transport{
  158. DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  159. return r.dialOrGetTLSWithExpectedALPN(context.Background(), addr, true)
  160. },
  161. }
  162. })
  163. } else {
  164. r.httpsH2Transport = &http2.Transport{
  165. DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  166. return r.dialOrGetTLSWithExpectedALPN(context.Background(), addr, true)
  167. },
  168. }
  169. }
  170. r.httpsH1Transport = &http.Transport{
  171. DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
  172. return r.dialOrGetTLSWithExpectedALPN(ctx, addr, false)
  173. },
  174. }
  175. }
  176. func NewUnclaimedConnection(conn net.Conn, expireTime time.Duration) *unclaimedConnection {
  177. c := &unclaimedConnection{
  178. Conn: conn,
  179. }
  180. time.AfterFunc(expireTime, c.tick)
  181. return c
  182. }
  183. type unclaimedConnection struct {
  184. net.Conn
  185. claimed bool
  186. access sync.Mutex
  187. }
  188. func (c *unclaimedConnection) claimConnection() (net.Conn, error) {
  189. c.access.Lock()
  190. defer c.access.Unlock()
  191. if !c.claimed {
  192. c.claimed = true
  193. return c.Conn, nil
  194. }
  195. return nil, errExpired
  196. }
  197. func (c *unclaimedConnection) tick() {
  198. c.access.Lock()
  199. defer c.access.Unlock()
  200. if !c.claimed {
  201. c.claimed = true
  202. c.Conn.Close()
  203. c.Conn = nil
  204. }
  205. }
  206. type h2TransportFactory func() *http2.Transport
  207. func newH2TransportPool(size int64, h2factory h2TransportFactory) *h2TransportPool {
  208. return &h2TransportPool{
  209. pool: make([]*http2.Transport, size),
  210. size: size,
  211. h2factory: h2factory,
  212. }
  213. }
  214. type h2TransportPool struct {
  215. pool []*http2.Transport
  216. h2factory h2TransportFactory
  217. usageCount int64
  218. size int64
  219. }
  220. func (h *h2TransportPool) RoundTrip(request *http.Request) (*http.Response, error) {
  221. currentSlot := h.usageCount % h.size
  222. h.usageCount++
  223. if h.pool[currentSlot] == nil {
  224. h.pool[currentSlot] = h.h2factory()
  225. }
  226. return h.pool[currentSlot].RoundTrip(request)
  227. }