Selaa lähdekoodia

cleanup websocket code

Darien Raymond 8 vuotta sitten
vanhempi
commit
cdb0b94046

+ 4 - 8
transport/internet/websocket/dialer.go

@@ -84,12 +84,8 @@ func wsDial(ctx context.Context, dest v2net.Destination) (net.Conn, error) {
 		}
 		return nil, err
 	}
-	return func() net.Conn {
-		connv2ray := &wsconn{
-			wsc:         conn,
-			connClosing: false,
-		}
-		connv2ray.setup()
-		return connv2ray
-	}(), nil
+
+	return &wsconn{
+		wsc: conn,
+	}, nil
 }

+ 1 - 3
transport/internet/websocket/hub.go

@@ -120,9 +120,7 @@ func (wsl *WSListener) converttovws(w http.ResponseWriter, r *http.Request) (*ws
 		return nil, err
 	}
 
-	wrapedConn := &wsconn{wsc: conn, connClosing: false}
-	wrapedConn.setup()
-	return wrapedConn, nil
+	return &wsconn{wsc: conn}, nil
 }
 
 func (v *WSListener) Accept() (internet.Connection, error) {

+ 45 - 137
transport/internet/websocket/wsconn.go

@@ -1,179 +1,87 @@
 package websocket
 
 import (
-	"bufio"
 	"io"
 	"net"
 	"sync"
 	"time"
 
 	"github.com/gorilla/websocket"
-	"v2ray.com/core/app/log"
 	"v2ray.com/core/common/errors"
 )
 
 type wsconn struct {
-	wsc         *websocket.Conn
-	readBuffer  *bufio.Reader
-	connClosing bool
-	rlock       *sync.Mutex
-	wlock       *sync.Mutex
+	wsc    *websocket.Conn
+	reader io.Reader
+	rlock  sync.Mutex
+	wlock  sync.Mutex
 }
 
-func (ws *wsconn) Read(b []byte) (n int, err error) {
-	ws.rlock.Lock()
-	n, err = ws.read(b)
-	ws.rlock.Unlock()
-	return n, err
+func (c *wsconn) Read(b []byte) (int, error) {
+	c.rlock.Lock()
+	defer c.rlock.Unlock()
 
-}
-
-func (ws *wsconn) read(b []byte) (n int, err error) {
-	if ws.connClosing {
-		return 0, io.EOF
-	}
-
-	n, err = ws.readNext(b)
-	return n, err
-}
-
-func (ws *wsconn) getNewReadBuffer() error {
-	_, r, err := ws.wsc.NextReader()
-	if err != nil {
-		log.Warning("WebSocket|Connection: Failed to get reader.", err)
-		ws.connClosing = true
-		ws.Close()
-		return err
-	}
-	ws.readBuffer = bufio.NewReader(r)
-	return nil
-}
-
-func (ws *wsconn) readNext(b []byte) (n int, err error) {
-	if ws.readBuffer == nil {
-		err = ws.getNewReadBuffer()
+	for {
+		reader, err := c.getReader()
 		if err != nil {
 			return 0, err
 		}
-	}
-
-	n, err = ws.readBuffer.Read(b)
-
-	if err == nil {
-		return n, err
-	}
 
-	if errors.Cause(err) == io.EOF {
-		ws.readBuffer = nil
-		if n == 0 {
-			return ws.readNext(b)
+		nBytes, err := reader.Read(b)
+		if errors.Cause(err) == io.EOF {
+			continue
 		}
-		return n, nil
+		return nBytes, err
 	}
-	return n, err
-
 }
 
-func (ws *wsconn) Write(b []byte) (n int, err error) {
-	ws.wlock.Lock()
-	if ws.connClosing {
-		return 0, io.ErrClosedPipe
+func (c *wsconn) getReader() (io.Reader, error) {
+	if c.reader != nil {
+		return c.reader, nil
 	}
 
-	n, err = ws.write(b)
-	ws.wlock.Unlock()
-	return n, err
-}
-
-func (ws *wsconn) write(b []byte) (n int, err error) {
-	wr, err := ws.wsc.NextWriter(websocket.BinaryMessage)
-	if err != nil {
-		log.Warning("WebSocket|Connection: Failed to get writer.", err)
-		ws.connClosing = true
-		ws.Close()
-		return 0, err
-	}
-	n, err = wr.Write(b)
+	_, reader, err := c.wsc.NextReader()
 	if err != nil {
-		return 0, err
+		return nil, err
 	}
-	err = wr.Close()
-	if err != nil {
+	c.reader = reader
+	return reader, nil
+}
+
+func (c *wsconn) Write(b []byte) (int, error) {
+	c.wlock.Lock()
+	defer c.wlock.Unlock()
+
+	if err := c.wsc.WriteMessage(websocket.BinaryMessage, b); err != nil {
 		return 0, err
 	}
-	return n, err
+	return len(b), nil
 }
 
-func (ws *wsconn) Close() error {
-	ws.connClosing = true
-	ws.wlock.Lock()
-	ws.wsc.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add((time.Second * 5)))
-	ws.wlock.Unlock()
-	err := ws.wsc.Close()
-	return err
+func (c *wsconn) Close() error {
+	c.wsc.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5))
+	return c.wsc.Close()
 }
-func (ws *wsconn) LocalAddr() net.Addr {
-	return ws.wsc.LocalAddr()
+
+func (c *wsconn) LocalAddr() net.Addr {
+	return c.wsc.LocalAddr()
 }
-func (ws *wsconn) RemoteAddr() net.Addr {
-	return ws.wsc.RemoteAddr()
+
+func (c *wsconn) RemoteAddr() net.Addr {
+	return c.wsc.RemoteAddr()
 }
-func (ws *wsconn) SetDeadline(t time.Time) error {
-	if err := ws.SetReadDeadline(t); err != nil {
+
+func (c *wsconn) SetDeadline(t time.Time) error {
+	if err := c.SetReadDeadline(t); err != nil {
 		return err
 	}
-	return ws.SetWriteDeadline(t)
+	return c.SetWriteDeadline(t)
 }
-func (ws *wsconn) SetReadDeadline(t time.Time) error {
-	return ws.wsc.SetReadDeadline(t)
-}
-func (ws *wsconn) SetWriteDeadline(t time.Time) error {
-	return ws.wsc.SetWriteDeadline(t)
-}
-
-func (ws *wsconn) setup() {
-	ws.connClosing = false
-
-	/*
-		https://godoc.org/github.com/gorilla/websocket#Conn.NextReader
-		https://godoc.org/github.com/gorilla/websocket#Conn.NextWriter
-
-		Both Read and write access are both exclusive.
-		And in both case it will need a lock.
 
-	*/
-	ws.rlock = &sync.Mutex{}
-	ws.wlock = &sync.Mutex{}
-
-	ws.pingPong()
+func (c *wsconn) SetReadDeadline(t time.Time) error {
+	return c.wsc.SetReadDeadline(t)
 }
 
-func (ws *wsconn) pingPong() {
-	pongRcv := make(chan int, 1)
-	ws.wsc.SetPongHandler(func(data string) error {
-		pongRcv <- 0
-		return nil
-	})
-
-	go func() {
-		for !ws.connClosing {
-			ws.wlock.Lock()
-			ws.wsc.WriteMessage(websocket.PingMessage, nil)
-			ws.wlock.Unlock()
-			tick := time.After(time.Second * 3)
-
-			select {
-			case <-pongRcv:
-			case <-tick:
-				if !ws.connClosing {
-					log.Debug("WS:Closing as ping is not responded~" + ws.wsc.UnderlyingConn().LocalAddr().String() + "-" + ws.wsc.UnderlyingConn().RemoteAddr().String())
-				}
-				ws.Close()
-			}
-			<-time.After(time.Second * 27)
-		}
-
-		return
-	}()
-
+func (c *wsconn) SetWriteDeadline(t time.Time) error {
+	return c.wsc.SetWriteDeadline(t)
 }