dialer.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. package kcp
  2. import (
  3. "context"
  4. "crypto/cipher"
  5. "crypto/tls"
  6. "net"
  7. "sync"
  8. "sync/atomic"
  9. "v2ray.com/core/app/log"
  10. "v2ray.com/core/common"
  11. "v2ray.com/core/common/buf"
  12. "v2ray.com/core/common/dice"
  13. v2net "v2ray.com/core/common/net"
  14. "v2ray.com/core/transport/internet"
  15. v2tls "v2ray.com/core/transport/internet/tls"
  16. )
  17. var (
  18. globalConv = uint32(dice.RandomUint16())
  19. )
  20. type ClientConnection struct {
  21. sync.RWMutex
  22. net.Conn
  23. input func([]Segment)
  24. reader PacketReader
  25. writer PacketWriter
  26. }
  27. func (o *ClientConnection) Overhead() int {
  28. o.RLock()
  29. defer o.RUnlock()
  30. if o.writer == nil {
  31. return 0
  32. }
  33. return o.writer.Overhead()
  34. }
  35. func (o *ClientConnection) Write(b []byte) (int, error) {
  36. o.RLock()
  37. defer o.RUnlock()
  38. if o.writer == nil {
  39. return len(b), nil
  40. }
  41. return o.writer.Write(b)
  42. }
  43. func (o *ClientConnection) Read([]byte) (int, error) {
  44. panic("KCP|ClientConnection: Read should not be called.")
  45. }
  46. func (o *ClientConnection) Close() error {
  47. return o.Conn.Close()
  48. }
  49. func (o *ClientConnection) Reset(inputCallback func([]Segment)) {
  50. o.Lock()
  51. o.input = inputCallback
  52. o.Unlock()
  53. }
  54. func (o *ClientConnection) ResetSecurity(header internet.PacketHeader, security cipher.AEAD) {
  55. o.Lock()
  56. if o.reader == nil {
  57. o.reader = new(KCPPacketReader)
  58. }
  59. o.reader.(*KCPPacketReader).Header = header
  60. o.reader.(*KCPPacketReader).Security = security
  61. if o.writer == nil {
  62. o.writer = new(KCPPacketWriter)
  63. }
  64. o.writer.(*KCPPacketWriter).Header = header
  65. o.writer.(*KCPPacketWriter).Security = security
  66. o.writer.(*KCPPacketWriter).Writer = o.Conn
  67. o.Unlock()
  68. }
  69. func (o *ClientConnection) Run() {
  70. payload := buf.NewSmall()
  71. defer payload.Release()
  72. for {
  73. err := payload.Reset(buf.ReadFrom(o.Conn))
  74. if err != nil {
  75. payload.Release()
  76. return
  77. }
  78. o.RLock()
  79. if o.input != nil {
  80. segments := o.reader.Read(payload.Bytes())
  81. if len(segments) > 0 {
  82. o.input(segments)
  83. }
  84. }
  85. o.RUnlock()
  86. }
  87. }
  88. func DialKCP(ctx context.Context, dest v2net.Destination) (internet.Connection, error) {
  89. dest.Network = v2net.Network_UDP
  90. log.Trace(newError("dialing mKCP to ", dest))
  91. src := internet.DialerSourceFromContext(ctx)
  92. rawConn, err := internet.DialSystem(ctx, src, dest)
  93. if err != nil {
  94. log.Trace(newError("failed to dial to dest: ", err).AtError())
  95. return nil, err
  96. }
  97. conn := &ClientConnection{
  98. Conn: rawConn,
  99. }
  100. go conn.Run()
  101. kcpSettings := internet.TransportSettingsFromContext(ctx).(*Config)
  102. header, err := kcpSettings.GetPackerHeader()
  103. if err != nil {
  104. return nil, newError("failed to create packet header").Base(err)
  105. }
  106. security, err := kcpSettings.GetSecurity()
  107. if err != nil {
  108. return nil, newError("failed to create security").Base(err)
  109. }
  110. conn.ResetSecurity(header, security)
  111. conv := uint16(atomic.AddUint32(&globalConv, 1))
  112. session := NewConnection(conv, conn, kcpSettings)
  113. var iConn internet.Connection
  114. iConn = session
  115. if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
  116. switch securitySettings := securitySettings.(type) {
  117. case *v2tls.Config:
  118. config := securitySettings.GetTLSConfig()
  119. if dest.Address.Family().IsDomain() {
  120. config.ServerName = dest.Address.Domain()
  121. }
  122. tlsConn := tls.Client(iConn, config)
  123. iConn = tlsConn
  124. }
  125. }
  126. return iConn, nil
  127. }
  128. func init() {
  129. common.Must(internet.RegisterTransportDialer(internet.TransportProtocol_MKCP, DialKCP))
  130. }