reader.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package io
  2. import (
  3. "hash"
  4. "hash/fnv"
  5. "io"
  6. "github.com/v2ray/v2ray-core/common/alloc"
  7. "github.com/v2ray/v2ray-core/common/serial"
  8. "github.com/v2ray/v2ray-core/transport"
  9. )
  10. // @Private
  11. func AllocBuffer(size int) *alloc.Buffer {
  12. if size < 8*1024-16 {
  13. return alloc.NewBuffer()
  14. }
  15. return alloc.NewLargeBuffer()
  16. }
  17. // @Private
  18. type Validator struct {
  19. actualAuth hash.Hash32
  20. expectedAuth uint32
  21. }
  22. func NewValidator(expectedAuth uint32) *Validator {
  23. return &Validator{
  24. actualAuth: fnv.New32a(),
  25. expectedAuth: expectedAuth,
  26. }
  27. }
  28. func (this *Validator) Consume(b []byte) {
  29. this.actualAuth.Write(b)
  30. }
  31. func (this *Validator) Validate() bool {
  32. return this.actualAuth.Sum32() == this.expectedAuth
  33. }
  34. type AuthChunkReader struct {
  35. reader io.Reader
  36. last *alloc.Buffer
  37. chunkLength int
  38. validator *Validator
  39. }
  40. func NewAuthChunkReader(reader io.Reader) *AuthChunkReader {
  41. return &AuthChunkReader{
  42. reader: reader,
  43. chunkLength: -1,
  44. }
  45. }
  46. func (this *AuthChunkReader) Read() (*alloc.Buffer, error) {
  47. var buffer *alloc.Buffer
  48. if this.last != nil {
  49. buffer = this.last
  50. this.last = nil
  51. } else {
  52. buffer = AllocBuffer(this.chunkLength).Clear()
  53. }
  54. _, err := buffer.FillFrom(this.reader)
  55. if err != nil {
  56. buffer.Release()
  57. return nil, err
  58. }
  59. if this.chunkLength == -1 {
  60. for buffer.Len() < 6 {
  61. _, err := buffer.FillFrom(this.reader)
  62. if err != nil {
  63. buffer.Release()
  64. return nil, err
  65. }
  66. }
  67. length := serial.BytesLiteral(buffer.Value[:2]).Uint16Value()
  68. this.chunkLength = int(length) - 4
  69. this.validator = NewValidator(serial.BytesLiteral(buffer.Value[2:6]).Uint32Value())
  70. buffer.SliceFrom(6)
  71. }
  72. if this.chunkLength == 0 {
  73. buffer.Release()
  74. return nil, io.EOF
  75. }
  76. if buffer.Len() <= this.chunkLength {
  77. this.validator.Consume(buffer.Value)
  78. this.chunkLength -= buffer.Len()
  79. if this.chunkLength == 0 {
  80. if !this.validator.Validate() {
  81. buffer.Release()
  82. return nil, transport.ErrorCorruptedPacket
  83. }
  84. this.chunkLength = -1
  85. this.validator = nil
  86. }
  87. } else {
  88. this.validator.Consume(buffer.Value[:this.chunkLength])
  89. if !this.validator.Validate() {
  90. buffer.Release()
  91. return nil, transport.ErrorCorruptedPacket
  92. }
  93. leftLength := buffer.Len() - this.chunkLength
  94. this.last = AllocBuffer(leftLength).Clear()
  95. this.last.Append(buffer.Value[this.chunkLength:])
  96. buffer.Slice(0, this.chunkLength)
  97. this.chunkLength = -1
  98. this.validator = nil
  99. }
  100. return buffer, nil
  101. }
  102. func (this *AuthChunkReader) Release() {
  103. this.reader = nil
  104. this.last.Release()
  105. this.last = nil
  106. this.validator = nil
  107. }