sess.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. package kcp
  2. import (
  3. crand "crypto/rand"
  4. "encoding/binary"
  5. "errors"
  6. "hash/crc32"
  7. "io"
  8. "log"
  9. "math/rand"
  10. "net"
  11. "sync"
  12. "time"
  13. "golang.org/x/net/ipv4"
  14. )
  15. var (
  16. errTimeout = errors.New("i/o timeout")
  17. errBrokenPipe = errors.New("broken pipe")
  18. )
  19. const (
  20. basePort = 20000 // minimum port for listening
  21. maxPort = 65535 // maximum port for listening
  22. defaultWndSize = 128 // default window size, in packet
  23. otpSize = 16 // magic number
  24. crcSize = 4 // 4bytes packet checksum
  25. cryptHeaderSize = otpSize + crcSize
  26. connTimeout = 60 * time.Second
  27. mtuLimit = 4096
  28. rxQueueLimit = 8192
  29. rxFecLimit = 2048
  30. )
  31. type (
  32. // UDPSession defines a KCP session implemented by UDP
  33. UDPSession struct {
  34. kcp *KCP // the core ARQ
  35. conn *net.UDPConn // the underlying UDP socket
  36. block BlockCrypt
  37. needUpdate bool
  38. l *Listener // point to server listener if it's a server socket
  39. local, remote net.Addr
  40. rd time.Time // read deadline
  41. wd time.Time // write deadline
  42. sockbuff []byte // kcp receiving is based on packet, I turn it into stream
  43. die chan struct{}
  44. isClosed bool
  45. mu sync.Mutex
  46. chReadEvent chan struct{}
  47. chWriteEvent chan struct{}
  48. chTicker chan time.Time
  49. chUDPOutput chan []byte
  50. headerSize int
  51. lastInputTs time.Time
  52. ackNoDelay bool
  53. }
  54. )
  55. // newUDPSession create a new udp session for client or server
  56. func newUDPSession(conv uint32, l *Listener, conn *net.UDPConn, remote *net.UDPAddr, block BlockCrypt) *UDPSession {
  57. sess := new(UDPSession)
  58. sess.chTicker = make(chan time.Time, 1)
  59. sess.chUDPOutput = make(chan []byte, rxQueueLimit)
  60. sess.die = make(chan struct{})
  61. sess.local = conn.LocalAddr()
  62. sess.chReadEvent = make(chan struct{}, 1)
  63. sess.chWriteEvent = make(chan struct{}, 1)
  64. sess.remote = remote
  65. sess.conn = conn
  66. sess.l = l
  67. sess.block = block
  68. sess.lastInputTs = time.Now()
  69. // caculate header size
  70. if sess.block != nil {
  71. sess.headerSize += cryptHeaderSize
  72. }
  73. sess.kcp = NewKCP(conv, func(buf []byte, size int) {
  74. if size >= IKCP_OVERHEAD {
  75. ext := make([]byte, sess.headerSize+size)
  76. copy(ext[sess.headerSize:], buf)
  77. sess.chUDPOutput <- ext
  78. }
  79. })
  80. sess.kcp.WndSize(defaultWndSize, defaultWndSize)
  81. sess.kcp.SetMtu(IKCP_MTU_DEF - sess.headerSize)
  82. go sess.updateTask()
  83. go sess.outputTask()
  84. if l == nil { // it's a client connection
  85. go sess.readLoop()
  86. }
  87. return sess
  88. }
  89. // Read implements the Conn Read method.
  90. func (s *UDPSession) Read(b []byte) (n int, err error) {
  91. for {
  92. s.mu.Lock()
  93. if len(s.sockbuff) > 0 { // copy from buffer
  94. n = copy(b, s.sockbuff)
  95. s.sockbuff = s.sockbuff[n:]
  96. s.mu.Unlock()
  97. return n, nil
  98. }
  99. if s.isClosed {
  100. s.mu.Unlock()
  101. return 0, errBrokenPipe
  102. }
  103. if !s.rd.IsZero() {
  104. if time.Now().After(s.rd) { // timeout
  105. s.mu.Unlock()
  106. return 0, errTimeout
  107. }
  108. }
  109. if n := s.kcp.PeekSize(); n > 0 { // data arrived
  110. if len(b) >= n {
  111. s.kcp.Recv(b)
  112. } else {
  113. buf := make([]byte, n)
  114. s.kcp.Recv(buf)
  115. n = copy(b, buf)
  116. s.sockbuff = buf[n:] // store remaining bytes into sockbuff for next read
  117. }
  118. s.mu.Unlock()
  119. return n, nil
  120. }
  121. var timeout <-chan time.Time
  122. if !s.rd.IsZero() {
  123. delay := s.rd.Sub(time.Now())
  124. timeout = time.After(delay)
  125. }
  126. s.mu.Unlock()
  127. // wait for read event or timeout
  128. select {
  129. case <-s.chReadEvent:
  130. case <-timeout:
  131. case <-s.die:
  132. }
  133. }
  134. }
  135. // Write implements the Conn Write method.
  136. func (s *UDPSession) Write(b []byte) (n int, err error) {
  137. for {
  138. s.mu.Lock()
  139. if s.isClosed {
  140. s.mu.Unlock()
  141. return 0, errBrokenPipe
  142. }
  143. if !s.wd.IsZero() {
  144. if time.Now().After(s.wd) { // timeout
  145. s.mu.Unlock()
  146. return 0, errTimeout
  147. }
  148. }
  149. if s.kcp.WaitSnd() < int(s.kcp.snd_wnd) {
  150. n = len(b)
  151. max := s.kcp.mss << 8
  152. for {
  153. if len(b) <= int(max) { // in most cases
  154. s.kcp.Send(b)
  155. break
  156. } else {
  157. s.kcp.Send(b[:max])
  158. b = b[max:]
  159. }
  160. }
  161. s.kcp.current = currentMs()
  162. s.kcp.flush()
  163. s.mu.Unlock()
  164. return n, nil
  165. }
  166. var timeout <-chan time.Time
  167. if !s.wd.IsZero() {
  168. delay := s.wd.Sub(time.Now())
  169. timeout = time.After(delay)
  170. }
  171. s.mu.Unlock()
  172. // wait for write event or timeout
  173. select {
  174. case <-s.chWriteEvent:
  175. case <-timeout:
  176. case <-s.die:
  177. }
  178. }
  179. }
  180. // Close closes the connection.
  181. func (s *UDPSession) Close() error {
  182. s.mu.Lock()
  183. defer s.mu.Unlock()
  184. if s.isClosed {
  185. return errBrokenPipe
  186. }
  187. close(s.die)
  188. s.isClosed = true
  189. if s.l == nil { // client socket close
  190. s.conn.Close()
  191. }
  192. return nil
  193. }
  194. // LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
  195. func (s *UDPSession) LocalAddr() net.Addr {
  196. return s.local
  197. }
  198. // RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
  199. func (s *UDPSession) RemoteAddr() net.Addr { return s.remote }
  200. // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
  201. func (s *UDPSession) SetDeadline(t time.Time) error {
  202. s.mu.Lock()
  203. defer s.mu.Unlock()
  204. s.rd = t
  205. s.wd = t
  206. return nil
  207. }
  208. // SetReadDeadline implements the Conn SetReadDeadline method.
  209. func (s *UDPSession) SetReadDeadline(t time.Time) error {
  210. s.mu.Lock()
  211. defer s.mu.Unlock()
  212. s.rd = t
  213. return nil
  214. }
  215. // SetWriteDeadline implements the Conn SetWriteDeadline method.
  216. func (s *UDPSession) SetWriteDeadline(t time.Time) error {
  217. s.mu.Lock()
  218. defer s.mu.Unlock()
  219. s.wd = t
  220. return nil
  221. }
  222. // SetWindowSize set maximum window size
  223. func (s *UDPSession) SetWindowSize(sndwnd, rcvwnd int) {
  224. s.mu.Lock()
  225. defer s.mu.Unlock()
  226. s.kcp.WndSize(sndwnd, rcvwnd)
  227. }
  228. // SetMtu sets the maximum transmission unit
  229. func (s *UDPSession) SetMtu(mtu int) {
  230. s.mu.Lock()
  231. defer s.mu.Unlock()
  232. s.kcp.SetMtu(mtu - s.headerSize)
  233. }
  234. // SetACKNoDelay changes ack flush option, set true to flush ack immediately,
  235. func (s *UDPSession) SetACKNoDelay(nodelay bool) {
  236. s.mu.Lock()
  237. defer s.mu.Unlock()
  238. s.ackNoDelay = nodelay
  239. }
  240. // SetNoDelay calls nodelay() of kcp
  241. func (s *UDPSession) SetNoDelay(nodelay, interval, resend, nc int) {
  242. s.mu.Lock()
  243. defer s.mu.Unlock()
  244. s.kcp.NoDelay(nodelay, interval, resend, nc)
  245. }
  246. // SetDSCP sets the DSCP field of IP header
  247. func (s *UDPSession) SetDSCP(tos int) {
  248. s.mu.Lock()
  249. defer s.mu.Unlock()
  250. if err := ipv4.NewConn(s.conn).SetTOS(tos << 2); err != nil {
  251. log.Println("set tos:", err)
  252. }
  253. }
  254. func (s *UDPSession) outputTask() {
  255. // ping
  256. ticker := time.NewTicker(5 * time.Second)
  257. defer ticker.Stop()
  258. for {
  259. select {
  260. case ext := <-s.chUDPOutput:
  261. if s.block != nil {
  262. io.ReadFull(crand.Reader, ext[:otpSize]) // OTP
  263. checksum := crc32.ChecksumIEEE(ext[cryptHeaderSize:])
  264. binary.LittleEndian.PutUint32(ext[otpSize:], checksum)
  265. s.block.Encrypt(ext, ext)
  266. }
  267. //if rand.Intn(100) < 80 {
  268. n, err := s.conn.WriteTo(ext, s.remote)
  269. if err != nil {
  270. log.Println(err, n)
  271. }
  272. //}
  273. case <-ticker.C:
  274. sz := rand.Intn(IKCP_MTU_DEF - s.headerSize - IKCP_OVERHEAD)
  275. sz += s.headerSize + IKCP_OVERHEAD
  276. ping := make([]byte, sz)
  277. io.ReadFull(crand.Reader, ping)
  278. if s.block != nil {
  279. checksum := crc32.ChecksumIEEE(ping[cryptHeaderSize:])
  280. binary.LittleEndian.PutUint32(ping[otpSize:], checksum)
  281. s.block.Encrypt(ping, ping)
  282. }
  283. n, err := s.conn.WriteTo(ping, s.remote)
  284. if err != nil {
  285. log.Println(err, n)
  286. }
  287. case <-s.die:
  288. return
  289. }
  290. }
  291. }
  292. // kcp update, input loop
  293. func (s *UDPSession) updateTask() {
  294. var tc <-chan time.Time
  295. if s.l == nil { // client
  296. ticker := time.NewTicker(10 * time.Millisecond)
  297. tc = ticker.C
  298. defer ticker.Stop()
  299. } else {
  300. tc = s.chTicker
  301. }
  302. var nextupdate uint32
  303. for {
  304. select {
  305. case <-tc:
  306. s.mu.Lock()
  307. current := currentMs()
  308. if current >= nextupdate || s.needUpdate {
  309. s.kcp.Update(current)
  310. nextupdate = s.kcp.Check(current)
  311. }
  312. if s.kcp.WaitSnd() < int(s.kcp.snd_wnd) {
  313. s.notifyWriteEvent()
  314. }
  315. s.needUpdate = false
  316. s.mu.Unlock()
  317. case <-s.die:
  318. if s.l != nil { // has listener
  319. s.l.chDeadlinks <- s.remote
  320. }
  321. return
  322. }
  323. }
  324. }
  325. // GetConv gets conversation id of a session
  326. func (s *UDPSession) GetConv() uint32 {
  327. return s.kcp.conv
  328. }
  329. func (s *UDPSession) notifyReadEvent() {
  330. select {
  331. case s.chReadEvent <- struct{}{}:
  332. default:
  333. }
  334. }
  335. func (s *UDPSession) notifyWriteEvent() {
  336. select {
  337. case s.chWriteEvent <- struct{}{}:
  338. default:
  339. }
  340. }
  341. func (s *UDPSession) kcpInput(data []byte) {
  342. now := time.Now()
  343. if now.Sub(s.lastInputTs) > connTimeout {
  344. s.Close()
  345. return
  346. }
  347. s.lastInputTs = now
  348. s.mu.Lock()
  349. s.kcp.current = currentMs()
  350. s.kcp.Input(data)
  351. if s.ackNoDelay {
  352. s.kcp.current = currentMs()
  353. s.kcp.flush()
  354. } else {
  355. s.needUpdate = true
  356. }
  357. s.mu.Unlock()
  358. s.notifyReadEvent()
  359. }
  360. func (s *UDPSession) receiver(ch chan []byte) {
  361. for {
  362. data := make([]byte, mtuLimit)
  363. if n, _, err := s.conn.ReadFromUDP(data); err == nil && n >= s.headerSize+IKCP_OVERHEAD {
  364. ch <- data[:n]
  365. } else if err != nil {
  366. return
  367. }
  368. }
  369. }
  370. // read loop for client session
  371. func (s *UDPSession) readLoop() {
  372. chPacket := make(chan []byte, rxQueueLimit)
  373. go s.receiver(chPacket)
  374. for {
  375. select {
  376. case data := <-chPacket:
  377. dataValid := false
  378. if s.block != nil {
  379. s.block.Decrypt(data, data)
  380. data = data[otpSize:]
  381. checksum := crc32.ChecksumIEEE(data[crcSize:])
  382. if checksum == binary.LittleEndian.Uint32(data) {
  383. data = data[crcSize:]
  384. dataValid = true
  385. }
  386. } else if s.block == nil {
  387. dataValid = true
  388. }
  389. if dataValid {
  390. s.kcpInput(data)
  391. }
  392. case <-s.die:
  393. return
  394. }
  395. }
  396. }
  397. type (
  398. // Listener defines a server listening for connections
  399. Listener struct {
  400. block BlockCrypt
  401. conn *net.UDPConn
  402. sessions map[string]*UDPSession
  403. chAccepts chan *UDPSession
  404. chDeadlinks chan net.Addr
  405. headerSize int
  406. die chan struct{}
  407. }
  408. packet struct {
  409. from *net.UDPAddr
  410. data []byte
  411. }
  412. )
  413. // monitor incoming data for all connections of server
  414. func (l *Listener) monitor() {
  415. chPacket := make(chan packet, rxQueueLimit)
  416. go l.receiver(chPacket)
  417. ticker := time.NewTicker(10 * time.Millisecond)
  418. defer ticker.Stop()
  419. for {
  420. select {
  421. case p := <-chPacket:
  422. data := p.data
  423. from := p.from
  424. dataValid := false
  425. if l.block != nil {
  426. l.block.Decrypt(data, data)
  427. data = data[otpSize:]
  428. checksum := crc32.ChecksumIEEE(data[crcSize:])
  429. if checksum == binary.LittleEndian.Uint32(data) {
  430. data = data[crcSize:]
  431. dataValid = true
  432. }
  433. } else if l.block == nil {
  434. dataValid = true
  435. }
  436. if dataValid {
  437. addr := from.String()
  438. s, ok := l.sessions[addr]
  439. if !ok { // new session
  440. var conv uint32
  441. convValid := false
  442. conv = binary.LittleEndian.Uint32(data)
  443. convValid = true
  444. if convValid {
  445. s := newUDPSession(conv, l, l.conn, from, l.block)
  446. s.kcpInput(data)
  447. l.sessions[addr] = s
  448. l.chAccepts <- s
  449. }
  450. } else {
  451. s.kcpInput(data)
  452. }
  453. }
  454. case deadlink := <-l.chDeadlinks:
  455. delete(l.sessions, deadlink.String())
  456. case <-l.die:
  457. return
  458. case <-ticker.C:
  459. now := time.Now()
  460. for _, s := range l.sessions {
  461. select {
  462. case s.chTicker <- now:
  463. default:
  464. }
  465. }
  466. }
  467. }
  468. }
  469. func (l *Listener) receiver(ch chan packet) {
  470. for {
  471. data := make([]byte, mtuLimit)
  472. if n, from, err := l.conn.ReadFromUDP(data); err == nil && n >= l.headerSize+IKCP_OVERHEAD {
  473. ch <- packet{from, data[:n]}
  474. } else if err != nil {
  475. return
  476. }
  477. }
  478. }
  479. // Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn.
  480. func (l *Listener) Accept() (*UDPSession, error) {
  481. select {
  482. case c := <-l.chAccepts:
  483. return c, nil
  484. case <-l.die:
  485. return nil, errors.New("listener stopped")
  486. }
  487. }
  488. // Close stops listening on the UDP address. Already Accepted connections are not closed.
  489. func (l *Listener) Close() error {
  490. if err := l.conn.Close(); err == nil {
  491. close(l.die)
  492. return nil
  493. } else {
  494. return err
  495. }
  496. }
  497. // Addr returns the listener's network address, The Addr returned is shared by all invocations of Addr, so do not modify it.
  498. func (l *Listener) Addr() net.Addr {
  499. return l.conn.LocalAddr()
  500. }
  501. func currentMs() uint32 {
  502. return uint32(time.Now().UnixNano() / int64(time.Millisecond))
  503. }