connection_test.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. package kcp_test
  2. import (
  3. "crypto/rand"
  4. "io"
  5. "net"
  6. "testing"
  7. "time"
  8. v2net "v2ray.com/core/common/net"
  9. "v2ray.com/core/testing/assert"
  10. "v2ray.com/core/transport/internet"
  11. "v2ray.com/core/transport/internet/authenticators/srtp"
  12. . "v2ray.com/core/transport/internet/kcp"
  13. )
  14. type NoOpWriteCloser struct{}
  15. func (this *NoOpWriteCloser) Write(b []byte) (int, error) {
  16. return len(b), nil
  17. }
  18. func (this *NoOpWriteCloser) Close() error {
  19. return nil
  20. }
  21. func TestConnectionReadTimeout(t *testing.T) {
  22. assert := assert.On(t)
  23. conn := NewConnection(1, &NoOpWriteCloser{}, nil, nil, NewSimpleAuthenticator(), &Config{})
  24. conn.SetReadDeadline(time.Now().Add(time.Second))
  25. b := make([]byte, 1024)
  26. nBytes, err := conn.Read(b)
  27. assert.Int(nBytes).Equals(0)
  28. assert.Error(err).IsNotNil()
  29. conn.Terminate()
  30. }
  31. func TestConnectionReadWrite(t *testing.T) {
  32. assert := assert.On(t)
  33. upReader, upWriter := io.Pipe()
  34. downReader, downWriter := io.Pipe()
  35. auth := internet.NewAuthenticatorChain(srtp.SRTPFactory{}.Create(nil), NewSimpleAuthenticator())
  36. connClient := NewConnection(1, upWriter, &net.UDPAddr{IP: v2net.LocalHostIP.IP(), Port: 1}, &net.UDPAddr{IP: v2net.LocalHostIP.IP(), Port: 2}, auth, &Config{})
  37. connClient.FetchInputFrom(downReader)
  38. connServer := NewConnection(1, downWriter, &net.UDPAddr{IP: v2net.LocalHostIP.IP(), Port: 2}, &net.UDPAddr{IP: v2net.LocalHostIP.IP(), Port: 1}, auth, &Config{})
  39. connServer.FetchInputFrom(upReader)
  40. totalWritten := 1024 * 1024
  41. clientSend := make([]byte, totalWritten)
  42. rand.Read(clientSend)
  43. go func() {
  44. nBytes, err := connClient.Write(clientSend)
  45. assert.Int(nBytes).Equals(totalWritten)
  46. assert.Error(err).IsNil()
  47. }()
  48. serverReceived := make([]byte, totalWritten)
  49. totalRead := 0
  50. for totalRead < totalWritten {
  51. nBytes, err := connServer.Read(serverReceived[totalRead:])
  52. assert.Error(err).IsNil()
  53. totalRead += nBytes
  54. }
  55. assert.Bytes(serverReceived).Equals(clientSend)
  56. connClient.Close()
  57. connServer.Close()
  58. for connClient.State() != StateTerminated || connServer.State() != StateTerminated {
  59. time.Sleep(time.Second)
  60. }
  61. }