Browse Source

enforce timeout for http header processing

Darien Raymond 6 years ago
parent
commit
6146366a4a
2 changed files with 17 additions and 12 deletions
  1. 5 3
      transport/internet/http/hub.go
  2. 12 9
      transport/internet/websocket/hub.go

+ 5 - 3
transport/internet/http/hub.go

@@ -5,6 +5,7 @@ import (
 	"io"
 	"net/http"
 	"strings"
+	"time"
 
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/net"
@@ -105,9 +106,10 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
 	}
 
 	server := &http.Server{
-		Addr:      serial.Concat(address, ":", port),
-		TLSConfig: config.GetTLSConfig(tls.WithNextProto("h2")),
-		Handler:   listener,
+		Addr:              serial.Concat(address, ":", port),
+		TLSConfig:         config.GetTLSConfig(tls.WithNextProto("h2")),
+		Handler:           listener,
+		ReadHeaderTimeout: time.Second * 4,
 	}
 
 	listener.server = server

+ 12 - 9
transport/internet/websocket/hub.go

@@ -25,7 +25,7 @@ type requestHandler struct {
 var upgrader = &websocket.Upgrader{
 	ReadBufferSize:   4 * 1024,
 	WriteBufferSize:  4 * 1024,
-	HandshakeTimeout: time.Second * 8,
+	HandshakeTimeout: time.Second * 4,
 }
 
 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
@@ -50,6 +50,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 
 type Listener struct {
 	sync.Mutex
+	server   http.Server
 	listener net.Listener
 	config   *Config
 	addConn  internet.ConnHandler
@@ -74,8 +75,17 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet
 		listener: listener,
 	}
 
+	l.server = http.Server{
+		Handler: &requestHandler{
+			path: wsSettings.GetNormalizedPath(),
+			ln:   l,
+		},
+		ReadHeaderTimeout: time.Second * 4,
+		MaxHeaderBytes:    2048,
+	}
+
 	go func() {
-		if err := l.serve(); err != nil {
+		if err := l.server.Serve(l.listener); err != nil {
 			newError("failed to serve http for WebSocket").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
 		}
 	}()
@@ -99,13 +109,6 @@ func listenTCP(ctx context.Context, address net.Address, port net.Port, tlsConfi
 	return listener, nil
 }
 
-func (ln *Listener) serve() error {
-	return http.Serve(ln.listener, &requestHandler{
-		path: ln.config.GetNormalizedPath(),
-		ln:   ln,
-	})
-}
-
 // Addr implements net.Listener.Addr().
 func (ln *Listener) Addr() net.Addr {
 	return ln.listener.Addr()