dialer.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. package kcp
  2. import (
  3. "context"
  4. "crypto/cipher"
  5. "crypto/tls"
  6. "net"
  7. "sync"
  8. "sync/atomic"
  9. "v2ray.com/core/app/log"
  10. "v2ray.com/core/common"
  11. "v2ray.com/core/common/buf"
  12. "v2ray.com/core/common/dice"
  13. "v2ray.com/core/common/errors"
  14. v2net "v2ray.com/core/common/net"
  15. "v2ray.com/core/transport/internet"
  16. v2tls "v2ray.com/core/transport/internet/tls"
  17. )
  18. var (
  19. globalConv = uint32(dice.RandomUint16())
  20. )
  21. type ClientConnection struct {
  22. sync.RWMutex
  23. net.Conn
  24. input func([]Segment)
  25. reader PacketReader
  26. writer PacketWriter
  27. }
  28. func (o *ClientConnection) Overhead() int {
  29. o.RLock()
  30. defer o.RUnlock()
  31. if o.writer == nil {
  32. return 0
  33. }
  34. return o.writer.Overhead()
  35. }
  36. func (o *ClientConnection) Write(b []byte) (int, error) {
  37. o.RLock()
  38. defer o.RUnlock()
  39. if o.writer == nil {
  40. return len(b), nil
  41. }
  42. return o.writer.Write(b)
  43. }
  44. func (o *ClientConnection) Read([]byte) (int, error) {
  45. panic("KCP|ClientConnection: Read should not be called.")
  46. }
  47. func (o *ClientConnection) Close() error {
  48. return o.Conn.Close()
  49. }
  50. func (o *ClientConnection) Reset(inputCallback func([]Segment)) {
  51. o.Lock()
  52. o.input = inputCallback
  53. o.Unlock()
  54. }
  55. func (o *ClientConnection) ResetSecurity(header internet.PacketHeader, security cipher.AEAD) {
  56. o.Lock()
  57. if o.reader == nil {
  58. o.reader = new(KCPPacketReader)
  59. }
  60. o.reader.(*KCPPacketReader).Header = header
  61. o.reader.(*KCPPacketReader).Security = security
  62. if o.writer == nil {
  63. o.writer = new(KCPPacketWriter)
  64. }
  65. o.writer.(*KCPPacketWriter).Header = header
  66. o.writer.(*KCPPacketWriter).Security = security
  67. o.writer.(*KCPPacketWriter).Writer = o.Conn
  68. o.Unlock()
  69. }
  70. func (o *ClientConnection) Run() {
  71. payload := buf.NewSmall()
  72. defer payload.Release()
  73. for {
  74. err := payload.Reset(buf.ReadFrom(o.Conn))
  75. if err != nil {
  76. payload.Release()
  77. return
  78. }
  79. o.RLock()
  80. if o.input != nil {
  81. segments := o.reader.Read(payload.Bytes())
  82. if len(segments) > 0 {
  83. o.input(segments)
  84. }
  85. }
  86. o.RUnlock()
  87. }
  88. }
  89. func DialKCP(ctx context.Context, dest v2net.Destination) (internet.Connection, error) {
  90. dest.Network = v2net.Network_UDP
  91. log.Trace(errors.New("KCP|Dialer: Dialing KCP to ", dest))
  92. src := internet.DialerSourceFromContext(ctx)
  93. rawConn, err := internet.DialSystem(ctx, src, dest)
  94. if err != nil {
  95. log.Trace(errors.New("KCP|Dialer: Failed to dial to dest: ", err).AtError())
  96. return nil, err
  97. }
  98. conn := &ClientConnection{
  99. Conn: rawConn,
  100. }
  101. go conn.Run()
  102. kcpSettings := internet.TransportSettingsFromContext(ctx).(*Config)
  103. header, err := kcpSettings.GetPackerHeader()
  104. if err != nil {
  105. return nil, errors.New("KCP|Dialer: Failed to create packet header.").Base(err)
  106. }
  107. security, err := kcpSettings.GetSecurity()
  108. if err != nil {
  109. return nil, errors.New("KCP|Dialer: Failed to create security.").Base(err)
  110. }
  111. conn.ResetSecurity(header, security)
  112. conv := uint16(atomic.AddUint32(&globalConv, 1))
  113. session := NewConnection(conv, conn, kcpSettings)
  114. var iConn internet.Connection
  115. iConn = session
  116. if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
  117. switch securitySettings := securitySettings.(type) {
  118. case *v2tls.Config:
  119. config := securitySettings.GetTLSConfig()
  120. if dest.Address.Family().IsDomain() {
  121. config.ServerName = dest.Address.Domain()
  122. }
  123. tlsConn := tls.Client(iConn, config)
  124. iConn = tlsConn
  125. }
  126. }
  127. return iConn, nil
  128. }
  129. func init() {
  130. common.Must(internet.RegisterTransportDialer(internet.TransportProtocol_MKCP, DialKCP))
  131. }