conn_test.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696
  1. // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "bufio"
  7. "bytes"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/ioutil"
  12. "net"
  13. "reflect"
  14. "sync"
  15. "testing"
  16. "testing/iotest"
  17. "time"
  18. )
  19. var _ net.Error = errWriteTimeout
  20. type fakeNetConn struct {
  21. io.Reader
  22. io.Writer
  23. }
  24. func (c fakeNetConn) Close() error { return nil }
  25. func (c fakeNetConn) LocalAddr() net.Addr { return localAddr }
  26. func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr }
  27. func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
  28. func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
  29. func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
  30. type fakeAddr int
  31. var (
  32. localAddr = fakeAddr(1)
  33. remoteAddr = fakeAddr(2)
  34. )
  35. func (a fakeAddr) Network() string {
  36. return "net"
  37. }
  38. func (a fakeAddr) String() string {
  39. return "str"
  40. }
  41. // newTestConn creates a connnection backed by a fake network connection using
  42. // default values for buffering.
  43. func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
  44. return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil)
  45. }
  46. func TestFraming(t *testing.T) {
  47. frameSizes := []int{
  48. 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
  49. // 65536, 65537
  50. }
  51. var readChunkers = []struct {
  52. name string
  53. f func(io.Reader) io.Reader
  54. }{
  55. {"half", iotest.HalfReader},
  56. {"one", iotest.OneByteReader},
  57. {"asis", func(r io.Reader) io.Reader { return r }},
  58. }
  59. writeBuf := make([]byte, 65537)
  60. for i := range writeBuf {
  61. writeBuf[i] = byte(i)
  62. }
  63. var writers = []struct {
  64. name string
  65. f func(w io.Writer, n int) (int, error)
  66. }{
  67. {"iocopy", func(w io.Writer, n int) (int, error) {
  68. nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
  69. return int(nn), err
  70. }},
  71. {"write", func(w io.Writer, n int) (int, error) {
  72. return w.Write(writeBuf[:n])
  73. }},
  74. {"string", func(w io.Writer, n int) (int, error) {
  75. return io.WriteString(w, string(writeBuf[:n]))
  76. }},
  77. }
  78. for _, compress := range []bool{false, true} {
  79. for _, isServer := range []bool{true, false} {
  80. for _, chunker := range readChunkers {
  81. var connBuf bytes.Buffer
  82. wc := newTestConn(nil, &connBuf, isServer)
  83. rc := newTestConn(chunker.f(&connBuf), nil, !isServer)
  84. //if compress {
  85. // wc.newCompressionWriter = compressNoContextTakeover
  86. // rc.newDecompressionReader = decompressNoContextTakeover
  87. //}
  88. for _, n := range frameSizes {
  89. for _, writer := range writers {
  90. name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
  91. w, err := wc.NextWriter(TextMessage)
  92. if err != nil {
  93. t.Errorf("%s: wc.NextWriter() returned %v", name, err)
  94. continue
  95. }
  96. nn, err := writer.f(w, n)
  97. if err != nil || nn != n {
  98. t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
  99. continue
  100. }
  101. err = w.Close()
  102. if err != nil {
  103. t.Errorf("%s: w.Close() returned %v", name, err)
  104. continue
  105. }
  106. opCode, r, err := rc.NextReader()
  107. if err != nil || opCode != TextMessage {
  108. t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
  109. continue
  110. }
  111. t.Logf("frame size: %d", n)
  112. rbuf, err := ioutil.ReadAll(r)
  113. if err != nil {
  114. t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
  115. continue
  116. }
  117. if len(rbuf) != n {
  118. t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
  119. continue
  120. }
  121. for i, b := range rbuf {
  122. if byte(i) != b {
  123. t.Errorf("%s: bad byte at offset %d", name, i)
  124. break
  125. }
  126. }
  127. }
  128. }
  129. }
  130. }
  131. }
  132. }
  133. func TestControl(t *testing.T) {
  134. const message = "this is a ping/pong messsage"
  135. for _, isServer := range []bool{true, false} {
  136. for _, isWriteControl := range []bool{true, false} {
  137. name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
  138. var connBuf bytes.Buffer
  139. wc := newTestConn(nil, &connBuf, isServer)
  140. rc := newTestConn(&connBuf, nil, !isServer)
  141. if isWriteControl {
  142. wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
  143. } else {
  144. w, err := wc.NextWriter(PongMessage)
  145. if err != nil {
  146. t.Errorf("%s: wc.NextWriter() returned %v", name, err)
  147. continue
  148. }
  149. if _, err := w.Write([]byte(message)); err != nil {
  150. t.Errorf("%s: w.Write() returned %v", name, err)
  151. continue
  152. }
  153. if err := w.Close(); err != nil {
  154. t.Errorf("%s: w.Close() returned %v", name, err)
  155. continue
  156. }
  157. var actualMessage string
  158. rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
  159. rc.NextReader()
  160. if actualMessage != message {
  161. t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
  162. continue
  163. }
  164. }
  165. }
  166. }
  167. }
  168. // simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool.
  169. type simpleBufferPool struct {
  170. v interface{}
  171. }
  172. func (p *simpleBufferPool) Get() interface{} {
  173. v := p.v
  174. p.v = nil
  175. return v
  176. }
  177. func (p *simpleBufferPool) Put(v interface{}) {
  178. p.v = v
  179. }
  180. func TestWriteBufferPool(t *testing.T) {
  181. const message = "Now is the time for all good people to come to the aid of the party."
  182. var buf bytes.Buffer
  183. var pool simpleBufferPool
  184. rc := newTestConn(&buf, nil, false)
  185. // Specify writeBufferSize smaller than message size to ensure that pooling
  186. // works with fragmented messages.
  187. wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil)
  188. if wc.writeBuf != nil {
  189. t.Fatal("writeBuf not nil after create")
  190. }
  191. // Part 1: test NextWriter/Write/Close
  192. w, err := wc.NextWriter(TextMessage)
  193. if err != nil {
  194. t.Fatalf("wc.NextWriter() returned %v", err)
  195. }
  196. if wc.writeBuf == nil {
  197. t.Fatal("writeBuf is nil after NextWriter")
  198. }
  199. writeBufAddr := &wc.writeBuf[0]
  200. if _, err := io.WriteString(w, message); err != nil {
  201. t.Fatalf("io.WriteString(w, message) returned %v", err)
  202. }
  203. if err := w.Close(); err != nil {
  204. t.Fatalf("w.Close() returned %v", err)
  205. }
  206. if wc.writeBuf != nil {
  207. t.Fatal("writeBuf not nil after w.Close()")
  208. }
  209. if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
  210. t.Fatal("writeBuf not returned to pool")
  211. }
  212. opCode, p, err := rc.ReadMessage()
  213. if opCode != TextMessage || err != nil {
  214. t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
  215. }
  216. if s := string(p); s != message {
  217. t.Fatalf("message is %s, want %s", s, message)
  218. }
  219. // Part 2: Test WriteMessage.
  220. if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
  221. t.Fatalf("wc.WriteMessage() returned %v", err)
  222. }
  223. if wc.writeBuf != nil {
  224. t.Fatal("writeBuf not nil after wc.WriteMessage()")
  225. }
  226. if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
  227. t.Fatal("writeBuf not returned to pool after WriteMessage")
  228. }
  229. opCode, p, err = rc.ReadMessage()
  230. if opCode != TextMessage || err != nil {
  231. t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
  232. }
  233. if s := string(p); s != message {
  234. t.Fatalf("message is %s, want %s", s, message)
  235. }
  236. }
  237. // TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
  238. func TestWriteBufferPoolSync(t *testing.T) {
  239. var buf bytes.Buffer
  240. var pool sync.Pool
  241. wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
  242. rc := newTestConn(&buf, nil, false)
  243. const message = "Hello World!"
  244. for i := 0; i < 3; i++ {
  245. if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
  246. t.Fatalf("wc.WriteMessage() returned %v", err)
  247. }
  248. opCode, p, err := rc.ReadMessage()
  249. if opCode != TextMessage || err != nil {
  250. t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
  251. }
  252. if s := string(p); s != message {
  253. t.Fatalf("message is %s, want %s", s, message)
  254. }
  255. }
  256. }
  257. // errorWriter is an io.Writer than returns an error on all writes.
  258. type errorWriter struct{}
  259. func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") }
  260. // TestWriteBufferPoolError ensures that buffer is returned to pool after error
  261. // on write.
  262. func TestWriteBufferPoolError(t *testing.T) {
  263. // Part 1: Test NextWriter/Write/Close
  264. var pool simpleBufferPool
  265. wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
  266. w, err := wc.NextWriter(TextMessage)
  267. if err != nil {
  268. t.Fatalf("wc.NextWriter() returned %v", err)
  269. }
  270. if wc.writeBuf == nil {
  271. t.Fatal("writeBuf is nil after NextWriter")
  272. }
  273. writeBufAddr := &wc.writeBuf[0]
  274. if _, err := io.WriteString(w, "Hello"); err != nil {
  275. t.Fatalf("io.WriteString(w, message) returned %v", err)
  276. }
  277. if err := w.Close(); err == nil {
  278. t.Fatalf("w.Close() did not return error")
  279. }
  280. if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
  281. t.Fatal("writeBuf not returned to pool")
  282. }
  283. // Part 2: Test WriteMessage
  284. wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
  285. if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil {
  286. t.Fatalf("wc.WriteMessage did not return error")
  287. }
  288. if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
  289. t.Fatal("writeBuf not returned to pool")
  290. }
  291. }
  292. func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
  293. const bufSize = 512
  294. expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
  295. var b1, b2 bytes.Buffer
  296. wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
  297. rc := newTestConn(&b1, &b2, true)
  298. w, _ := wc.NextWriter(BinaryMessage)
  299. w.Write(make([]byte, bufSize+bufSize/2))
  300. wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
  301. w.Close()
  302. op, r, err := rc.NextReader()
  303. if op != BinaryMessage || err != nil {
  304. t.Fatalf("NextReader() returned %d, %v", op, err)
  305. }
  306. _, err = io.Copy(ioutil.Discard, r)
  307. if !reflect.DeepEqual(err, expectedErr) {
  308. t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
  309. }
  310. _, _, err = rc.NextReader()
  311. if !reflect.DeepEqual(err, expectedErr) {
  312. t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
  313. }
  314. }
  315. func TestEOFWithinFrame(t *testing.T) {
  316. const bufSize = 64
  317. for n := 0; ; n++ {
  318. var b bytes.Buffer
  319. wc := newTestConn(nil, &b, false)
  320. rc := newTestConn(&b, nil, true)
  321. w, _ := wc.NextWriter(BinaryMessage)
  322. w.Write(make([]byte, bufSize))
  323. w.Close()
  324. if n >= b.Len() {
  325. break
  326. }
  327. b.Truncate(n)
  328. op, r, err := rc.NextReader()
  329. if err == errUnexpectedEOF {
  330. continue
  331. }
  332. if op != BinaryMessage || err != nil {
  333. t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
  334. }
  335. _, err = io.Copy(ioutil.Discard, r)
  336. if err != errUnexpectedEOF {
  337. t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
  338. }
  339. _, _, err = rc.NextReader()
  340. if err != errUnexpectedEOF {
  341. t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
  342. }
  343. }
  344. }
  345. func TestEOFBeforeFinalFrame(t *testing.T) {
  346. const bufSize = 512
  347. var b1, b2 bytes.Buffer
  348. wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
  349. rc := newTestConn(&b1, &b2, true)
  350. w, _ := wc.NextWriter(BinaryMessage)
  351. w.Write(make([]byte, bufSize+bufSize/2))
  352. op, r, err := rc.NextReader()
  353. if op != BinaryMessage || err != nil {
  354. t.Fatalf("NextReader() returned %d, %v", op, err)
  355. }
  356. _, err = io.Copy(ioutil.Discard, r)
  357. if err != errUnexpectedEOF {
  358. t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
  359. }
  360. _, _, err = rc.NextReader()
  361. if err != errUnexpectedEOF {
  362. t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
  363. }
  364. }
  365. func TestWriteAfterMessageWriterClose(t *testing.T) {
  366. wc := newTestConn(nil, &bytes.Buffer{}, false)
  367. w, _ := wc.NextWriter(BinaryMessage)
  368. io.WriteString(w, "hello")
  369. if err := w.Close(); err != nil {
  370. t.Fatalf("unxpected error closing message writer, %v", err)
  371. }
  372. if _, err := io.WriteString(w, "world"); err == nil {
  373. t.Fatalf("no error writing after close")
  374. }
  375. w, _ = wc.NextWriter(BinaryMessage)
  376. io.WriteString(w, "hello")
  377. // close w by getting next writer
  378. _, err := wc.NextWriter(BinaryMessage)
  379. if err != nil {
  380. t.Fatalf("unexpected error getting next writer, %v", err)
  381. }
  382. if _, err := io.WriteString(w, "world"); err == nil {
  383. t.Fatalf("no error writing after close")
  384. }
  385. }
  386. func TestReadLimit(t *testing.T) {
  387. t.Run("Test ReadLimit is enforced", func(t *testing.T) {
  388. const readLimit = 512
  389. message := make([]byte, readLimit+1)
  390. var b1, b2 bytes.Buffer
  391. wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
  392. rc := newTestConn(&b1, &b2, true)
  393. rc.SetReadLimit(readLimit)
  394. // Send message at the limit with interleaved pong.
  395. w, _ := wc.NextWriter(BinaryMessage)
  396. w.Write(message[:readLimit-1])
  397. wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
  398. w.Write(message[:1])
  399. w.Close()
  400. // Send message larger than the limit.
  401. wc.WriteMessage(BinaryMessage, message[:readLimit+1])
  402. op, _, err := rc.NextReader()
  403. if op != BinaryMessage || err != nil {
  404. t.Fatalf("1: NextReader() returned %d, %v", op, err)
  405. }
  406. op, r, err := rc.NextReader()
  407. if op != BinaryMessage || err != nil {
  408. t.Fatalf("2: NextReader() returned %d, %v", op, err)
  409. }
  410. _, err = io.Copy(ioutil.Discard, r)
  411. if err != ErrReadLimit {
  412. t.Fatalf("io.Copy() returned %v", err)
  413. }
  414. })
  415. t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) {
  416. const readLimit = 1
  417. var b1, b2 bytes.Buffer
  418. rc := newTestConn(&b1, &b2, true)
  419. rc.SetReadLimit(readLimit)
  420. // First, send a non-final binary message
  421. b1.Write([]byte("\x02\x81"))
  422. // Mask key
  423. b1.Write([]byte("\x00\x00\x00\x00"))
  424. // First payload
  425. b1.Write([]byte("A"))
  426. // Next, send a negative-length, non-final continuation frame
  427. b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00"))
  428. // Mask key
  429. b1.Write([]byte("\x00\x00\x00\x00"))
  430. // Next, send a too long, final continuation frame
  431. b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05"))
  432. // Mask key
  433. b1.Write([]byte("\x00\x00\x00\x00"))
  434. // Too-long payload
  435. b1.Write([]byte("BCDEF"))
  436. op, r, err := rc.NextReader()
  437. if op != BinaryMessage || err != nil {
  438. t.Fatalf("1: NextReader() returned %d, %v", op, err)
  439. }
  440. var buf [10]byte
  441. var read int
  442. n, err := r.Read(buf[:])
  443. if err != nil && err != ErrReadLimit {
  444. t.Fatalf("unexpected error testing read limit: %v", err)
  445. }
  446. read += n
  447. n, err = r.Read(buf[:])
  448. if err != nil && err != ErrReadLimit {
  449. t.Fatalf("unexpected error testing read limit: %v", err)
  450. }
  451. read += n
  452. if err == nil && read > readLimit {
  453. t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read)
  454. }
  455. })
  456. }
  457. func TestAddrs(t *testing.T) {
  458. c := newTestConn(nil, nil, true)
  459. if c.LocalAddr() != localAddr {
  460. t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
  461. }
  462. if c.RemoteAddr() != remoteAddr {
  463. t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
  464. }
  465. }
  466. func TestUnderlyingConn(t *testing.T) {
  467. var b1, b2 bytes.Buffer
  468. fc := fakeNetConn{Reader: &b1, Writer: &b2}
  469. c := newConn(fc, true, 1024, 1024, nil, nil, nil)
  470. ul := c.UnderlyingConn()
  471. if ul != fc {
  472. t.Fatalf("Underlying conn is not what it should be.")
  473. }
  474. }
  475. func TestBufioReadBytes(t *testing.T) {
  476. // Test calling bufio.ReadBytes for value longer than read buffer size.
  477. m := make([]byte, 512)
  478. m[len(m)-1] = '\n'
  479. var b1, b2 bytes.Buffer
  480. wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil)
  481. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
  482. w, _ := wc.NextWriter(BinaryMessage)
  483. w.Write(m)
  484. w.Close()
  485. op, r, err := rc.NextReader()
  486. if op != BinaryMessage || err != nil {
  487. t.Fatalf("NextReader() returned %d, %v", op, err)
  488. }
  489. br := bufio.NewReader(r)
  490. p, err := br.ReadBytes('\n')
  491. if err != nil {
  492. t.Fatalf("ReadBytes() returned %v", err)
  493. }
  494. if len(p) != len(m) {
  495. t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m))
  496. }
  497. }
  498. var closeErrorTests = []struct {
  499. err error
  500. codes []int
  501. ok bool
  502. }{
  503. {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
  504. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
  505. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
  506. {errors.New("hello"), []int{CloseNormalClosure}, false},
  507. }
  508. func TestCloseError(t *testing.T) {
  509. for _, tt := range closeErrorTests {
  510. ok := IsCloseError(tt.err, tt.codes...)
  511. if ok != tt.ok {
  512. t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
  513. }
  514. }
  515. }
  516. var unexpectedCloseErrorTests = []struct {
  517. err error
  518. codes []int
  519. ok bool
  520. }{
  521. {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
  522. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
  523. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
  524. {errors.New("hello"), []int{CloseNormalClosure}, false},
  525. }
  526. func TestUnexpectedCloseErrors(t *testing.T) {
  527. for _, tt := range unexpectedCloseErrorTests {
  528. ok := IsUnexpectedCloseError(tt.err, tt.codes...)
  529. if ok != tt.ok {
  530. t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
  531. }
  532. }
  533. }
  534. type blockingWriter struct {
  535. c1, c2 chan struct{}
  536. }
  537. func (w blockingWriter) Write(p []byte) (int, error) {
  538. // Allow main to continue
  539. close(w.c1)
  540. // Wait for panic in main
  541. <-w.c2
  542. return len(p), nil
  543. }
  544. func TestConcurrentWritePanic(t *testing.T) {
  545. w := blockingWriter{make(chan struct{}), make(chan struct{})}
  546. c := newTestConn(nil, w, false)
  547. go func() {
  548. c.WriteMessage(TextMessage, []byte{})
  549. }()
  550. // wait for goroutine to block in write.
  551. <-w.c1
  552. defer func() {
  553. close(w.c2)
  554. if v := recover(); v != nil {
  555. return
  556. }
  557. }()
  558. c.WriteMessage(TextMessage, []byte{})
  559. t.Fatal("should not get here")
  560. }
  561. type failingReader struct{}
  562. func (r failingReader) Read(p []byte) (int, error) {
  563. return 0, io.EOF
  564. }
  565. func TestFailedConnectionReadPanic(t *testing.T) {
  566. c := newTestConn(failingReader{}, nil, false)
  567. defer func() {
  568. if v := recover(); v != nil {
  569. return
  570. }
  571. }()
  572. for i := 0; i < 20000; i++ {
  573. c.ReadMessage()
  574. }
  575. t.Fatal("should not get here")
  576. }