httpDialer.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. package transportcommon
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "net/http"
  9. "sync"
  10. "time"
  11. "github.com/v2fly/v2ray-core/v5/transport/internet/security"
  12. "golang.org/x/net/http2"
  13. )
  14. type DialerFunc func(ctx context.Context, addr string) (net.Conn, error)
  15. // NewALPNAwareHTTPRoundTripper creates an instance of RoundTripper that dial to remote HTTPS endpoint with
  16. // an alternative version of TLS implementation.
  17. func NewALPNAwareHTTPRoundTripper(ctx context.Context, dialer DialerFunc,
  18. backdropTransport http.RoundTripper) http.RoundTripper {
  19. rtImpl := &alpnAwareHTTPRoundTripperImpl{
  20. connectWithH1: map[string]bool{},
  21. backdropTransport: backdropTransport,
  22. pendingConn: map[pendingConnKey]*unclaimedConnection{},
  23. dialer: dialer,
  24. ctx: ctx,
  25. }
  26. rtImpl.init()
  27. return rtImpl
  28. }
  29. type alpnAwareHTTPRoundTripperImpl struct {
  30. accessConnectWithH1 sync.Mutex
  31. connectWithH1 map[string]bool
  32. httpsH1Transport http.RoundTripper
  33. httpsH2Transport http.RoundTripper
  34. backdropTransport http.RoundTripper
  35. accessDialingConnection sync.Mutex
  36. pendingConn map[pendingConnKey]*unclaimedConnection
  37. ctx context.Context
  38. dialer DialerFunc
  39. }
  40. type pendingConnKey struct {
  41. isH2 bool
  42. dest string
  43. }
  44. var errEAGAIN = errors.New("incorrect ALPN negotiated, try again with another ALPN")
  45. var errEAGAINTooMany = errors.New("incorrect ALPN negotiated")
  46. var errExpired = errors.New("connection have expired")
  47. func (r *alpnAwareHTTPRoundTripperImpl) RoundTrip(req *http.Request) (*http.Response, error) {
  48. if req.URL.Scheme != "https" {
  49. return r.backdropTransport.RoundTrip(req)
  50. }
  51. for retryCount := 0; retryCount < 5; retryCount++ {
  52. effectivePort := req.URL.Port()
  53. if effectivePort == "" {
  54. effectivePort = "443"
  55. }
  56. if r.getShouldConnectWithH1(fmt.Sprintf("%v:%v", req.URL.Hostname(), effectivePort)) {
  57. resp, err := r.httpsH1Transport.RoundTrip(req)
  58. if errors.Is(err, errEAGAIN) {
  59. continue
  60. }
  61. return resp, err
  62. }
  63. resp, err := r.httpsH2Transport.RoundTrip(req)
  64. if errors.Is(err, errEAGAIN) {
  65. continue
  66. }
  67. return resp, err
  68. }
  69. return nil, errEAGAINTooMany
  70. }
  71. func (r *alpnAwareHTTPRoundTripperImpl) getShouldConnectWithH1(domainName string) bool {
  72. r.accessConnectWithH1.Lock()
  73. defer r.accessConnectWithH1.Unlock()
  74. if value, set := r.connectWithH1[domainName]; set {
  75. return value
  76. }
  77. return false
  78. }
  79. func (r *alpnAwareHTTPRoundTripperImpl) setShouldConnectWithH1(domainName string) {
  80. r.accessConnectWithH1.Lock()
  81. defer r.accessConnectWithH1.Unlock()
  82. r.connectWithH1[domainName] = true
  83. }
  84. func (r *alpnAwareHTTPRoundTripperImpl) clearShouldConnectWithH1(domainName string) {
  85. r.accessConnectWithH1.Lock()
  86. defer r.accessConnectWithH1.Unlock()
  87. r.connectWithH1[domainName] = false
  88. }
  89. func getPendingConnectionID(dest string, alpnIsH2 bool) pendingConnKey {
  90. return pendingConnKey{isH2: alpnIsH2, dest: dest}
  91. }
  92. func (r *alpnAwareHTTPRoundTripperImpl) putConn(addr string, alpnIsH2 bool, conn net.Conn) {
  93. connId := getPendingConnectionID(addr, alpnIsH2)
  94. r.pendingConn[connId] = NewUnclaimedConnection(conn, time.Minute)
  95. }
  96. func (r *alpnAwareHTTPRoundTripperImpl) getConn(addr string, alpnIsH2 bool) net.Conn {
  97. connId := getPendingConnectionID(addr, alpnIsH2)
  98. if conn, ok := r.pendingConn[connId]; ok {
  99. delete(r.pendingConn, connId)
  100. if claimedConnection, err := conn.claimConnection(); err == nil {
  101. return claimedConnection
  102. }
  103. }
  104. return nil
  105. }
  106. func (r *alpnAwareHTTPRoundTripperImpl) dialOrGetTLSWithExpectedALPN(ctx context.Context, addr string, expectedH2 bool) (net.Conn, error) {
  107. r.accessDialingConnection.Lock()
  108. defer r.accessDialingConnection.Unlock()
  109. if r.getShouldConnectWithH1(addr) == expectedH2 {
  110. return nil, errEAGAIN
  111. }
  112. //Get a cached connection if possible to reduce preflight connection closed without sending data
  113. if gconn := r.getConn(addr, expectedH2); gconn != nil {
  114. return gconn, nil
  115. }
  116. conn, err := r.dialTLS(ctx, addr)
  117. if err != nil {
  118. return nil, err
  119. }
  120. protocol := ""
  121. if connAPLNGetter, ok := conn.(security.ConnectionApplicationProtocol); ok {
  122. connectionALPN, err := connAPLNGetter.GetConnectionApplicationProtocol()
  123. if err != nil {
  124. return nil, newError("failed to get connection ALPN").Base(err).AtWarning()
  125. }
  126. protocol = connectionALPN
  127. }
  128. protocolIsH2 := protocol == http2.NextProtoTLS
  129. if protocolIsH2 == expectedH2 {
  130. return conn, err
  131. }
  132. r.putConn(addr, protocolIsH2, conn)
  133. if protocolIsH2 {
  134. r.clearShouldConnectWithH1(addr)
  135. } else {
  136. r.setShouldConnectWithH1(addr)
  137. }
  138. return nil, errEAGAIN
  139. }
  140. func (r *alpnAwareHTTPRoundTripperImpl) dialTLS(ctx context.Context, addr string) (net.Conn, error) {
  141. return r.dialer(r.ctx, addr)
  142. }
  143. func (r *alpnAwareHTTPRoundTripperImpl) init() {
  144. r.httpsH2Transport = &http2.Transport{
  145. DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
  146. return r.dialOrGetTLSWithExpectedALPN(context.Background(), addr, true)
  147. },
  148. }
  149. r.httpsH1Transport = &http.Transport{
  150. DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
  151. return r.dialOrGetTLSWithExpectedALPN(ctx, addr, false)
  152. },
  153. }
  154. }
  155. func NewUnclaimedConnection(conn net.Conn, expireTime time.Duration) *unclaimedConnection {
  156. c := &unclaimedConnection{
  157. Conn: conn,
  158. }
  159. time.AfterFunc(expireTime, c.tick)
  160. return c
  161. }
  162. type unclaimedConnection struct {
  163. net.Conn
  164. claimed bool
  165. access sync.Mutex
  166. }
  167. func (c *unclaimedConnection) claimConnection() (net.Conn, error) {
  168. c.access.Lock()
  169. defer c.access.Unlock()
  170. if !c.claimed {
  171. c.claimed = true
  172. return c.Conn, nil
  173. }
  174. return nil, errExpired
  175. }
  176. func (c *unclaimedConnection) tick() {
  177. c.access.Lock()
  178. defer c.access.Unlock()
  179. if !c.claimed {
  180. c.claimed = true
  181. c.Conn.Close()
  182. c.Conn = nil
  183. }
  184. }