Browse Source

Decryption reader for decoding vmess message

V2Ray 10 years ago
parent
commit
69bcce0b0d
2 changed files with 140 additions and 0 deletions
  1. 73 0
      io/vmess/decryptionreader.go
  2. 67 0
      io/vmess/decryptionreader_test.go

+ 73 - 0
io/vmess/decryptionreader.go

@@ -0,0 +1,73 @@
+package vmess
+
+import (
+  "bytes"
+	"crypto/aes"
+	"crypto/cipher"
+	"fmt"
+	"io"
+)
+
+const (
+	blockSize = 16
+)
+
+type DecryptionReader struct {
+	cipher             cipher.Block
+	reader             io.Reader
+	buffer             *bytes.Buffer
+}
+
+func NewDecryptionReader(reader io.Reader, key []byte) (*DecryptionReader, error) {
+	decryptionReader := new(DecryptionReader)
+	cipher, err := aes.NewCipher(key)
+	if err != nil {
+		return nil, err
+	}
+	decryptionReader.cipher = cipher
+	decryptionReader.reader = reader
+  decryptionReader.buffer = bytes.NewBuffer(make([]byte, 0, 2 * blockSize))
+	return decryptionReader, nil
+}
+
+func (reader *DecryptionReader) readBlock() error {
+  buffer := make([]byte, blockSize)
+  nBytes, err := reader.reader.Read(buffer)
+	if err != nil {
+		return err
+	}
+	if nBytes < blockSize {
+		return fmt.Errorf("Expected to read %d bytes, but got %d bytes", blockSize, nBytes)
+	}
+	reader.cipher.Decrypt(buffer, buffer)
+  reader.buffer.Write(buffer)
+	return nil
+}
+
+func (reader *DecryptionReader) Read(p []byte) (int, error) {
+  if reader.buffer.Len() == 0 {
+    err := reader.readBlock()
+    if err != nil {
+      return 0, err
+    }
+  }
+	nBytes, err := reader.buffer.Read(p)
+  if err != nil {
+    return nBytes, err
+  }
+  if nBytes < len(p) {
+    err = reader.readBlock()
+    if err != nil {
+      return nBytes, err
+    }
+    moreBytes, err := reader.buffer.Read(p[nBytes:])
+    if err != nil {
+      return nBytes, err
+    }
+    nBytes += moreBytes
+    if nBytes != len(p) {
+      return nBytes, fmt.Errorf("Unable to read %d bytes", len(p))
+    }
+  }
+  return nBytes, err
+}

+ 67 - 0
io/vmess/decryptionreader_test.go

@@ -0,0 +1,67 @@
+package vmess
+
+import (
+  "bytes"
+  "crypto/aes"
+  "crypto/rand"
+  mrand "math/rand"
+  "testing"
+)
+
+func randomBytes(p []byte, t *testing.T) {
+  nBytes, err := rand.Read(p)
+  if err != nil {
+    t.Fatal(err)
+  }
+  if nBytes != len(p) {
+    t.Error("Unable to generate %d bytes of random buffer", len(p))
+  }
+}
+
+func TestNormalReading(t *testing.T) {
+  testSize := 256
+  plaintext := make([]byte, testSize)
+  randomBytes(plaintext, t)
+  
+  keySize := 16
+  key := make([]byte, keySize)
+  randomBytes(key, t)
+  
+  cipher, err := aes.NewCipher(key)
+  if err != nil {
+    t.Fatal(err)
+  }
+  
+  ciphertext := make([]byte, testSize)
+  for encryptSize := 0; encryptSize < testSize; encryptSize += blockSize {
+    cipher.Encrypt(ciphertext[encryptSize:], plaintext[encryptSize:])
+  }
+  
+  ciphertextcopy := make([]byte, testSize)
+  copy(ciphertextcopy, ciphertext)
+  
+  reader, err := NewDecryptionReader(bytes.NewReader(ciphertextcopy), key)
+  if err != nil {
+    t.Fatal(err)
+  }
+  
+  readtext := make([]byte, testSize)
+  readSize := 0
+  for readSize < testSize {
+    nBytes := mrand.Intn(16) + 1
+    if nBytes > testSize - readSize {
+      nBytes = testSize - readSize
+    }
+    bytesRead, err := reader.Read(readtext[readSize:readSize + nBytes])
+    if err != nil {
+      t.Fatal(err)
+    }
+    if bytesRead != nBytes {
+      t.Errorf("Expected to read %d bytes, but only read %d bytes", nBytes, bytesRead)
+    }
+    readSize += nBytes
+  }
+  if ! bytes.Equal(readtext, plaintext) {
+    t.Errorf("Expected plaintext %v, but got %v", plaintext, readtext)
+  }
+}