|  | @@ -1,6 +1,7 @@
 | 
											
												
													
														|  |  package io
 |  |  package io
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  import (
 |  |  import (
 | 
											
												
													
														|  | 
 |  | +	"hash"
 | 
											
												
													
														|  |  	"hash/fnv"
 |  |  	"hash/fnv"
 | 
											
												
													
														|  |  	"io"
 |  |  	"io"
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -9,49 +10,115 @@ import (
 | 
											
												
													
														|  |  	"github.com/v2ray/v2ray-core/transport"
 |  |  	"github.com/v2ray/v2ray-core/transport"
 | 
											
												
													
														|  |  )
 |  |  )
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +// @Private
 | 
											
												
													
														|  | 
 |  | +func AllocBuffer(size int) *alloc.Buffer {
 | 
											
												
													
														|  | 
 |  | +	if size < 8*1024-16 {
 | 
											
												
													
														|  | 
 |  | +		return alloc.NewBuffer()
 | 
											
												
													
														|  | 
 |  | +	}
 | 
											
												
													
														|  | 
 |  | +	return alloc.NewLargeBuffer()
 | 
											
												
													
														|  | 
 |  | +}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +// @Private
 | 
											
												
													
														|  | 
 |  | +type Validator struct {
 | 
											
												
													
														|  | 
 |  | +	actualAuth   hash.Hash32
 | 
											
												
													
														|  | 
 |  | +	expectedAuth uint32
 | 
											
												
													
														|  | 
 |  | +}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +func NewValidator(expectedAuth uint32) *Validator {
 | 
											
												
													
														|  | 
 |  | +	return &Validator{
 | 
											
												
													
														|  | 
 |  | +		actualAuth:   fnv.New32a(),
 | 
											
												
													
														|  | 
 |  | +		expectedAuth: expectedAuth,
 | 
											
												
													
														|  | 
 |  | +	}
 | 
											
												
													
														|  | 
 |  | +}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +func (this *Validator) Consume(b []byte) {
 | 
											
												
													
														|  | 
 |  | +	this.actualAuth.Write(b)
 | 
											
												
													
														|  | 
 |  | +}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +func (this *Validator) Validate() bool {
 | 
											
												
													
														|  | 
 |  | +	return this.actualAuth.Sum32() == this.expectedAuth
 | 
											
												
													
														|  | 
 |  | +}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |  type AuthChunkReader struct {
 |  |  type AuthChunkReader struct {
 | 
											
												
													
														|  | -	reader io.Reader
 |  | 
 | 
											
												
													
														|  | 
 |  | +	reader      io.Reader
 | 
											
												
													
														|  | 
 |  | +	last        *alloc.Buffer
 | 
											
												
													
														|  | 
 |  | +	chunkLength int
 | 
											
												
													
														|  | 
 |  | +	validator   *Validator
 | 
											
												
													
														|  |  }
 |  |  }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  func NewAuthChunkReader(reader io.Reader) *AuthChunkReader {
 |  |  func NewAuthChunkReader(reader io.Reader) *AuthChunkReader {
 | 
											
												
													
														|  |  	return &AuthChunkReader{
 |  |  	return &AuthChunkReader{
 | 
											
												
													
														|  | -		reader: reader,
 |  | 
 | 
											
												
													
														|  | 
 |  | +		reader:      reader,
 | 
											
												
													
														|  | 
 |  | +		chunkLength: -1,
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  }
 |  |  }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  func (this *AuthChunkReader) Read() (*alloc.Buffer, error) {
 |  |  func (this *AuthChunkReader) Read() (*alloc.Buffer, error) {
 | 
											
												
													
														|  | -	buffer := alloc.NewBuffer()
 |  | 
 | 
											
												
													
														|  | -	if _, err := io.ReadFull(this.reader, buffer.Value[:2]); err != nil {
 |  | 
 | 
											
												
													
														|  | 
 |  | +	var buffer *alloc.Buffer
 | 
											
												
													
														|  | 
 |  | +	if this.last != nil {
 | 
											
												
													
														|  | 
 |  | +		buffer = this.last
 | 
											
												
													
														|  | 
 |  | +		this.last = nil
 | 
											
												
													
														|  | 
 |  | +	} else {
 | 
											
												
													
														|  | 
 |  | +		buffer = AllocBuffer(this.chunkLength).Clear()
 | 
											
												
													
														|  | 
 |  | +	}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	_, err := buffer.FillFrom(this.reader)
 | 
											
												
													
														|  | 
 |  | +	if err != nil {
 | 
											
												
													
														|  |  		buffer.Release()
 |  |  		buffer.Release()
 | 
											
												
													
														|  |  		return nil, err
 |  |  		return nil, err
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -	length := serial.BytesLiteral(buffer.Value[:2]).Uint16Value()
 |  | 
 | 
											
												
													
														|  | -	if length <= 4 { // Length of authentication bytes.
 |  | 
 | 
											
												
													
														|  | -		return nil, io.EOF
 |  | 
 | 
											
												
													
														|  | 
 |  | +	if this.chunkLength == -1 {
 | 
											
												
													
														|  | 
 |  | +		for buffer.Len() < 6 {
 | 
											
												
													
														|  | 
 |  | +			_, err := buffer.FillFrom(this.reader)
 | 
											
												
													
														|  | 
 |  | +			if err != nil {
 | 
											
												
													
														|  | 
 |  | +				buffer.Release()
 | 
											
												
													
														|  | 
 |  | +				return nil, err
 | 
											
												
													
														|  | 
 |  | +			}
 | 
											
												
													
														|  | 
 |  | +		}
 | 
											
												
													
														|  | 
 |  | +		length := serial.BytesLiteral(buffer.Value[:2]).Uint16Value()
 | 
											
												
													
														|  | 
 |  | +		this.chunkLength = int(length) - 4
 | 
											
												
													
														|  | 
 |  | +		this.validator = NewValidator(serial.BytesLiteral(buffer.Value[2:6]).Uint32Value())
 | 
											
												
													
														|  | 
 |  | +		buffer.SliceFrom(6)
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  | -	if length > 8*1024-16 {
 |  | 
 | 
											
												
													
														|  | -		buffer.Release()
 |  | 
 | 
											
												
													
														|  | -		buffer = alloc.NewLargeBuffer()
 |  | 
 | 
											
												
													
														|  | -	}
 |  | 
 | 
											
												
													
														|  | -	if _, err := io.ReadFull(this.reader, buffer.Value[:length]); err != nil {
 |  | 
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	if this.chunkLength == 0 {
 | 
											
												
													
														|  |  		buffer.Release()
 |  |  		buffer.Release()
 | 
											
												
													
														|  | -		return nil, err
 |  | 
 | 
											
												
													
														|  | 
 |  | +		return nil, io.EOF
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  | -	buffer.Slice(0, int(length))
 |  | 
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -	fnvHash := fnv.New32a()
 |  | 
 | 
											
												
													
														|  | -	fnvHash.Write(buffer.Value[4:])
 |  | 
 | 
											
												
													
														|  | -	expAuth := serial.BytesLiteral(fnvHash.Sum(nil))
 |  | 
 | 
											
												
													
														|  | -	actualAuth := serial.BytesLiteral(buffer.Value[:4])
 |  | 
 | 
											
												
													
														|  | -	if !actualAuth.Equals(expAuth) {
 |  | 
 | 
											
												
													
														|  | -		buffer.Release()
 |  | 
 | 
											
												
													
														|  | -		return nil, transport.ErrorCorruptedPacket
 |  | 
 | 
											
												
													
														|  | 
 |  | +	if buffer.Len() <= this.chunkLength {
 | 
											
												
													
														|  | 
 |  | +		this.validator.Consume(buffer.Value)
 | 
											
												
													
														|  | 
 |  | +		this.chunkLength -= buffer.Len()
 | 
											
												
													
														|  | 
 |  | +		if this.chunkLength == 0 {
 | 
											
												
													
														|  | 
 |  | +			if !this.validator.Validate() {
 | 
											
												
													
														|  | 
 |  | +				buffer.Release()
 | 
											
												
													
														|  | 
 |  | +				return nil, transport.ErrorCorruptedPacket
 | 
											
												
													
														|  | 
 |  | +			}
 | 
											
												
													
														|  | 
 |  | +			this.chunkLength = -1
 | 
											
												
													
														|  | 
 |  | +			this.validator = nil
 | 
											
												
													
														|  | 
 |  | +		}
 | 
											
												
													
														|  | 
 |  | +	} else {
 | 
											
												
													
														|  | 
 |  | +		this.validator.Consume(buffer.Value[:this.chunkLength])
 | 
											
												
													
														|  | 
 |  | +		if !this.validator.Validate() {
 | 
											
												
													
														|  | 
 |  | +			buffer.Release()
 | 
											
												
													
														|  | 
 |  | +			return nil, transport.ErrorCorruptedPacket
 | 
											
												
													
														|  | 
 |  | +		}
 | 
											
												
													
														|  | 
 |  | +		leftLength := buffer.Len() - this.chunkLength
 | 
											
												
													
														|  | 
 |  | +		this.last = AllocBuffer(leftLength).Clear()
 | 
											
												
													
														|  | 
 |  | +		this.last.Append(buffer.Value[this.chunkLength:])
 | 
											
												
													
														|  | 
 |  | +		buffer.Slice(0, this.chunkLength)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +		this.chunkLength = -1
 | 
											
												
													
														|  | 
 |  | +		this.validator = nil
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  | -	buffer.SliceFrom(4)
 |  | 
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |  	return buffer, nil
 |  |  	return buffer, nil
 | 
											
												
													
														|  |  }
 |  |  }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  func (this *AuthChunkReader) Release() {
 |  |  func (this *AuthChunkReader) Release() {
 | 
											
												
													
														|  |  	this.reader = nil
 |  |  	this.reader = nil
 | 
											
												
													
														|  | 
 |  | +	this.last.Release()
 | 
											
												
													
														|  | 
 |  | +	this.last = nil
 | 
											
												
													
														|  | 
 |  | +	this.validator = nil
 | 
											
												
													
														|  |  }
 |  |  }
 |