io.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. package kcp
  2. import (
  3. "crypto/cipher"
  4. "crypto/rand"
  5. "io"
  6. "github.com/v2fly/v2ray-core/v4/common"
  7. "github.com/v2fly/v2ray-core/v4/common/buf"
  8. "github.com/v2fly/v2ray-core/v4/transport/internet"
  9. )
  10. type PacketReader interface {
  11. Read([]byte) []Segment
  12. }
  13. type PacketWriter interface {
  14. Overhead() int
  15. io.Writer
  16. }
  17. type KCPPacketReader struct { // nolint: golint
  18. Security cipher.AEAD
  19. Header internet.PacketHeader
  20. }
  21. func (r *KCPPacketReader) Read(b []byte) []Segment {
  22. if r.Header != nil {
  23. if int32(len(b)) <= r.Header.Size() {
  24. return nil
  25. }
  26. b = b[r.Header.Size():]
  27. }
  28. if r.Security != nil {
  29. nonceSize := r.Security.NonceSize()
  30. overhead := r.Security.Overhead()
  31. if len(b) <= nonceSize+overhead {
  32. return nil
  33. }
  34. out, err := r.Security.Open(b[nonceSize:nonceSize], b[:nonceSize], b[nonceSize:], nil)
  35. if err != nil {
  36. return nil
  37. }
  38. b = out
  39. }
  40. var result []Segment
  41. for len(b) > 0 {
  42. seg, x := ReadSegment(b)
  43. if seg == nil {
  44. break
  45. }
  46. result = append(result, seg)
  47. b = x
  48. }
  49. return result
  50. }
  51. type KCPPacketWriter struct { // nolint: golint
  52. Header internet.PacketHeader
  53. Security cipher.AEAD
  54. Writer io.Writer
  55. }
  56. func (w *KCPPacketWriter) Overhead() int {
  57. overhead := 0
  58. if w.Header != nil {
  59. overhead += int(w.Header.Size())
  60. }
  61. if w.Security != nil {
  62. overhead += w.Security.Overhead()
  63. }
  64. return overhead
  65. }
  66. func (w *KCPPacketWriter) Write(b []byte) (int, error) {
  67. bb := buf.StackNew()
  68. defer bb.Release()
  69. if w.Header != nil {
  70. w.Header.Serialize(bb.Extend(w.Header.Size()))
  71. }
  72. if w.Security != nil {
  73. nonceSize := w.Security.NonceSize()
  74. common.Must2(bb.ReadFullFrom(rand.Reader, int32(nonceSize)))
  75. nonce := bb.BytesFrom(int32(-nonceSize))
  76. encrypted := bb.Extend(int32(w.Security.Overhead() + len(b)))
  77. w.Security.Seal(encrypted[:0], nonce, b, nil)
  78. } else {
  79. bb.Write(b)
  80. }
  81. _, err := w.Writer.Write(bb.Bytes())
  82. return len(b), err
  83. }