reader.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package io
  2. import (
  3. "hash"
  4. "hash/fnv"
  5. "io"
  6. "github.com/v2ray/v2ray-core/common/alloc"
  7. "github.com/v2ray/v2ray-core/common/serial"
  8. "github.com/v2ray/v2ray-core/transport"
  9. )
  10. // @Private
  11. func AllocBuffer(size int) *alloc.Buffer {
  12. if size < 8*1024-16 {
  13. return alloc.NewBuffer()
  14. }
  15. return alloc.NewLargeBuffer()
  16. }
  17. // @Private
  18. type Validator struct {
  19. actualAuth hash.Hash32
  20. expectedAuth uint32
  21. }
  22. func NewValidator(expectedAuth uint32) *Validator {
  23. return &Validator{
  24. actualAuth: fnv.New32a(),
  25. expectedAuth: expectedAuth,
  26. }
  27. }
  28. func (this *Validator) Consume(b []byte) {
  29. this.actualAuth.Write(b)
  30. }
  31. func (this *Validator) Validate() bool {
  32. return this.actualAuth.Sum32() == this.expectedAuth
  33. }
  34. type AuthChunkReader struct {
  35. reader io.Reader
  36. last *alloc.Buffer
  37. chunkLength int
  38. validator *Validator
  39. }
  40. func NewAuthChunkReader(reader io.Reader) *AuthChunkReader {
  41. return &AuthChunkReader{
  42. reader: reader,
  43. chunkLength: -1,
  44. }
  45. }
  46. func (this *AuthChunkReader) Read() (*alloc.Buffer, error) {
  47. var buffer *alloc.Buffer
  48. if this.last != nil {
  49. buffer = this.last
  50. this.last = nil
  51. } else {
  52. buffer = AllocBuffer(this.chunkLength).Clear()
  53. }
  54. if this.chunkLength == -1 {
  55. for buffer.Len() < 6 {
  56. _, err := buffer.FillFrom(this.reader)
  57. if err != nil {
  58. buffer.Release()
  59. return nil, err
  60. }
  61. }
  62. length := serial.BytesLiteral(buffer.Value[:2]).Uint16Value()
  63. this.chunkLength = int(length) - 4
  64. this.validator = NewValidator(serial.BytesLiteral(buffer.Value[2:6]).Uint32Value())
  65. buffer.SliceFrom(6)
  66. } else if buffer.Len() < this.chunkLength {
  67. _, err := buffer.FillFrom(this.reader)
  68. if err != nil {
  69. buffer.Release()
  70. return nil, err
  71. }
  72. }
  73. if this.chunkLength == 0 {
  74. buffer.Release()
  75. return nil, io.EOF
  76. }
  77. if buffer.Len() <= this.chunkLength {
  78. this.validator.Consume(buffer.Value)
  79. this.chunkLength -= buffer.Len()
  80. if this.chunkLength == 0 {
  81. if !this.validator.Validate() {
  82. buffer.Release()
  83. return nil, transport.ErrorCorruptedPacket
  84. }
  85. this.chunkLength = -1
  86. this.validator = nil
  87. }
  88. } else {
  89. this.validator.Consume(buffer.Value[:this.chunkLength])
  90. if !this.validator.Validate() {
  91. buffer.Release()
  92. return nil, transport.ErrorCorruptedPacket
  93. }
  94. leftLength := buffer.Len() - this.chunkLength
  95. this.last = AllocBuffer(leftLength).Clear()
  96. this.last.Append(buffer.Value[this.chunkLength:])
  97. buffer.Slice(0, this.chunkLength)
  98. this.chunkLength = -1
  99. this.validator = nil
  100. }
  101. return buffer, nil
  102. }
  103. func (this *AuthChunkReader) Release() {
  104. this.reader = nil
  105. this.last.Release()
  106. this.last = nil
  107. this.validator = nil
  108. }