validation.go 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. package io
  2. import (
  3. "errors"
  4. "hash/fnv"
  5. "io"
  6. "github.com/v2ray/v2ray-core/common/alloc"
  7. "github.com/v2ray/v2ray-core/transport"
  8. )
  9. var (
  10. TruncatedPayload = errors.New("Truncated payload.")
  11. )
  12. type ValidationReader struct {
  13. reader io.Reader
  14. buffer *alloc.Buffer
  15. }
  16. func NewValidationReader(reader io.Reader) *ValidationReader {
  17. return &ValidationReader{
  18. reader: reader,
  19. buffer: alloc.NewLargeBuffer().Clear(),
  20. }
  21. }
  22. func (this *ValidationReader) Read(data []byte) (int, error) {
  23. nBytes, err := this.reader.Read(data)
  24. if err != nil {
  25. return nBytes, err
  26. }
  27. nBytesActual := 0
  28. dataActual := data[:]
  29. for {
  30. payload, rest, err := parsePayload(data)
  31. if err != nil {
  32. return nBytesActual, err
  33. }
  34. copy(dataActual, payload)
  35. nBytesActual += len(payload)
  36. dataActual = dataActual[nBytesActual:]
  37. if len(rest) == 0 {
  38. break
  39. }
  40. data = rest
  41. }
  42. return nBytesActual, nil
  43. }
  44. func parsePayload(data []byte) (payload []byte, rest []byte, err error) {
  45. dataLen := len(data)
  46. if dataLen < 6 {
  47. err = TruncatedPayload
  48. return
  49. }
  50. payloadLen := int(data[0])<<8 + int(data[1])
  51. if dataLen < payloadLen+6 {
  52. err = TruncatedPayload
  53. return
  54. }
  55. payload = data[6 : 6+payloadLen]
  56. rest = data[6+payloadLen:]
  57. fnv1a := fnv.New32a()
  58. fnv1a.Write(payload)
  59. actualHash := fnv1a.Sum32()
  60. expectedHash := uint32(data[2])<<24 + uint32(data[3])<<16 + uint32(data[4])<<8 + uint32(data[5])
  61. if actualHash != expectedHash {
  62. err = transport.ErrorCorruptedPacket
  63. return
  64. }
  65. return
  66. }