Parcourir la source

simplify websocket dialer and hub

Darien Raymond il y a 8 ans
Parent
commit
60f3562ac1

+ 6 - 0
transport/internet/websocket/connection.go

@@ -24,6 +24,12 @@ type connection struct {
 	mergingWriter buf.Writer
 }
 
+func newConnection(conn *websocket.Conn) *connection {
+	return &connection{
+		wsc: conn,
+	}
+}
+
 // Read implements net.Conn.Read()
 func (c *connection) Read(b []byte) (int, error) {
 	for {

+ 8 - 8
transport/internet/websocket/dialer.go

@@ -2,6 +2,7 @@ package websocket
 
 import (
 	"context"
+	"time"
 
 	"github.com/gorilla/websocket"
 	"v2ray.com/core/app/log"
@@ -30,14 +31,13 @@ func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error)
 	src := internet.DialerSourceFromContext(ctx)
 	wsSettings := internet.TransportSettingsFromContext(ctx).(*Config)
 
-	commonDial := func(network, addr string) (net.Conn, error) {
-		return internet.DialSystem(ctx, src, dest)
-	}
-
-	dialer := websocket.Dialer{
-		NetDial:         commonDial,
-		ReadBufferSize:  32 * 1024,
-		WriteBufferSize: 32 * 1024,
+	dialer := &websocket.Dialer{
+		NetDial: func(network, addr string) (net.Conn, error) {
+			return internet.DialSystem(ctx, src, dest)
+		},
+		ReadBufferSize:   32 * 1024,
+		WriteBufferSize:  32 * 1024,
+		HandshakeTimeout: time.Second * 8,
 	}
 
 	protocol := "ws"

+ 9 - 16
transport/internet/websocket/hub.go

@@ -6,6 +6,7 @@ import (
 	"net/http"
 	"strconv"
 	"sync"
+	"time"
 
 	"github.com/gorilla/websocket"
 	"v2ray.com/core/app/log"
@@ -20,18 +21,24 @@ type requestHandler struct {
 	ln   *Listener
 }
 
+var upgrader = &websocket.Upgrader{
+	ReadBufferSize:   32 * 1024,
+	WriteBufferSize:  32 * 1024,
+	HandshakeTimeout: time.Second * 8,
+}
+
 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 	if request.URL.Path != h.path {
 		writer.WriteHeader(http.StatusNotFound)
 		return
 	}
-	conn, err := converttovws(writer, request)
+	conn, err := upgrader.Upgrade(writer, request, nil)
 	if err != nil {
 		log.Trace(newError("failed to convert to WebSocket connection").Base(err))
 		return
 	}
 
-	h.ln.addConn(h.ln.ctx, internet.Connection(conn))
+	h.ln.addConn(h.ln.ctx, newConnection(conn))
 }
 
 type Listener struct {
@@ -92,20 +99,6 @@ func (ln *Listener) listenws(address net.Address, port net.Port) error {
 	return nil
 }
 
-func converttovws(w http.ResponseWriter, r *http.Request) (*connection, error) {
-	var upgrader = websocket.Upgrader{
-		ReadBufferSize:  32 * 1024,
-		WriteBufferSize: 32 * 1024,
-	}
-	conn, err := upgrader.Upgrade(w, r, nil)
-
-	if err != nil {
-		return nil, err
-	}
-
-	return &connection{wsc: conn}, nil
-}
-
 func (ln *Listener) Addr() net.Addr {
 	return ln.listener.Addr()
 }