dialer.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. // +build !confonly
  2. package websocket
  3. import (
  4. "context"
  5. "time"
  6. "github.com/gorilla/websocket"
  7. "github.com/v2fly/v2ray-core/v4/common"
  8. "github.com/v2fly/v2ray-core/v4/common/net"
  9. "github.com/v2fly/v2ray-core/v4/common/session"
  10. "github.com/v2fly/v2ray-core/v4/transport/internet"
  11. "github.com/v2fly/v2ray-core/v4/transport/internet/tls"
  12. )
  13. // Dial dials a WebSocket connection to the given destination.
  14. func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
  15. newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
  16. conn, err := dialWebsocket(ctx, dest, streamSettings)
  17. if err != nil {
  18. return nil, newError("failed to dial WebSocket").Base(err)
  19. }
  20. return internet.Connection(conn), nil
  21. }
  22. func init() {
  23. common.Must(internet.RegisterTransportDialer(protocolName, Dial))
  24. }
  25. func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
  26. wsSettings := streamSettings.ProtocolSettings.(*Config)
  27. dialer := &websocket.Dialer{
  28. NetDial: func(network, addr string) (net.Conn, error) {
  29. return internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
  30. },
  31. ReadBufferSize: 4 * 1024,
  32. WriteBufferSize: 4 * 1024,
  33. HandshakeTimeout: time.Second * 8,
  34. }
  35. protocol := "ws"
  36. if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
  37. protocol = "wss"
  38. dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
  39. }
  40. host := dest.NetAddr()
  41. if (protocol == "ws" && dest.Port == 80) || (protocol == "wss" && dest.Port == 443) {
  42. host = dest.Address.String()
  43. }
  44. uri := protocol + "://" + host + wsSettings.GetNormalizedPath()
  45. conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader())
  46. if err != nil {
  47. var reason string
  48. if resp != nil {
  49. reason = resp.Status
  50. }
  51. return nil, newError("failed to dial to (", uri, "): ", reason).Base(err)
  52. }
  53. return newConnection(conn, conn.RemoteAddr()), nil
  54. }