connection_test.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. package kcp_test
  2. import (
  3. "crypto/rand"
  4. "io"
  5. "net"
  6. "testing"
  7. "time"
  8. v2net "v2ray.com/core/common/net"
  9. "v2ray.com/core/testing/assert"
  10. "v2ray.com/core/transport/internet"
  11. "v2ray.com/core/transport/internet/authenticators/srtp"
  12. . "v2ray.com/core/transport/internet/kcp"
  13. )
  14. type NoOpWriteCloser struct{}
  15. func (this *NoOpWriteCloser) Write(b []byte) (int, error) {
  16. return len(b), nil
  17. }
  18. func (this *NoOpWriteCloser) Close() error {
  19. return nil
  20. }
  21. func TestConnectionReadTimeout(t *testing.T) {
  22. assert := assert.On(t)
  23. conn := NewConnection(1, &NoOpWriteCloser{}, nil, nil, NewSimpleAuthenticator(), &Config{})
  24. conn.SetReadDeadline(time.Now().Add(time.Second))
  25. b := make([]byte, 1024)
  26. nBytes, err := conn.Read(b)
  27. assert.Int(nBytes).Equals(0)
  28. assert.Error(err).IsNotNil()
  29. }
  30. func TestConnectionReadWrite(t *testing.T) {
  31. assert := assert.On(t)
  32. upReader, upWriter := io.Pipe()
  33. downReader, downWriter := io.Pipe()
  34. auth := internet.NewAuthenticatorChain(srtp.SRTPFactory{}.Create(nil), NewSimpleAuthenticator())
  35. connClient := NewConnection(1, upWriter, &net.UDPAddr{IP: v2net.LocalHostIP.IP(), Port: 1}, &net.UDPAddr{IP: v2net.LocalHostIP.IP(), Port: 2}, auth, &Config{})
  36. connClient.FetchInputFrom(downReader)
  37. connServer := NewConnection(1, downWriter, &net.UDPAddr{IP: v2net.LocalHostIP.IP(), Port: 2}, &net.UDPAddr{IP: v2net.LocalHostIP.IP(), Port: 1}, auth, &Config{})
  38. connServer.FetchInputFrom(upReader)
  39. totalWritten := 1024 * 1024
  40. clientSend := make([]byte, totalWritten)
  41. rand.Read(clientSend)
  42. go func() {
  43. nBytes, err := connClient.Write(clientSend)
  44. assert.Int(nBytes).Equals(totalWritten)
  45. assert.Error(err).IsNil()
  46. }()
  47. serverReceived := make([]byte, totalWritten)
  48. totalRead := 0
  49. for totalRead < totalWritten {
  50. nBytes, err := connServer.Read(serverReceived[totalRead:])
  51. assert.Error(err).IsNil()
  52. totalRead += nBytes
  53. }
  54. assert.Bytes(serverReceived).Equals(clientSend)
  55. connClient.Close()
  56. connServer.Close()
  57. for connClient.State() != StateTerminated || connServer.State() != StateTerminated {
  58. time.Sleep(time.Second)
  59. }
  60. }