decryptionreader.go 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. package vmess
  2. import (
  3. "bytes"
  4. "crypto/aes"
  5. "crypto/cipher"
  6. "fmt"
  7. "io"
  8. v2io "github.com/v2ray/v2ray-core/io"
  9. )
  10. const (
  11. blockSize = 16 // Decryption block size, inherited from AES
  12. )
  13. // DecryptionReader is a byte stream reader to decrypt AES-128 CBC (for now)
  14. // encrypted content.
  15. type DecryptionReader struct {
  16. reader *v2io.CryptionReader
  17. buffer *bytes.Buffer
  18. }
  19. // NewDecryptionReader creates a new DescriptionReader by given byte Reader and
  20. // AES key.
  21. func NewDecryptionReader(reader io.Reader, key []byte, iv []byte) (*DecryptionReader, error) {
  22. decryptionReader := new(DecryptionReader)
  23. aesCipher, err := aes.NewCipher(key)
  24. if err != nil {
  25. return nil, err
  26. }
  27. aesStream := cipher.NewCFBDecrypter(aesCipher, iv)
  28. decryptionReader.reader = v2io.NewCryptionReader(aesStream, reader)
  29. decryptionReader.buffer = bytes.NewBuffer(make([]byte, 0, 2*blockSize))
  30. return decryptionReader, nil
  31. }
  32. func (reader *DecryptionReader) readBlock() error {
  33. buffer := make([]byte, blockSize)
  34. nBytes, err := reader.reader.Read(buffer)
  35. if err != nil && err != io.EOF {
  36. return err
  37. }
  38. if nBytes < blockSize {
  39. return fmt.Errorf("Expected to read %d bytes, but got %d bytes", blockSize, nBytes)
  40. }
  41. reader.buffer.Write(buffer)
  42. return err
  43. }
  44. // Read returns decrypted bytes of given length
  45. func (reader *DecryptionReader) Read(p []byte) (int, error) {
  46. nBytes, err := reader.buffer.Read(p)
  47. if err != nil && err != io.EOF {
  48. return nBytes, err
  49. }
  50. if nBytes < len(p) {
  51. err = reader.readBlock()
  52. if err != nil {
  53. return nBytes, err
  54. }
  55. moreBytes, err := reader.buffer.Read(p[nBytes:])
  56. if err != nil {
  57. return nBytes, err
  58. }
  59. nBytes += moreBytes
  60. if nBytes != len(p) {
  61. return nBytes, fmt.Errorf("Unable to read %d bytes", len(p))
  62. }
  63. }
  64. return nBytes, err
  65. }