auth.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. package crypto
  2. import (
  3. "crypto/cipher"
  4. "errors"
  5. "io"
  6. "v2ray.com/core/common"
  7. "v2ray.com/core/common/buf"
  8. "v2ray.com/core/common/serial"
  9. )
  10. var (
  11. ErrAuthenticationFailed = errors.New("Authentication failed.")
  12. errInsufficientBuffer = errors.New("Insufficient buffer.")
  13. errInvalidNonce = errors.New("Invalid nonce.")
  14. errInvalidLength = errors.New("Invalid buffer size.")
  15. )
  16. type BytesGenerator interface {
  17. Next() []byte
  18. }
  19. type NoOpBytesGenerator struct {
  20. buffer [1]byte
  21. }
  22. func (v NoOpBytesGenerator) Next() []byte {
  23. return v.buffer[:0]
  24. }
  25. type StaticBytesGenerator struct {
  26. Content []byte
  27. }
  28. func (v StaticBytesGenerator) Next() []byte {
  29. return v.Content
  30. }
  31. type Authenticator interface {
  32. NonceSize() int
  33. Overhead() int
  34. Open(dst, cipherText []byte) ([]byte, error)
  35. Seal(dst, plainText []byte) ([]byte, error)
  36. }
  37. type AEADAuthenticator struct {
  38. cipher.AEAD
  39. NonceGenerator BytesGenerator
  40. AdditionalDataGenerator BytesGenerator
  41. }
  42. func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) {
  43. iv := v.NonceGenerator.Next()
  44. if len(iv) != v.AEAD.NonceSize() {
  45. return nil, errInvalidNonce
  46. }
  47. additionalData := v.AdditionalDataGenerator.Next()
  48. return v.AEAD.Open(dst, iv, cipherText, additionalData)
  49. }
  50. func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
  51. iv := v.NonceGenerator.Next()
  52. if len(iv) != v.AEAD.NonceSize() {
  53. return nil, errInvalidNonce
  54. }
  55. additionalData := v.AdditionalDataGenerator.Next()
  56. return v.AEAD.Seal(dst, iv, plainText, additionalData), nil
  57. }
  58. type AuthenticationReader struct {
  59. auth Authenticator
  60. buffer *buf.Buffer
  61. reader io.Reader
  62. chunk []byte
  63. aggressive bool
  64. }
  65. const (
  66. readerBufferSize = 32 * 1024
  67. )
  68. func NewAuthenticationReader(auth Authenticator, reader io.Reader, aggressive bool) *AuthenticationReader {
  69. return &AuthenticationReader{
  70. auth: auth,
  71. buffer: buf.NewLocal(readerBufferSize),
  72. reader: reader,
  73. aggressive: aggressive,
  74. }
  75. }
  76. func (v *AuthenticationReader) NextChunk() error {
  77. if v.buffer.Len() < 2 {
  78. return errInsufficientBuffer
  79. }
  80. size := int(serial.BytesToUint16(v.buffer.BytesTo(2)))
  81. if size > v.buffer.Len()-2 {
  82. return errInsufficientBuffer
  83. }
  84. if size > readerBufferSize-2 {
  85. return errInvalidLength
  86. }
  87. if size == v.auth.Overhead() {
  88. return io.EOF
  89. }
  90. if size < v.auth.Overhead() {
  91. return errors.New("AuthenticationReader: invalid packet size.")
  92. }
  93. cipherChunk := v.buffer.BytesRange(2, size+2)
  94. plainChunk, err := v.auth.Open(cipherChunk[:0], cipherChunk)
  95. if err != nil {
  96. return err
  97. }
  98. v.chunk = plainChunk
  99. v.buffer.SliceFrom(size + 2)
  100. return nil
  101. }
  102. func (v *AuthenticationReader) CopyChunk(b []byte) int {
  103. if len(v.chunk) == 0 {
  104. return 0
  105. }
  106. nBytes := copy(b, v.chunk)
  107. if nBytes == len(v.chunk) {
  108. v.chunk = nil
  109. } else {
  110. v.chunk = v.chunk[nBytes:]
  111. }
  112. return nBytes
  113. }
  114. func (v *AuthenticationReader) EnsureChunk() error {
  115. for {
  116. err := v.NextChunk()
  117. if err == nil {
  118. return nil
  119. }
  120. if err == errInsufficientBuffer {
  121. if v.buffer.IsEmpty() {
  122. v.buffer.Clear()
  123. } else {
  124. leftover := v.buffer.Bytes()
  125. common.Must(v.buffer.Reset(func(b []byte) (int, error) {
  126. return copy(b, leftover), nil
  127. }))
  128. }
  129. err = v.buffer.AppendSupplier(buf.ReadFrom(v.reader))
  130. if err == nil {
  131. continue
  132. }
  133. }
  134. return err
  135. }
  136. }
  137. func (v *AuthenticationReader) Read(b []byte) (int, error) {
  138. if len(v.chunk) > 0 {
  139. nBytes := v.CopyChunk(b)
  140. return nBytes, nil
  141. }
  142. err := v.EnsureChunk()
  143. if err != nil {
  144. return 0, err
  145. }
  146. totalBytes := v.CopyChunk(b)
  147. for v.aggressive && totalBytes < len(b) {
  148. if err := v.NextChunk(); err != nil {
  149. break
  150. }
  151. totalBytes += v.CopyChunk(b[totalBytes:])
  152. }
  153. return totalBytes, nil
  154. }
  155. type AuthenticationWriter struct {
  156. auth Authenticator
  157. buffer []byte
  158. writer io.Writer
  159. }
  160. func NewAuthenticationWriter(auth Authenticator, writer io.Writer) *AuthenticationWriter {
  161. return &AuthenticationWriter{
  162. auth: auth,
  163. buffer: make([]byte, 32*1024),
  164. writer: writer,
  165. }
  166. }
  167. func (v *AuthenticationWriter) Write(b []byte) (int, error) {
  168. cipherChunk, err := v.auth.Seal(v.buffer[2:2], b)
  169. if err != nil {
  170. return 0, err
  171. }
  172. serial.Uint16ToBytes(uint16(len(cipherChunk)), v.buffer[:0])
  173. _, err = v.writer.Write(v.buffer[:2+len(cipherChunk)])
  174. return len(b), err
  175. }