auth.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. package crypto
  2. import (
  3. "crypto/cipher"
  4. "io"
  5. "golang.org/x/crypto/sha3"
  6. "v2ray.com/core/common"
  7. "v2ray.com/core/common/buf"
  8. "v2ray.com/core/common/serial"
  9. )
  10. var (
  11. errInsufficientBuffer = newError("insufficient buffer")
  12. )
  13. type BytesGenerator interface {
  14. Next() []byte
  15. }
  16. type NoOpBytesGenerator struct {
  17. buffer [1]byte
  18. }
  19. func (v NoOpBytesGenerator) Next() []byte {
  20. return v.buffer[:0]
  21. }
  22. type StaticBytesGenerator struct {
  23. Content []byte
  24. }
  25. func (v StaticBytesGenerator) Next() []byte {
  26. return v.Content
  27. }
  28. type Authenticator interface {
  29. NonceSize() int
  30. Overhead() int
  31. Open(dst, cipherText []byte) ([]byte, error)
  32. Seal(dst, plainText []byte) ([]byte, error)
  33. }
  34. type AEADAuthenticator struct {
  35. cipher.AEAD
  36. NonceGenerator BytesGenerator
  37. AdditionalDataGenerator BytesGenerator
  38. }
  39. func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) {
  40. iv := v.NonceGenerator.Next()
  41. if len(iv) != v.AEAD.NonceSize() {
  42. return nil, newError("invalid AEAD nonce size: ", len(iv))
  43. }
  44. additionalData := v.AdditionalDataGenerator.Next()
  45. return v.AEAD.Open(dst, iv, cipherText, additionalData)
  46. }
  47. func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
  48. iv := v.NonceGenerator.Next()
  49. if len(iv) != v.AEAD.NonceSize() {
  50. return nil, newError("invalid AEAD nonce size: ", len(iv))
  51. }
  52. additionalData := v.AdditionalDataGenerator.Next()
  53. return v.AEAD.Seal(dst, iv, plainText, additionalData), nil
  54. }
  55. type Uint16Generator interface {
  56. Next() uint16
  57. }
  58. type StaticUint16Generator uint16
  59. func (g StaticUint16Generator) Next() uint16 {
  60. return uint16(g)
  61. }
  62. type ShakeUint16Generator struct {
  63. shake sha3.ShakeHash
  64. buffer [2]byte
  65. }
  66. func NewShakeUint16Generator(nonce []byte) *ShakeUint16Generator {
  67. shake := sha3.NewShake128()
  68. shake.Write(nonce)
  69. return &ShakeUint16Generator{
  70. shake: shake,
  71. }
  72. }
  73. func (g *ShakeUint16Generator) Next() uint16 {
  74. g.shake.Read(g.buffer[:])
  75. return serial.BytesToUint16(g.buffer[:])
  76. }
  77. type AuthenticationReader struct {
  78. auth Authenticator
  79. buffer *buf.Buffer
  80. reader io.Reader
  81. sizeMask Uint16Generator
  82. chunk []byte
  83. }
  84. const (
  85. readerBufferSize = 32 * 1024
  86. )
  87. func NewAuthenticationReader(auth Authenticator, reader io.Reader, sizeMask Uint16Generator) *AuthenticationReader {
  88. return &AuthenticationReader{
  89. auth: auth,
  90. buffer: buf.NewLocal(readerBufferSize),
  91. reader: reader,
  92. sizeMask: sizeMask,
  93. }
  94. }
  95. func (v *AuthenticationReader) nextChunk(mask uint16) error {
  96. if v.buffer.Len() < 2 {
  97. return errInsufficientBuffer
  98. }
  99. size := int(serial.BytesToUint16(v.buffer.BytesTo(2)) ^ mask)
  100. if size > v.buffer.Len()-2 {
  101. return errInsufficientBuffer
  102. }
  103. if size > readerBufferSize-2 {
  104. return newError("size too large: ", size)
  105. }
  106. if size == v.auth.Overhead() {
  107. return io.EOF
  108. }
  109. if size < v.auth.Overhead() {
  110. return newError("invalid packet size: ", size)
  111. }
  112. cipherChunk := v.buffer.BytesRange(2, size+2)
  113. plainChunk, err := v.auth.Open(cipherChunk[:0], cipherChunk)
  114. if err != nil {
  115. return err
  116. }
  117. v.chunk = plainChunk
  118. v.buffer.SliceFrom(size + 2)
  119. return nil
  120. }
  121. func (v *AuthenticationReader) copyChunk(b []byte) int {
  122. if len(v.chunk) == 0 {
  123. return 0
  124. }
  125. nBytes := copy(b, v.chunk)
  126. if nBytes == len(v.chunk) {
  127. v.chunk = nil
  128. } else {
  129. v.chunk = v.chunk[nBytes:]
  130. }
  131. return nBytes
  132. }
  133. func (v *AuthenticationReader) ensureChunk() error {
  134. atHead := false
  135. if v.buffer.IsEmpty() {
  136. v.buffer.Clear()
  137. atHead = true
  138. }
  139. mask := v.sizeMask.Next()
  140. for {
  141. err := v.nextChunk(mask)
  142. if err != errInsufficientBuffer {
  143. return err
  144. }
  145. leftover := v.buffer.Bytes()
  146. if !atHead && len(leftover) > 0 {
  147. common.Must(v.buffer.Reset(func(b []byte) (int, error) {
  148. return copy(b, leftover), nil
  149. }))
  150. }
  151. if err := v.buffer.AppendSupplier(buf.ReadFrom(v.reader)); err != nil {
  152. return err
  153. }
  154. }
  155. }
  156. func (v *AuthenticationReader) Read(b []byte) (int, error) {
  157. if len(v.chunk) > 0 {
  158. nBytes := v.copyChunk(b)
  159. return nBytes, nil
  160. }
  161. err := v.ensureChunk()
  162. if err != nil {
  163. return 0, err
  164. }
  165. return v.copyChunk(b), nil
  166. }
  167. func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
  168. err := r.ensureChunk()
  169. if err != nil {
  170. return nil, err
  171. }
  172. mb := buf.NewMultiBuffer()
  173. for len(r.chunk) > 0 {
  174. b := buf.New()
  175. nBytes, _ := b.Write(r.chunk)
  176. mb.Append(b)
  177. r.chunk = r.chunk[nBytes:]
  178. }
  179. r.chunk = nil
  180. return mb, nil
  181. }
  182. type AuthenticationWriter struct {
  183. auth Authenticator
  184. buffer []byte
  185. writer io.Writer
  186. sizeMask Uint16Generator
  187. }
  188. func NewAuthenticationWriter(auth Authenticator, writer io.Writer, sizeMask Uint16Generator) *AuthenticationWriter {
  189. return &AuthenticationWriter{
  190. auth: auth,
  191. buffer: make([]byte, 32*1024),
  192. writer: writer,
  193. sizeMask: sizeMask,
  194. }
  195. }
  196. // Write implements io.Writer.
  197. func (w *AuthenticationWriter) Write(b []byte) (int, error) {
  198. cipherChunk, err := w.auth.Seal(w.buffer[2:2], b)
  199. if err != nil {
  200. return 0, err
  201. }
  202. size := uint16(len(cipherChunk)) ^ w.sizeMask.Next()
  203. serial.Uint16ToBytes(size, w.buffer[:0])
  204. _, err = w.writer.Write(w.buffer[:2+len(cipherChunk)])
  205. return len(b), err
  206. }
  207. func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) (int, error) {
  208. defer mb.Release()
  209. const StartIndex = 17 * 1024
  210. var totalBytes int
  211. for {
  212. payloadLen, _ := mb.Read(w.buffer[StartIndex:])
  213. nBytes, err := w.Write(w.buffer[StartIndex : StartIndex+payloadLen])
  214. totalBytes += nBytes
  215. if err != nil {
  216. return totalBytes, err
  217. }
  218. if mb.IsEmpty() {
  219. break
  220. }
  221. }
  222. return totalBytes, nil
  223. }