auth.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package crypto
  2. import (
  3. "crypto/cipher"
  4. "errors"
  5. "io"
  6. "v2ray.com/core/common"
  7. "v2ray.com/core/common/buf"
  8. "v2ray.com/core/common/serial"
  9. )
  10. var (
  11. ErrAuthenticationFailed = errors.New("Authentication failed.")
  12. errInsufficientBuffer = errors.New("Insufficient buffer.")
  13. errInvalidNonce = errors.New("Invalid nonce.")
  14. )
  15. type BytesGenerator interface {
  16. Next() []byte
  17. }
  18. type NoOpBytesGenerator struct {
  19. buffer [1]byte
  20. }
  21. func (v NoOpBytesGenerator) Next() []byte {
  22. return v.buffer[:0]
  23. }
  24. type StaticBytesGenerator struct {
  25. Content []byte
  26. }
  27. func (v StaticBytesGenerator) Next() []byte {
  28. return v.Content
  29. }
  30. type Authenticator interface {
  31. NonceSize() int
  32. Overhead() int
  33. Open(dst, cipherText []byte) ([]byte, error)
  34. Seal(dst, plainText []byte) ([]byte, error)
  35. }
  36. type AEADAuthenticator struct {
  37. cipher.AEAD
  38. NonceGenerator BytesGenerator
  39. AdditionalDataGenerator BytesGenerator
  40. }
  41. func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) {
  42. iv := v.NonceGenerator.Next()
  43. if len(iv) != v.AEAD.NonceSize() {
  44. return nil, errInvalidNonce
  45. }
  46. additionalData := v.AdditionalDataGenerator.Next()
  47. return v.AEAD.Open(dst, iv, cipherText, additionalData)
  48. }
  49. func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
  50. iv := v.NonceGenerator.Next()
  51. if len(iv) != v.AEAD.NonceSize() {
  52. return nil, errInvalidNonce
  53. }
  54. additionalData := v.AdditionalDataGenerator.Next()
  55. return v.AEAD.Seal(dst, iv, plainText, additionalData), nil
  56. }
  57. type AuthenticationReader struct {
  58. auth Authenticator
  59. buffer *buf.Buffer
  60. reader io.Reader
  61. chunk []byte
  62. aggressive bool
  63. }
  64. func NewAuthenticationReader(auth Authenticator, reader io.Reader, aggressive bool) *AuthenticationReader {
  65. return &AuthenticationReader{
  66. auth: auth,
  67. buffer: buf.NewLocal(32 * 1024),
  68. reader: reader,
  69. aggressive: aggressive,
  70. }
  71. }
  72. func (v *AuthenticationReader) NextChunk() error {
  73. if v.buffer.Len() < 2 {
  74. return errInsufficientBuffer
  75. }
  76. size := int(serial.BytesToUint16(v.buffer.BytesTo(2)))
  77. if size > v.buffer.Len()-2 {
  78. return errInsufficientBuffer
  79. }
  80. if size == v.auth.Overhead() {
  81. return io.EOF
  82. }
  83. if size < v.auth.Overhead() {
  84. return errors.New("AuthenticationReader: invalid packet size.")
  85. }
  86. cipherChunk := v.buffer.BytesRange(2, size+2)
  87. plainChunk, err := v.auth.Open(cipherChunk[:0], cipherChunk)
  88. if err != nil {
  89. return err
  90. }
  91. v.chunk = plainChunk
  92. v.buffer.SliceFrom(size + 2)
  93. return nil
  94. }
  95. func (v *AuthenticationReader) CopyChunk(b []byte) int {
  96. if len(v.chunk) == 0 {
  97. return 0
  98. }
  99. nBytes := copy(b, v.chunk)
  100. if nBytes == len(v.chunk) {
  101. v.chunk = nil
  102. } else {
  103. v.chunk = v.chunk[nBytes:]
  104. }
  105. return nBytes
  106. }
  107. func (v *AuthenticationReader) EnsureChunk() error {
  108. for {
  109. err := v.NextChunk()
  110. if err == nil {
  111. return nil
  112. }
  113. if err == errInsufficientBuffer {
  114. if v.buffer.IsEmpty() {
  115. v.buffer.Clear()
  116. } else {
  117. leftover := v.buffer.Bytes()
  118. common.Must(v.buffer.Reset(func(b []byte) (int, error) {
  119. return copy(b, leftover), nil
  120. }))
  121. }
  122. err = v.buffer.AppendSupplier(buf.ReadFrom(v.reader))
  123. if err == nil {
  124. continue
  125. }
  126. }
  127. return err
  128. }
  129. }
  130. func (v *AuthenticationReader) Read(b []byte) (int, error) {
  131. if len(v.chunk) > 0 {
  132. nBytes := v.CopyChunk(b)
  133. return nBytes, nil
  134. }
  135. err := v.EnsureChunk()
  136. if err != nil {
  137. return 0, err
  138. }
  139. totalBytes := v.CopyChunk(b)
  140. for v.aggressive && totalBytes < len(b) {
  141. if err := v.NextChunk(); err != nil {
  142. break
  143. }
  144. totalBytes += v.CopyChunk(b[totalBytes:])
  145. }
  146. return totalBytes, nil
  147. }
  148. type AuthenticationWriter struct {
  149. auth Authenticator
  150. buffer []byte
  151. writer io.Writer
  152. }
  153. func NewAuthenticationWriter(auth Authenticator, writer io.Writer) *AuthenticationWriter {
  154. return &AuthenticationWriter{
  155. auth: auth,
  156. buffer: make([]byte, 32*1024),
  157. writer: writer,
  158. }
  159. }
  160. func (v *AuthenticationWriter) Write(b []byte) (int, error) {
  161. cipherChunk, err := v.auth.Seal(v.buffer[2:2], b)
  162. if err != nil {
  163. return 0, err
  164. }
  165. serial.Uint16ToBytes(uint16(len(cipherChunk)), v.buffer[:0])
  166. _, err = v.writer.Write(v.buffer[:2+len(cipherChunk)])
  167. return len(b), err
  168. }