|
|
@@ -6,26 +6,31 @@ import (
|
|
|
"crypto/cipher"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
+
|
|
|
+ v2io "github.com/v2ray/v2ray-core/io"
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
- blockSize = 16
|
|
|
+ blockSize = 16 // Decryption block size, inherited from AES
|
|
|
)
|
|
|
|
|
|
+// DecryptionReader is a byte stream reader to decrypt AES-128 CBC (for now)
|
|
|
+// encrypted content.
|
|
|
type DecryptionReader struct {
|
|
|
- cipher cipher.Block
|
|
|
- reader io.Reader
|
|
|
+ reader *v2io.CryptionReader
|
|
|
buffer *bytes.Buffer
|
|
|
}
|
|
|
|
|
|
-func NewDecryptionReader(reader io.Reader, key []byte) (*DecryptionReader, error) {
|
|
|
+// NewDecryptionReader creates a new DescriptionReader by given byte Reader and
|
|
|
+// AES key.
|
|
|
+func NewDecryptionReader(reader io.Reader, key []byte, iv []byte) (*DecryptionReader, error) {
|
|
|
decryptionReader := new(DecryptionReader)
|
|
|
- cipher, err := aes.NewCipher(key)
|
|
|
+ aesCipher, err := aes.NewCipher(key)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
- decryptionReader.cipher = cipher
|
|
|
- decryptionReader.reader = reader
|
|
|
+ aesBlockMode := cipher.NewCBCDecrypter(aesCipher, iv)
|
|
|
+ decryptionReader.reader = v2io.NewCryptionReader(aesBlockMode, reader)
|
|
|
decryptionReader.buffer = bytes.NewBuffer(make([]byte, 0, 2*blockSize))
|
|
|
return decryptionReader, nil
|
|
|
}
|
|
|
@@ -33,26 +38,20 @@ func NewDecryptionReader(reader io.Reader, key []byte) (*DecryptionReader, error
|
|
|
func (reader *DecryptionReader) readBlock() error {
|
|
|
buffer := make([]byte, blockSize)
|
|
|
nBytes, err := reader.reader.Read(buffer)
|
|
|
- if err != nil {
|
|
|
+ if err != nil && err != io.EOF {
|
|
|
return err
|
|
|
}
|
|
|
if nBytes < blockSize {
|
|
|
return fmt.Errorf("Expected to read %d bytes, but got %d bytes", blockSize, nBytes)
|
|
|
}
|
|
|
- reader.cipher.Decrypt(buffer, buffer)
|
|
|
reader.buffer.Write(buffer)
|
|
|
- return nil
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
+// Read returns decrypted bytes of given length
|
|
|
func (reader *DecryptionReader) Read(p []byte) (int, error) {
|
|
|
- if reader.buffer.Len() == 0 {
|
|
|
- err := reader.readBlock()
|
|
|
- if err != nil {
|
|
|
- return 0, err
|
|
|
- }
|
|
|
- }
|
|
|
nBytes, err := reader.buffer.Read(p)
|
|
|
- if err != nil {
|
|
|
+ if err != nil && err != io.EOF {
|
|
|
return nBytes, err
|
|
|
}
|
|
|
if nBytes < len(p) {
|