io.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. //go:build !confonly
  2. // +build !confonly
  3. package kcp
  4. import (
  5. "crypto/cipher"
  6. "crypto/rand"
  7. "io"
  8. "github.com/v2fly/v2ray-core/v4/common"
  9. "github.com/v2fly/v2ray-core/v4/common/buf"
  10. "github.com/v2fly/v2ray-core/v4/transport/internet"
  11. )
  12. type PacketReader interface {
  13. Read([]byte) []Segment
  14. }
  15. type PacketWriter interface {
  16. Overhead() int
  17. io.Writer
  18. }
  19. type KCPPacketReader struct { // nolint: golint
  20. Security cipher.AEAD
  21. Header internet.PacketHeader
  22. }
  23. func (r *KCPPacketReader) Read(b []byte) []Segment {
  24. if r.Header != nil {
  25. if int32(len(b)) <= r.Header.Size() {
  26. return nil
  27. }
  28. b = b[r.Header.Size():]
  29. }
  30. if r.Security != nil {
  31. nonceSize := r.Security.NonceSize()
  32. overhead := r.Security.Overhead()
  33. if len(b) <= nonceSize+overhead {
  34. return nil
  35. }
  36. out, err := r.Security.Open(b[nonceSize:nonceSize], b[:nonceSize], b[nonceSize:], nil)
  37. if err != nil {
  38. return nil
  39. }
  40. b = out
  41. }
  42. var result []Segment
  43. for len(b) > 0 {
  44. seg, x := ReadSegment(b)
  45. if seg == nil {
  46. break
  47. }
  48. result = append(result, seg)
  49. b = x
  50. }
  51. return result
  52. }
  53. type KCPPacketWriter struct { // nolint: golint
  54. Header internet.PacketHeader
  55. Security cipher.AEAD
  56. Writer io.Writer
  57. }
  58. func (w *KCPPacketWriter) Overhead() int {
  59. overhead := 0
  60. if w.Header != nil {
  61. overhead += int(w.Header.Size())
  62. }
  63. if w.Security != nil {
  64. overhead += w.Security.Overhead()
  65. }
  66. return overhead
  67. }
  68. func (w *KCPPacketWriter) Write(b []byte) (int, error) {
  69. bb := buf.StackNew()
  70. defer bb.Release()
  71. if w.Header != nil {
  72. w.Header.Serialize(bb.Extend(w.Header.Size()))
  73. }
  74. if w.Security != nil {
  75. nonceSize := w.Security.NonceSize()
  76. common.Must2(bb.ReadFullFrom(rand.Reader, int32(nonceSize)))
  77. nonce := bb.BytesFrom(int32(-nonceSize))
  78. encrypted := bb.Extend(int32(w.Security.Overhead() + len(b)))
  79. w.Security.Seal(encrypted[:0], nonce, b, nil)
  80. } else {
  81. bb.Write(b)
  82. }
  83. _, err := w.Writer.Write(bb.Bytes())
  84. return len(b), err
  85. }