connection.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. package kcp
  2. import (
  3. "errors"
  4. "io"
  5. "net"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. "github.com/v2ray/v2ray-core/common/alloc"
  10. "github.com/v2ray/v2ray-core/common/log"
  11. "github.com/v2ray/v2ray-core/transport/internet"
  12. )
  13. var (
  14. ErrIOTimeout = errors.New("Read/Write timeout")
  15. ErrClosedListener = errors.New("Listener closed.")
  16. ErrClosedConnection = errors.New("Connection closed.")
  17. )
  18. type State int32
  19. func (this State) Is(states ...State) bool {
  20. for _, state := range states {
  21. if this == state {
  22. return true
  23. }
  24. }
  25. return false
  26. }
  27. const (
  28. StateActive State = 0
  29. StateReadyToClose State = 1
  30. StatePeerClosed State = 2
  31. StateTerminating State = 3
  32. StatePeerTerminating State = 4
  33. StateTerminated State = 5
  34. )
  35. const (
  36. headerSize uint32 = 2
  37. )
  38. func nowMillisec() int64 {
  39. now := time.Now()
  40. return now.Unix()*1000 + int64(now.Nanosecond()/1000000)
  41. }
  42. type RoundTripInfo struct {
  43. sync.RWMutex
  44. variation uint32
  45. srtt uint32
  46. rto uint32
  47. minRtt uint32
  48. updatedTimestamp uint32
  49. }
  50. func (this *RoundTripInfo) UpdatePeerRTO(rto uint32, current uint32) {
  51. this.Lock()
  52. defer this.Unlock()
  53. if current-this.updatedTimestamp < 3000 {
  54. return
  55. }
  56. this.updatedTimestamp = current
  57. this.rto = rto
  58. }
  59. func (this *RoundTripInfo) Update(rtt uint32, current uint32) {
  60. if rtt > 0x7FFFFFFF {
  61. return
  62. }
  63. this.Lock()
  64. defer this.Unlock()
  65. // https://tools.ietf.org/html/rfc6298
  66. if this.srtt == 0 {
  67. this.srtt = rtt
  68. this.variation = rtt / 2
  69. } else {
  70. delta := rtt - this.srtt
  71. if this.srtt > rtt {
  72. delta = this.srtt - rtt
  73. }
  74. this.variation = (3*this.variation + delta) / 4
  75. this.srtt = (7*this.srtt + rtt) / 8
  76. if this.srtt < this.minRtt {
  77. this.srtt = this.minRtt
  78. }
  79. }
  80. var rto uint32
  81. if this.minRtt < 4*this.variation {
  82. rto = this.srtt + 4*this.variation
  83. } else {
  84. rto = this.srtt + this.variation
  85. }
  86. if rto > 10000 {
  87. rto = 10000
  88. }
  89. this.rto = rto * 3 / 2
  90. this.updatedTimestamp = current
  91. }
  92. func (this *RoundTripInfo) Timeout() uint32 {
  93. this.RLock()
  94. defer this.RUnlock()
  95. return this.rto
  96. }
  97. func (this *RoundTripInfo) SmoothedTime() uint32 {
  98. this.RLock()
  99. defer this.RUnlock()
  100. return this.srtt
  101. }
  102. // Connection is a KCP connection over UDP.
  103. type Connection struct {
  104. block internet.Authenticator
  105. local, remote net.Addr
  106. rd time.Time
  107. wd time.Time // write deadline
  108. writer io.WriteCloser
  109. since int64
  110. dataInputCond *sync.Cond
  111. dataOutputCond *sync.Cond
  112. conv uint16
  113. state State
  114. stateBeginTime uint32
  115. lastIncomingTime uint32
  116. lastPingTime uint32
  117. mss uint32
  118. roundTrip *RoundTripInfo
  119. interval uint32
  120. receivingWorker *ReceivingWorker
  121. sendingWorker *SendingWorker
  122. fastresend uint32
  123. congestionControl bool
  124. output *BufferedSegmentWriter
  125. }
  126. // NewConnection create a new KCP connection between local and remote.
  127. func NewConnection(conv uint16, writerCloser io.WriteCloser, local *net.UDPAddr, remote *net.UDPAddr, block internet.Authenticator) *Connection {
  128. log.Info("KCP|Connection: creating connection ", conv)
  129. conn := new(Connection)
  130. conn.local = local
  131. conn.remote = remote
  132. conn.block = block
  133. conn.writer = writerCloser
  134. conn.since = nowMillisec()
  135. conn.dataInputCond = sync.NewCond(new(sync.Mutex))
  136. conn.dataOutputCond = sync.NewCond(new(sync.Mutex))
  137. authWriter := &AuthenticationWriter{
  138. Authenticator: block,
  139. Writer: writerCloser,
  140. }
  141. conn.conv = conv
  142. conn.output = NewSegmentWriter(authWriter)
  143. conn.mss = authWriter.Mtu() - DataSegmentOverhead
  144. conn.roundTrip = &RoundTripInfo{
  145. rto: 100,
  146. minRtt: effectiveConfig.Tti,
  147. }
  148. conn.interval = effectiveConfig.Tti
  149. conn.receivingWorker = NewReceivingWorker(conn)
  150. conn.fastresend = 2
  151. conn.congestionControl = effectiveConfig.Congestion
  152. conn.sendingWorker = NewSendingWorker(conn)
  153. go conn.updateTask()
  154. return conn
  155. }
  156. func (this *Connection) Elapsed() uint32 {
  157. return uint32(nowMillisec() - this.since)
  158. }
  159. // Read implements the Conn Read method.
  160. func (this *Connection) Read(b []byte) (int, error) {
  161. if this == nil {
  162. return 0, io.EOF
  163. }
  164. for {
  165. if this.State().Is(StateReadyToClose, StateTerminating, StateTerminated) {
  166. return 0, io.EOF
  167. }
  168. nBytes := this.receivingWorker.Read(b)
  169. if nBytes > 0 {
  170. return nBytes, nil
  171. }
  172. if this.State() == StatePeerTerminating {
  173. return 0, io.EOF
  174. }
  175. var timer *time.Timer
  176. if !this.rd.IsZero() {
  177. duration := this.rd.Sub(time.Now())
  178. if duration <= 0 {
  179. return 0, ErrIOTimeout
  180. }
  181. timer = time.AfterFunc(duration, this.dataInputCond.Signal)
  182. }
  183. this.dataInputCond.L.Lock()
  184. this.dataInputCond.Wait()
  185. this.dataInputCond.L.Unlock()
  186. if timer != nil {
  187. timer.Stop()
  188. }
  189. if !this.rd.IsZero() && this.rd.Before(time.Now()) {
  190. return 0, ErrIOTimeout
  191. }
  192. }
  193. }
  194. // Write implements the Conn Write method.
  195. func (this *Connection) Write(b []byte) (int, error) {
  196. totalWritten := 0
  197. for {
  198. if this == nil || this.State() != StateActive {
  199. return totalWritten, io.ErrClosedPipe
  200. }
  201. nBytes := this.sendingWorker.Push(b[totalWritten:])
  202. if nBytes > 0 {
  203. totalWritten += nBytes
  204. if totalWritten == len(b) {
  205. return totalWritten, nil
  206. }
  207. }
  208. var timer *time.Timer
  209. if !this.wd.IsZero() {
  210. duration := this.wd.Sub(time.Now())
  211. if duration <= 0 {
  212. return totalWritten, ErrIOTimeout
  213. }
  214. timer = time.AfterFunc(duration, this.dataOutputCond.Signal)
  215. }
  216. this.dataOutputCond.L.Lock()
  217. this.dataOutputCond.Wait()
  218. this.dataOutputCond.L.Unlock()
  219. if timer != nil {
  220. timer.Stop()
  221. }
  222. if !this.wd.IsZero() && this.wd.Before(time.Now()) {
  223. return totalWritten, ErrIOTimeout
  224. }
  225. }
  226. }
  227. func (this *Connection) SetState(state State) {
  228. current := this.Elapsed()
  229. atomic.StoreInt32((*int32)(&this.state), int32(state))
  230. atomic.StoreUint32(&this.stateBeginTime, current)
  231. log.Debug("KCP|Connection: #", this.conv, " entering state ", state, " at ", current)
  232. switch state {
  233. case StateReadyToClose:
  234. this.receivingWorker.CloseRead()
  235. case StatePeerClosed:
  236. this.sendingWorker.CloseWrite()
  237. case StateTerminating:
  238. this.receivingWorker.CloseRead()
  239. this.sendingWorker.CloseWrite()
  240. case StatePeerTerminating:
  241. this.sendingWorker.CloseWrite()
  242. case StateTerminated:
  243. this.receivingWorker.CloseRead()
  244. this.sendingWorker.CloseWrite()
  245. }
  246. }
  247. // Close closes the connection.
  248. func (this *Connection) Close() error {
  249. if this == nil {
  250. return ErrClosedConnection
  251. }
  252. this.dataInputCond.Broadcast()
  253. this.dataOutputCond.Broadcast()
  254. state := this.State()
  255. if state.Is(StateReadyToClose, StateTerminating, StateTerminated) {
  256. return ErrClosedConnection
  257. }
  258. log.Info("KCP|Connection: Closing connection to ", this.remote)
  259. if state == StateActive {
  260. this.SetState(StateReadyToClose)
  261. }
  262. if state == StatePeerClosed {
  263. this.SetState(StateTerminating)
  264. }
  265. if state == StatePeerTerminating {
  266. this.SetState(StateTerminated)
  267. }
  268. return nil
  269. }
  270. // LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
  271. func (this *Connection) LocalAddr() net.Addr {
  272. if this == nil {
  273. return nil
  274. }
  275. return this.local
  276. }
  277. // RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
  278. func (this *Connection) RemoteAddr() net.Addr {
  279. if this == nil {
  280. return nil
  281. }
  282. return this.remote
  283. }
  284. // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
  285. func (this *Connection) SetDeadline(t time.Time) error {
  286. if err := this.SetReadDeadline(t); err != nil {
  287. return err
  288. }
  289. if err := this.SetWriteDeadline(t); err != nil {
  290. return err
  291. }
  292. return nil
  293. }
  294. // SetReadDeadline implements the Conn SetReadDeadline method.
  295. func (this *Connection) SetReadDeadline(t time.Time) error {
  296. if this == nil || this.State() != StateActive {
  297. return ErrClosedConnection
  298. }
  299. this.rd = t
  300. return nil
  301. }
  302. // SetWriteDeadline implements the Conn SetWriteDeadline method.
  303. func (this *Connection) SetWriteDeadline(t time.Time) error {
  304. if this == nil || this.State() != StateActive {
  305. return ErrClosedConnection
  306. }
  307. this.wd = t
  308. return nil
  309. }
  310. // kcp update, input loop
  311. func (this *Connection) updateTask() {
  312. for this.State() != StateTerminated {
  313. this.flush()
  314. interval := time.Duration(effectiveConfig.Tti) * time.Millisecond
  315. if this.State() == StateTerminating {
  316. interval = time.Second
  317. }
  318. time.Sleep(interval)
  319. }
  320. this.Terminate()
  321. }
  322. func (this *Connection) FetchInputFrom(conn io.Reader) {
  323. go func() {
  324. payload := alloc.NewLocalBuffer(2048)
  325. defer payload.Release()
  326. for this.State() != StateTerminated {
  327. payload.Reset()
  328. nBytes, err := conn.Read(payload.Value)
  329. if err != nil {
  330. return
  331. }
  332. payload.Slice(0, nBytes)
  333. if this.block.Open(payload) {
  334. this.Input(payload.Value)
  335. }
  336. }
  337. }()
  338. }
  339. func (this *Connection) Reusable() bool {
  340. return false
  341. }
  342. func (this *Connection) SetReusable(b bool) {}
  343. func (this *Connection) Terminate() {
  344. if this == nil || this.writer == nil {
  345. return
  346. }
  347. log.Info("KCP|Connection: Terminating connection to ", this.RemoteAddr())
  348. this.SetState(StateTerminated)
  349. this.dataInputCond.Broadcast()
  350. this.dataOutputCond.Broadcast()
  351. this.writer.Close()
  352. }
  353. func (this *Connection) HandleOption(opt SegmentOption) {
  354. if (opt & SegmentOptionClose) == SegmentOptionClose {
  355. this.OnPeerClosed()
  356. }
  357. }
  358. func (this *Connection) OnPeerClosed() {
  359. state := this.State()
  360. if state == StateReadyToClose {
  361. this.SetState(StateTerminating)
  362. }
  363. if state == StateActive {
  364. this.SetState(StatePeerClosed)
  365. }
  366. }
  367. // Input when you received a low level packet (eg. UDP packet), call it
  368. func (this *Connection) Input(data []byte) int {
  369. current := this.Elapsed()
  370. atomic.StoreUint32(&this.lastIncomingTime, current)
  371. var seg Segment
  372. for {
  373. seg, data = ReadSegment(data)
  374. if seg == nil {
  375. break
  376. }
  377. switch seg := seg.(type) {
  378. case *DataSegment:
  379. this.HandleOption(seg.Option)
  380. this.receivingWorker.ProcessSegment(seg)
  381. this.dataInputCond.Signal()
  382. case *AckSegment:
  383. this.HandleOption(seg.Option)
  384. this.sendingWorker.ProcessSegment(current, seg)
  385. this.dataOutputCond.Signal()
  386. case *CmdOnlySegment:
  387. this.HandleOption(seg.Option)
  388. if seg.Command == CommandTerminate {
  389. state := this.State()
  390. if state == StateActive ||
  391. state == StatePeerClosed {
  392. this.SetState(StatePeerTerminating)
  393. } else if state == StateReadyToClose {
  394. this.SetState(StateTerminating)
  395. } else if state == StateTerminating {
  396. this.SetState(StateTerminated)
  397. }
  398. }
  399. this.sendingWorker.ProcessReceivingNext(seg.ReceivinNext)
  400. this.receivingWorker.ProcessSendingNext(seg.SendingNext)
  401. this.roundTrip.UpdatePeerRTO(seg.PeerRTO, current)
  402. seg.Release()
  403. default:
  404. }
  405. }
  406. return 0
  407. }
  408. func (this *Connection) flush() {
  409. current := this.Elapsed()
  410. if this.State() == StateTerminated {
  411. return
  412. }
  413. if this.State() == StateActive && current-atomic.LoadUint32(&this.lastIncomingTime) >= 30000 {
  414. this.Close()
  415. }
  416. if this.State() == StateReadyToClose && this.sendingWorker.IsEmpty() {
  417. this.SetState(StateTerminating)
  418. }
  419. if this.State() == StateTerminating {
  420. log.Debug("KCP|Connection: #", this.conv, " sending terminating cmd.")
  421. seg := NewCmdOnlySegment()
  422. defer seg.Release()
  423. seg.Conv = this.conv
  424. seg.Command = CommandTerminate
  425. this.output.Write(seg)
  426. this.output.Flush()
  427. if current-atomic.LoadUint32(&this.stateBeginTime) > 8000 {
  428. this.SetState(StateTerminated)
  429. }
  430. return
  431. }
  432. if this.State() == StatePeerTerminating && current-atomic.LoadUint32(&this.stateBeginTime) > 4000 {
  433. this.SetState(StateTerminating)
  434. }
  435. if this.State() == StateReadyToClose && current-atomic.LoadUint32(&this.stateBeginTime) > 15000 {
  436. this.SetState(StateTerminating)
  437. }
  438. // flush acknowledges
  439. this.receivingWorker.Flush(current)
  440. this.sendingWorker.Flush(current)
  441. if this.sendingWorker.PingNecessary() || this.receivingWorker.PingNecessary() || current-atomic.LoadUint32(&this.lastPingTime) >= 5000 {
  442. seg := NewCmdOnlySegment()
  443. seg.Conv = this.conv
  444. seg.Command = CommandPing
  445. seg.ReceivinNext = this.receivingWorker.nextNumber
  446. seg.SendingNext = this.sendingWorker.firstUnacknowledged
  447. seg.PeerRTO = this.roundTrip.Timeout()
  448. if this.State() == StateReadyToClose {
  449. seg.Option = SegmentOptionClose
  450. }
  451. this.output.Write(seg)
  452. this.lastPingTime = current
  453. this.sendingWorker.MarkPingNecessary(false)
  454. this.receivingWorker.MarkPingNecessary(false)
  455. seg.Release()
  456. }
  457. // flash remain segments
  458. this.output.Flush()
  459. }
  460. func (this *Connection) State() State {
  461. return State(atomic.LoadInt32((*int32)(&this.state)))
  462. }