auth_test.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package crypto_test
  2. import (
  3. "bytes"
  4. "crypto/aes"
  5. "crypto/cipher"
  6. "crypto/rand"
  7. "io"
  8. "testing"
  9. "github.com/google/go-cmp/cmp"
  10. "v2ray.com/core/common"
  11. "v2ray.com/core/common/buf"
  12. . "v2ray.com/core/common/crypto"
  13. "v2ray.com/core/common/protocol"
  14. )
  15. func TestAuthenticationReaderWriter(t *testing.T) {
  16. key := make([]byte, 16)
  17. rand.Read(key)
  18. block, err := aes.NewCipher(key)
  19. common.Must(err)
  20. aead, err := cipher.NewGCM(block)
  21. common.Must(err)
  22. const payloadSize = 1024 * 80
  23. rawPayload := make([]byte, payloadSize)
  24. rand.Read(rawPayload)
  25. payload := buf.MergeBytes(nil, rawPayload)
  26. if r := cmp.Diff(payload.Bytes(), rawPayload); r != "" {
  27. t.Error(r)
  28. }
  29. cache := bytes.NewBuffer(nil)
  30. iv := make([]byte, 12)
  31. rand.Read(iv)
  32. writer := NewAuthenticationWriter(&AEADAuthenticator{
  33. AEAD: aead,
  34. NonceGenerator: GenerateStaticBytes(iv),
  35. AdditionalDataGenerator: GenerateEmptyBytes(),
  36. }, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream, nil)
  37. common.Must(writer.WriteMultiBuffer(payload))
  38. if cache.Len() <= 1024*80 {
  39. t.Error("cache len: ", cache.Len())
  40. }
  41. common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{}))
  42. reader := NewAuthenticationReader(&AEADAuthenticator{
  43. AEAD: aead,
  44. NonceGenerator: GenerateStaticBytes(iv),
  45. AdditionalDataGenerator: GenerateEmptyBytes(),
  46. }, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream, nil)
  47. var mb buf.MultiBuffer
  48. for mb.Len() < payloadSize {
  49. mb2, err := reader.ReadMultiBuffer()
  50. common.Must(err)
  51. mb, _ = buf.MergeMulti(mb, mb2)
  52. }
  53. if mb.Len() != payloadSize {
  54. t.Error("mb len: ", mb.Len())
  55. }
  56. mbContent := make([]byte, payloadSize)
  57. buf.SplitBytes(mb, mbContent)
  58. if r := cmp.Diff(mbContent, rawPayload); r != "" {
  59. t.Error(r)
  60. }
  61. _, err = reader.ReadMultiBuffer()
  62. if err != io.EOF {
  63. t.Error("error: ", err)
  64. }
  65. }
  66. func TestAuthenticationReaderWriterPacket(t *testing.T) {
  67. key := make([]byte, 16)
  68. common.Must2(rand.Read(key))
  69. block, err := aes.NewCipher(key)
  70. common.Must(err)
  71. aead, err := cipher.NewGCM(block)
  72. common.Must(err)
  73. cache := buf.New()
  74. iv := make([]byte, 12)
  75. rand.Read(iv)
  76. writer := NewAuthenticationWriter(&AEADAuthenticator{
  77. AEAD: aead,
  78. NonceGenerator: GenerateStaticBytes(iv),
  79. AdditionalDataGenerator: GenerateEmptyBytes(),
  80. }, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket, nil)
  81. var payload buf.MultiBuffer
  82. pb1 := buf.New()
  83. pb1.Write([]byte("abcd"))
  84. payload = append(payload, pb1)
  85. pb2 := buf.New()
  86. pb2.Write([]byte("efgh"))
  87. payload = append(payload, pb2)
  88. common.Must(writer.WriteMultiBuffer(payload))
  89. if cache.Len() == 0 {
  90. t.Error("cache len: ", cache.Len())
  91. }
  92. common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{}))
  93. reader := NewAuthenticationReader(&AEADAuthenticator{
  94. AEAD: aead,
  95. NonceGenerator: GenerateStaticBytes(iv),
  96. AdditionalDataGenerator: GenerateEmptyBytes(),
  97. }, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket, nil)
  98. mb, err := reader.ReadMultiBuffer()
  99. common.Must(err)
  100. mb, b1 := buf.SplitFirst(mb)
  101. if b1.String() != "abcd" {
  102. t.Error("b1: ", b1.String())
  103. }
  104. mb, b2 := buf.SplitFirst(mb)
  105. if b2.String() != "efgh" {
  106. t.Error("b2: ", b2.String())
  107. }
  108. if !mb.IsEmpty() {
  109. t.Error("not empty")
  110. }
  111. _, err = reader.ReadMultiBuffer()
  112. if err != io.EOF {
  113. t.Error("error: ", err)
  114. }
  115. }