|  | @@ -7,19 +7,21 @@ import (
 | 
											
												
													
														|  |  	"crypto/rand"
 |  |  	"crypto/rand"
 | 
											
												
													
														|  |  	mrand "math/rand"
 |  |  	mrand "math/rand"
 | 
											
												
													
														|  |  	"testing"
 |  |  	"testing"
 | 
											
												
													
														|  | 
 |  | +  
 | 
											
												
													
														|  | 
 |  | +  "github.com/v2ray/v2ray-core/testing/unit"
 | 
											
												
													
														|  |  )
 |  |  )
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  func randomBytes(p []byte, t *testing.T) {
 |  |  func randomBytes(p []byte, t *testing.T) {
 | 
											
												
													
														|  | 
 |  | +  assert := unit.Assert(t)
 | 
											
												
													
														|  | 
 |  | +  
 | 
											
												
													
														|  |  	nBytes, err := rand.Read(p)
 |  |  	nBytes, err := rand.Read(p)
 | 
											
												
													
														|  | -	if err != nil {
 |  | 
 | 
											
												
													
														|  | -		t.Fatal(err)
 |  | 
 | 
											
												
													
														|  | -	}
 |  | 
 | 
											
												
													
														|  | -	if nBytes != len(p) {
 |  | 
 | 
											
												
													
														|  | -		t.Error("Unable to generate %d bytes of random buffer", len(p))
 |  | 
 | 
											
												
													
														|  | -	}
 |  | 
 | 
											
												
													
														|  | 
 |  | +  assert.Error(err).IsNil()
 | 
											
												
													
														|  | 
 |  | +  assert.Int(nBytes).Named("# bytes of random buffer").Equals(len(p))
 | 
											
												
													
														|  |  }
 |  |  }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  func TestNormalReading(t *testing.T) {
 |  |  func TestNormalReading(t *testing.T) {
 | 
											
												
													
														|  | 
 |  | +  assert := unit.Assert(t)
 | 
											
												
													
														|  | 
 |  | +  
 | 
											
												
													
														|  |  	testSize := 256
 |  |  	testSize := 256
 | 
											
												
													
														|  |  	plaintext := make([]byte, testSize)
 |  |  	plaintext := make([]byte, testSize)
 | 
											
												
													
														|  |  	randomBytes(plaintext, t)
 |  |  	randomBytes(plaintext, t)
 | 
											
										
											
												
													
														|  | @@ -31,9 +33,8 @@ func TestNormalReading(t *testing.T) {
 | 
											
												
													
														|  |  	randomBytes(iv, t)
 |  |  	randomBytes(iv, t)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	aesBlock, err := aes.NewCipher(key)
 |  |  	aesBlock, err := aes.NewCipher(key)
 | 
											
												
													
														|  | -	if err != nil {
 |  | 
 | 
											
												
													
														|  | -		t.Fatal(err)
 |  | 
 | 
											
												
													
														|  | -	}
 |  | 
 | 
											
												
													
														|  | 
 |  | +  assert.Error(err).IsNil()
 | 
											
												
													
														|  | 
 |  | +  
 | 
											
												
													
														|  |  	aesMode := cipher.NewCBCEncrypter(aesBlock, iv)
 |  |  	aesMode := cipher.NewCBCEncrypter(aesBlock, iv)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	ciphertext := make([]byte, testSize)
 |  |  	ciphertext := make([]byte, testSize)
 | 
											
										
											
												
													
														|  | @@ -43,9 +44,7 @@ func TestNormalReading(t *testing.T) {
 | 
											
												
													
														|  |  	copy(ciphertextcopy, ciphertext)
 |  |  	copy(ciphertextcopy, ciphertext)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	reader, err := NewDecryptionReader(bytes.NewReader(ciphertextcopy), key, iv)
 |  |  	reader, err := NewDecryptionReader(bytes.NewReader(ciphertextcopy), key, iv)
 | 
											
												
													
														|  | -	if err != nil {
 |  | 
 | 
											
												
													
														|  | -		t.Fatal(err)
 |  | 
 | 
											
												
													
														|  | -	}
 |  | 
 | 
											
												
													
														|  | 
 |  | +	assert.Error(err).IsNil()
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	readtext := make([]byte, testSize)
 |  |  	readtext := make([]byte, testSize)
 | 
											
												
													
														|  |  	readSize := 0
 |  |  	readSize := 0
 | 
											
										
											
												
													
														|  | @@ -55,15 +54,9 @@ func TestNormalReading(t *testing.T) {
 | 
											
												
													
														|  |  			nBytes = testSize - readSize
 |  |  			nBytes = testSize - readSize
 | 
											
												
													
														|  |  		}
 |  |  		}
 | 
											
												
													
														|  |  		bytesRead, err := reader.Read(readtext[readSize : readSize+nBytes])
 |  |  		bytesRead, err := reader.Read(readtext[readSize : readSize+nBytes])
 | 
											
												
													
														|  | -		if err != nil {
 |  | 
 | 
											
												
													
														|  | -			t.Fatal(err)
 |  | 
 | 
											
												
													
														|  | -		}
 |  | 
 | 
											
												
													
														|  | -		if bytesRead != nBytes {
 |  | 
 | 
											
												
													
														|  | -			t.Errorf("Expected to read %d bytes, but only read %d bytes", nBytes, bytesRead)
 |  | 
 | 
											
												
													
														|  | -		}
 |  | 
 | 
											
												
													
														|  | 
 |  | +		assert.Error(err).IsNil()
 | 
											
												
													
														|  | 
 |  | +    assert.Int(bytesRead).Equals(nBytes)
 | 
											
												
													
														|  |  		readSize += nBytes
 |  |  		readSize += nBytes
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  | -	if !bytes.Equal(readtext, plaintext) {
 |  | 
 | 
											
												
													
														|  | -		t.Errorf("Expected plaintext %v, but got %v", plaintext, readtext)
 |  | 
 | 
											
												
													
														|  | -	}
 |  | 
 | 
											
												
													
														|  | 
 |  | +  assert.Bytes(readtext).Named("Plaintext").Equals(plaintext)
 | 
											
												
													
														|  |  }
 |  |  }
 |