decryptionreader.go 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. package vmess
  2. import (
  3. "bytes"
  4. "crypto/aes"
  5. "crypto/cipher"
  6. "fmt"
  7. "io"
  8. )
  9. const (
  10. blockSize = 16
  11. )
  12. type DecryptionReader struct {
  13. cipher cipher.Block
  14. reader io.Reader
  15. buffer *bytes.Buffer
  16. }
  17. func NewDecryptionReader(reader io.Reader, key []byte) (*DecryptionReader, error) {
  18. decryptionReader := new(DecryptionReader)
  19. cipher, err := aes.NewCipher(key)
  20. if err != nil {
  21. return nil, err
  22. }
  23. decryptionReader.cipher = cipher
  24. decryptionReader.reader = reader
  25. decryptionReader.buffer = bytes.NewBuffer(make([]byte, 0, 2 * blockSize))
  26. return decryptionReader, nil
  27. }
  28. func (reader *DecryptionReader) readBlock() error {
  29. buffer := make([]byte, blockSize)
  30. nBytes, err := reader.reader.Read(buffer)
  31. if err != nil {
  32. return err
  33. }
  34. if nBytes < blockSize {
  35. return fmt.Errorf("Expected to read %d bytes, but got %d bytes", blockSize, nBytes)
  36. }
  37. reader.cipher.Decrypt(buffer, buffer)
  38. reader.buffer.Write(buffer)
  39. return nil
  40. }
  41. func (reader *DecryptionReader) Read(p []byte) (int, error) {
  42. if reader.buffer.Len() == 0 {
  43. err := reader.readBlock()
  44. if err != nil {
  45. return 0, err
  46. }
  47. }
  48. nBytes, err := reader.buffer.Read(p)
  49. if err != nil {
  50. return nBytes, err
  51. }
  52. if nBytes < len(p) {
  53. err = reader.readBlock()
  54. if err != nil {
  55. return nBytes, err
  56. }
  57. moreBytes, err := reader.buffer.Read(p[nBytes:])
  58. if err != nil {
  59. return nBytes, err
  60. }
  61. nBytes += moreBytes
  62. if nBytes != len(p) {
  63. return nBytes, fmt.Errorf("Unable to read %d bytes", len(p))
  64. }
  65. }
  66. return nBytes, err
  67. }