ws_test.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. package websocket_test
  2. import (
  3. "bytes"
  4. "context"
  5. "runtime"
  6. "testing"
  7. "time"
  8. "v2ray.com/core/common"
  9. "v2ray.com/core/common/net"
  10. "v2ray.com/core/common/protocol/tls/cert"
  11. "v2ray.com/core/transport/internet"
  12. "v2ray.com/core/transport/internet/tls"
  13. . "v2ray.com/core/transport/internet/websocket"
  14. . "v2ray.com/ext/assert"
  15. )
  16. func Test_listenWSAndDial(t *testing.T) {
  17. assert := With(t)
  18. listen, err := ListenWS(context.Background(), net.LocalHostIP, 13146, &internet.MemoryStreamConfig{
  19. ProtocolName: "websocket",
  20. ProtocolSettings: &Config{
  21. Path: "ws",
  22. },
  23. }, func(conn internet.Connection) {
  24. go func(c internet.Connection) {
  25. defer c.Close()
  26. var b [1024]byte
  27. n, err := c.Read(b[:])
  28. //common.Must(err)
  29. if err != nil {
  30. return
  31. }
  32. assert(bytes.HasPrefix(b[:n], []byte("Test connection")), IsTrue)
  33. _, err = c.Write([]byte("Response"))
  34. common.Must(err)
  35. }(conn)
  36. })
  37. common.Must(err)
  38. ctx := context.Background()
  39. streamSettings := &internet.MemoryStreamConfig{
  40. ProtocolName: "websocket",
  41. ProtocolSettings: &Config{Path: "ws"},
  42. }
  43. conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146), streamSettings)
  44. common.Must(err)
  45. _, err = conn.Write([]byte("Test connection 1"))
  46. common.Must(err)
  47. var b [1024]byte
  48. n, err := conn.Read(b[:])
  49. common.Must(err)
  50. assert(string(b[:n]), Equals, "Response")
  51. assert(conn.Close(), IsNil)
  52. <-time.After(time.Second * 5)
  53. conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146), streamSettings)
  54. common.Must(err)
  55. _, err = conn.Write([]byte("Test connection 2"))
  56. common.Must(err)
  57. n, err = conn.Read(b[:])
  58. common.Must(err)
  59. assert(string(b[:n]), Equals, "Response")
  60. assert(conn.Close(), IsNil)
  61. assert(listen.Close(), IsNil)
  62. }
  63. func TestDialWithRemoteAddr(t *testing.T) {
  64. assert := With(t)
  65. listen, err := ListenWS(context.Background(), net.LocalHostIP, 13148, &internet.MemoryStreamConfig{
  66. ProtocolName: "websocket",
  67. ProtocolSettings: &Config{
  68. Path: "ws",
  69. },
  70. }, func(conn internet.Connection) {
  71. go func(c internet.Connection) {
  72. defer c.Close()
  73. assert(c.RemoteAddr().String(), HasPrefix, "1.1.1.1")
  74. var b [1024]byte
  75. n, err := c.Read(b[:])
  76. //common.Must(err)
  77. if err != nil {
  78. return
  79. }
  80. assert(bytes.HasPrefix(b[:n], []byte("Test connection")), IsTrue)
  81. _, err = c.Write([]byte("Response"))
  82. common.Must(err)
  83. }(conn)
  84. })
  85. common.Must(err)
  86. conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), 13148), &internet.MemoryStreamConfig{
  87. ProtocolName: "websocket",
  88. ProtocolSettings: &Config{Path: "ws", Header: []*Header{{Key: "X-Forwarded-For", Value: "1.1.1.1"}}},
  89. })
  90. common.Must(err)
  91. _, err = conn.Write([]byte("Test connection 1"))
  92. common.Must(err)
  93. var b [1024]byte
  94. n, err := conn.Read(b[:])
  95. common.Must(err)
  96. assert(string(b[:n]), Equals, "Response")
  97. assert(listen.Close(), IsNil)
  98. }
  99. func Test_listenWSAndDial_TLS(t *testing.T) {
  100. if runtime.GOARCH == "arm64" {
  101. return
  102. }
  103. assert := With(t)
  104. start := time.Now()
  105. streamSettings := &internet.MemoryStreamConfig{
  106. ProtocolName: "websocket",
  107. ProtocolSettings: &Config{
  108. Path: "wss",
  109. },
  110. SecurityType: "tls",
  111. SecuritySettings: &tls.Config{
  112. AllowInsecure: true,
  113. Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
  114. },
  115. }
  116. listen, err := ListenWS(context.Background(), net.LocalHostIP, 13143, streamSettings, func(conn internet.Connection) {
  117. go func() {
  118. _ = conn.Close()
  119. }()
  120. })
  121. common.Must(err)
  122. defer listen.Close()
  123. conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), 13143), streamSettings)
  124. common.Must(err)
  125. _ = conn.Close()
  126. end := time.Now()
  127. assert(end.Before(start.Add(time.Second*5)), IsTrue)
  128. }