auth.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. package crypto
  2. import (
  3. "crypto/cipher"
  4. "errors"
  5. "io"
  6. "v2ray.com/core/common/buf"
  7. "v2ray.com/core/common/serial"
  8. )
  9. var (
  10. ErrAuthenticationFailed = errors.New("Authentication failed.")
  11. errInsufficientBuffer = errors.New("Insufficient buffer.")
  12. errInvalidNonce = errors.New("Invalid nonce.")
  13. )
  14. type BytesGenerator interface {
  15. Next() []byte
  16. }
  17. type NoOpBytesGenerator struct {
  18. buffer [1]byte
  19. }
  20. func (v NoOpBytesGenerator) Next() []byte {
  21. return v.buffer[:0]
  22. }
  23. type StaticBytesGenerator struct {
  24. Content []byte
  25. }
  26. func (v StaticBytesGenerator) Next() []byte {
  27. return v.Content
  28. }
  29. type Authenticator interface {
  30. NonceSize() int
  31. Overhead() int
  32. Open(dst, cipherText []byte) ([]byte, error)
  33. Seal(dst, plainText []byte) ([]byte, error)
  34. }
  35. type AEADAuthenticator struct {
  36. cipher.AEAD
  37. NonceGenerator BytesGenerator
  38. AdditionalDataGenerator BytesGenerator
  39. }
  40. func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) {
  41. iv := v.NonceGenerator.Next()
  42. if len(iv) != v.AEAD.NonceSize() {
  43. return nil, errInvalidNonce
  44. }
  45. additionalData := v.AdditionalDataGenerator.Next()
  46. return v.AEAD.Open(dst, iv, cipherText, additionalData)
  47. }
  48. func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
  49. iv := v.NonceGenerator.Next()
  50. if len(iv) != v.AEAD.NonceSize() {
  51. return nil, errInvalidNonce
  52. }
  53. additionalData := v.AdditionalDataGenerator.Next()
  54. return v.AEAD.Seal(dst, iv, plainText, additionalData), nil
  55. }
  56. type AuthenticationReader struct {
  57. auth Authenticator
  58. buffer *buf.Buffer
  59. reader io.Reader
  60. chunk []byte
  61. aggressive bool
  62. }
  63. func NewAuthenticationReader(auth Authenticator, reader io.Reader, aggressive bool) *AuthenticationReader {
  64. return &AuthenticationReader{
  65. auth: auth,
  66. buffer: buf.NewLocal(32 * 1024),
  67. reader: reader,
  68. aggressive: aggressive,
  69. }
  70. }
  71. func (v *AuthenticationReader) NextChunk() error {
  72. if v.buffer.Len() < 2 {
  73. return errInsufficientBuffer
  74. }
  75. size := int(serial.BytesToUint16(v.buffer.BytesTo(2)))
  76. if size > v.buffer.Len()-2 {
  77. return errInsufficientBuffer
  78. }
  79. if size == v.auth.Overhead() {
  80. return io.EOF
  81. }
  82. if size < v.auth.Overhead() {
  83. return errors.New("AuthenticationReader: invalid packet size.")
  84. }
  85. cipherChunk := v.buffer.BytesRange(2, size+2)
  86. plainChunk, err := v.auth.Open(cipherChunk[:0], cipherChunk)
  87. if err != nil {
  88. return err
  89. }
  90. v.chunk = plainChunk
  91. v.buffer.SliceFrom(size + 2)
  92. return nil
  93. }
  94. func (v *AuthenticationReader) CopyChunk(b []byte) int {
  95. if len(v.chunk) == 0 {
  96. return 0
  97. }
  98. nBytes := copy(b, v.chunk)
  99. if nBytes == len(v.chunk) {
  100. v.chunk = nil
  101. } else {
  102. v.chunk = v.chunk[nBytes:]
  103. }
  104. return nBytes
  105. }
  106. func (v *AuthenticationReader) EnsureChunk() error {
  107. for {
  108. err := v.NextChunk()
  109. if err == nil {
  110. return nil
  111. }
  112. if err == errInsufficientBuffer {
  113. if v.buffer.IsEmpty() {
  114. v.buffer.Clear()
  115. } else {
  116. leftover := v.buffer.Bytes()
  117. v.buffer.Reset(func(b []byte) (int, error) {
  118. return copy(b, leftover), nil
  119. })
  120. }
  121. err = v.buffer.AppendSupplier(buf.ReadFrom(v.reader))
  122. if err == nil {
  123. continue
  124. }
  125. }
  126. return err
  127. }
  128. }
  129. func (v *AuthenticationReader) Read(b []byte) (int, error) {
  130. if len(v.chunk) > 0 {
  131. nBytes := v.CopyChunk(b)
  132. return nBytes, nil
  133. }
  134. err := v.EnsureChunk()
  135. if err != nil {
  136. return 0, err
  137. }
  138. totalBytes := v.CopyChunk(b)
  139. for v.aggressive && totalBytes < len(b) {
  140. if err := v.NextChunk(); err != nil {
  141. break
  142. }
  143. totalBytes += v.CopyChunk(b[totalBytes:])
  144. }
  145. return totalBytes, nil
  146. }
  147. type AuthenticationWriter struct {
  148. auth Authenticator
  149. buffer []byte
  150. writer io.Writer
  151. ivGen BytesGenerator
  152. extraGen BytesGenerator
  153. }
  154. func NewAuthenticationWriter(auth Authenticator, writer io.Writer) *AuthenticationWriter {
  155. return &AuthenticationWriter{
  156. auth: auth,
  157. buffer: make([]byte, 32*1024),
  158. writer: writer,
  159. }
  160. }
  161. func (v *AuthenticationWriter) Write(b []byte) (int, error) {
  162. cipherChunk, err := v.auth.Seal(v.buffer[2:2], b)
  163. if err != nil {
  164. return 0, err
  165. }
  166. serial.Uint16ToBytes(uint16(len(cipherChunk)), v.buffer[:0])
  167. _, err = v.writer.Write(v.buffer[:2+len(cipherChunk)])
  168. return len(b), err
  169. }