dialer.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. package websocket
  2. import (
  3. "context"
  4. "net"
  5. "github.com/gorilla/websocket"
  6. "v2ray.com/core/app/log"
  7. "v2ray.com/core/common"
  8. "v2ray.com/core/common/errors"
  9. v2net "v2ray.com/core/common/net"
  10. "v2ray.com/core/transport/internet"
  11. "v2ray.com/core/transport/internet/internal"
  12. v2tls "v2ray.com/core/transport/internet/tls"
  13. )
  14. var (
  15. globalCache = internal.NewConnectionPool()
  16. )
  17. func Dial(ctx context.Context, dest v2net.Destination) (internet.Connection, error) {
  18. log.Info("WebSocket|Dialer: Creating connection to ", dest)
  19. src := internet.DialerSourceFromContext(ctx)
  20. wsSettings := internet.TransportSettingsFromContext(ctx).(*Config)
  21. id := internal.NewConnectionID(src, dest)
  22. var conn net.Conn
  23. if dest.Network == v2net.Network_TCP && wsSettings.IsConnectionReuse() {
  24. conn = globalCache.Get(id)
  25. }
  26. if conn == nil {
  27. var err error
  28. conn, err = dialWebsocket(ctx, dest)
  29. if err != nil {
  30. return nil, errors.New("dial failed").Path("WebSocket", "Dialer")
  31. }
  32. }
  33. return internal.NewConnection(id, conn, globalCache, internal.ReuseConnection(wsSettings.IsConnectionReuse())), nil
  34. }
  35. func init() {
  36. common.Must(internet.RegisterTransportDialer(internet.TransportProtocol_WebSocket, Dial))
  37. }
  38. func dialWebsocket(ctx context.Context, dest v2net.Destination) (net.Conn, error) {
  39. src := internet.DialerSourceFromContext(ctx)
  40. wsSettings := internet.TransportSettingsFromContext(ctx).(*Config)
  41. commonDial := func(network, addr string) (net.Conn, error) {
  42. return internet.DialSystem(ctx, src, dest)
  43. }
  44. dialer := websocket.Dialer{
  45. NetDial: commonDial,
  46. ReadBufferSize: 32 * 1024,
  47. WriteBufferSize: 32 * 1024,
  48. }
  49. protocol := "ws"
  50. if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
  51. tlsConfig, ok := securitySettings.(*v2tls.Config)
  52. if ok {
  53. protocol = "wss"
  54. dialer.TLSClientConfig = tlsConfig.GetTLSConfig()
  55. if dest.Address.Family().IsDomain() {
  56. dialer.TLSClientConfig.ServerName = dest.Address.Domain()
  57. }
  58. }
  59. }
  60. host := dest.NetAddr()
  61. if (protocol == "ws" && dest.Port == 80) || (protocol == "wss" && dest.Port == 443) {
  62. host = dest.Address.String()
  63. }
  64. uri := protocol + "://" + host + wsSettings.GetNormailzedPath()
  65. conn, resp, err := dialer.Dial(uri, nil)
  66. if err != nil {
  67. var reason string
  68. if resp != nil {
  69. reason = resp.Status
  70. }
  71. return nil, errors.New("failed to dial to (", uri, "): ", reason).Base(err).Path("WebSocket", "Dialer")
  72. }
  73. return &connection{
  74. wsc: conn,
  75. }, nil
  76. }