auth.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package crypto
  2. import (
  3. "crypto/cipher"
  4. "io"
  5. "v2ray.com/core/common"
  6. "v2ray.com/core/common/buf"
  7. "v2ray.com/core/common/errors"
  8. "v2ray.com/core/common/serial"
  9. )
  10. var (
  11. errInsufficientBuffer = errors.New("Insufficient buffer.")
  12. )
  13. type BytesGenerator interface {
  14. Next() []byte
  15. }
  16. type NoOpBytesGenerator struct {
  17. buffer [1]byte
  18. }
  19. func (v NoOpBytesGenerator) Next() []byte {
  20. return v.buffer[:0]
  21. }
  22. type StaticBytesGenerator struct {
  23. Content []byte
  24. }
  25. func (v StaticBytesGenerator) Next() []byte {
  26. return v.Content
  27. }
  28. type Authenticator interface {
  29. NonceSize() int
  30. Overhead() int
  31. Open(dst, cipherText []byte) ([]byte, error)
  32. Seal(dst, plainText []byte) ([]byte, error)
  33. }
  34. type AEADAuthenticator struct {
  35. cipher.AEAD
  36. NonceGenerator BytesGenerator
  37. AdditionalDataGenerator BytesGenerator
  38. }
  39. func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) {
  40. iv := v.NonceGenerator.Next()
  41. if len(iv) != v.AEAD.NonceSize() {
  42. return nil, errors.New("Crypto:AEADAuthenticator: Invalid nonce size: ", len(iv))
  43. }
  44. additionalData := v.AdditionalDataGenerator.Next()
  45. return v.AEAD.Open(dst, iv, cipherText, additionalData)
  46. }
  47. func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
  48. iv := v.NonceGenerator.Next()
  49. if len(iv) != v.AEAD.NonceSize() {
  50. return nil, errors.New("Crypto:AEADAuthenticator: Invalid nonce size: ", len(iv))
  51. }
  52. additionalData := v.AdditionalDataGenerator.Next()
  53. return v.AEAD.Seal(dst, iv, plainText, additionalData), nil
  54. }
  55. type AuthenticationReader struct {
  56. auth Authenticator
  57. buffer *buf.Buffer
  58. reader io.Reader
  59. sizeMask uint16
  60. chunk []byte
  61. }
  62. const (
  63. readerBufferSize = 32 * 1024
  64. )
  65. func NewAuthenticationReader(auth Authenticator, reader io.Reader, sizeMask uint16) *AuthenticationReader {
  66. return &AuthenticationReader{
  67. auth: auth,
  68. buffer: buf.NewLocal(readerBufferSize),
  69. reader: reader,
  70. sizeMask: sizeMask,
  71. }
  72. }
  73. func (v *AuthenticationReader) NextChunk() error {
  74. if v.buffer.Len() < 2 {
  75. return errInsufficientBuffer
  76. }
  77. size := int(serial.BytesToUint16(v.buffer.BytesTo(2)) ^ v.sizeMask)
  78. if size > v.buffer.Len()-2 {
  79. return errInsufficientBuffer
  80. }
  81. if size > readerBufferSize-2 {
  82. return errors.New("Crypto:AuthenticationReader: Size too large: ", size)
  83. }
  84. if size == v.auth.Overhead() {
  85. return io.EOF
  86. }
  87. if size < v.auth.Overhead() {
  88. return errors.New("AuthenticationReader: invalid packet size:", size)
  89. }
  90. cipherChunk := v.buffer.BytesRange(2, size+2)
  91. plainChunk, err := v.auth.Open(cipherChunk[:0], cipherChunk)
  92. if err != nil {
  93. return err
  94. }
  95. v.chunk = plainChunk
  96. v.buffer.SliceFrom(size + 2)
  97. return nil
  98. }
  99. func (v *AuthenticationReader) CopyChunk(b []byte) int {
  100. if len(v.chunk) == 0 {
  101. return 0
  102. }
  103. nBytes := copy(b, v.chunk)
  104. if nBytes == len(v.chunk) {
  105. v.chunk = nil
  106. } else {
  107. v.chunk = v.chunk[nBytes:]
  108. }
  109. return nBytes
  110. }
  111. func (v *AuthenticationReader) EnsureChunk() error {
  112. atHead := false
  113. if v.buffer.IsEmpty() {
  114. v.buffer.Clear()
  115. atHead = true
  116. }
  117. for {
  118. err := v.NextChunk()
  119. if err != errInsufficientBuffer {
  120. return err
  121. }
  122. leftover := v.buffer.Bytes()
  123. if !atHead && len(leftover) > 0 {
  124. common.Must(v.buffer.Reset(func(b []byte) (int, error) {
  125. return copy(b, leftover), nil
  126. }))
  127. }
  128. if err := v.buffer.AppendSupplier(buf.ReadFrom(v.reader)); err != nil {
  129. return err
  130. }
  131. }
  132. }
  133. func (v *AuthenticationReader) Read(b []byte) (int, error) {
  134. if len(v.chunk) > 0 {
  135. nBytes := v.CopyChunk(b)
  136. return nBytes, nil
  137. }
  138. err := v.EnsureChunk()
  139. if err != nil {
  140. return 0, err
  141. }
  142. return v.CopyChunk(b), nil
  143. }
  144. type AuthenticationWriter struct {
  145. auth Authenticator
  146. buffer []byte
  147. writer io.Writer
  148. sizeMask uint16
  149. }
  150. func NewAuthenticationWriter(auth Authenticator, writer io.Writer, sizeMask uint16) *AuthenticationWriter {
  151. return &AuthenticationWriter{
  152. auth: auth,
  153. buffer: make([]byte, 32*1024),
  154. writer: writer,
  155. sizeMask: sizeMask,
  156. }
  157. }
  158. func (v *AuthenticationWriter) Write(b []byte) (int, error) {
  159. cipherChunk, err := v.auth.Seal(v.buffer[2:2], b)
  160. if err != nil {
  161. return 0, err
  162. }
  163. size := uint16(len(cipherChunk)) ^ v.sizeMask
  164. serial.Uint16ToBytes(size, v.buffer[:0])
  165. _, err = v.writer.Write(v.buffer[:2+len(cipherChunk)])
  166. return len(b), err
  167. }