Ver Fonte

aggressively close connection when response is done

Darien Raymond há 8 anos atrás
pai
commit
109a37fe7e

+ 1 - 1
common/buf/copy.go

@@ -54,7 +54,7 @@ func IgnoreWriterError() CopyOption {
 	}
 }
 
-func UpdateActivity(timer signal.ActivityTimer) CopyOption {
+func UpdateActivity(timer signal.ActivityUpdater) CopyOption {
 	return func(handler *copyHandler) {
 		handler.onData = append(handler.onData, func(MultiBuffer) {
 			timer.Update()

+ 17 - 9
common/signal/timer.go

@@ -5,26 +5,30 @@ import (
 	"time"
 )
 
-type ActivityTimer interface {
+type ActivityUpdater interface {
 	Update()
 }
 
-type realActivityTimer struct {
+type ActivityTimer struct {
 	updated chan bool
-	timeout time.Duration
+	timeout chan time.Duration
 	ctx     context.Context
 	cancel  context.CancelFunc
 }
 
-func (t *realActivityTimer) Update() {
+func (t *ActivityTimer) Update() {
 	select {
 	case t.updated <- true:
 	default:
 	}
 }
 
-func (t *realActivityTimer) run() {
-	ticker := time.NewTicker(t.timeout)
+func (t *ActivityTimer) SetTimeout(timeout time.Duration) {
+	t.timeout <- timeout
+}
+
+func (t *ActivityTimer) run() {
+	ticker := time.NewTicker(<-t.timeout)
 	defer ticker.Stop()
 
 	for {
@@ -32,6 +36,9 @@ func (t *realActivityTimer) run() {
 		case <-ticker.C:
 		case <-t.ctx.Done():
 			return
+		case timeout := <-t.timeout:
+			ticker.Stop()
+			ticker = time.NewTicker(timeout)
 		}
 
 		select {
@@ -44,14 +51,15 @@ func (t *realActivityTimer) run() {
 	}
 }
 
-func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.Context, ActivityTimer) {
+func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.Context, *ActivityTimer) {
 	ctx, cancel := context.WithCancel(ctx)
-	timer := &realActivityTimer{
+	timer := &ActivityTimer{
 		ctx:     ctx,
 		cancel:  cancel,
-		timeout: timeout,
+		timeout: make(chan time.Duration, 1),
 		updated: make(chan bool, 1),
 	}
+	timer.timeout <- timeout
 	go timer.run()
 	return ctx, timer
 }

+ 32 - 0
common/signal/timer_test.go

@@ -0,0 +1,32 @@
+package signal_test
+
+import (
+	"context"
+	"runtime"
+	"testing"
+	"time"
+
+	. "v2ray.com/core/common/signal"
+	"v2ray.com/core/testing/assert"
+)
+
+func TestActivityTimer(t *testing.T) {
+	assert := assert.On(t)
+
+	ctx, timer := CancelAfterInactivity(context.Background(), time.Second*5)
+	time.Sleep(time.Second * 6)
+	assert.Error(ctx.Err()).IsNotNil()
+	runtime.KeepAlive(timer)
+}
+
+func TestActivityTimerUpdate(t *testing.T) {
+	assert := assert.On(t)
+
+	ctx, timer := CancelAfterInactivity(context.Background(), time.Second*10)
+	time.Sleep(time.Second * 3)
+	assert.Error(ctx.Err()).IsNil()
+	timer.SetTimeout(time.Second * 1)
+	time.Sleep(time.Second * 2)
+	assert.Error(ctx.Err()).IsNotNil()
+	runtime.KeepAlive(timer)
+}

+ 3 - 0
proxy/dokodemo/dokodemo.go

@@ -94,6 +94,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 		if err := buf.Copy(inboundRay.InboundOutput(), writer, buf.UpdateActivity(timer)); err != nil {
 			return newError("failed to transport response").Base(err)
 		}
+
+		timer.SetTimeout(time.Second * 2)
+
 		return nil
 	})
 

+ 1 - 0
proxy/http/server.go

@@ -148,6 +148,7 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
 		if err := buf.Copy(ray.InboundOutput(), v2writer, buf.UpdateActivity(timer)); err != nil {
 			return err
 		}
+		timer.SetTimeout(time.Second * 2)
 		return nil
 	})
 

+ 2 - 0
proxy/shadowsocks/server.go

@@ -177,6 +177,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 			return newError("failed to transport all TCP response").Base(err)
 		}
 
+		timer.SetTimeout(time.Second * 2)
+
 		return nil
 	})
 

+ 1 - 0
proxy/socks/server.go

@@ -135,6 +135,7 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 		if err := buf.Copy(output, v2writer, buf.UpdateActivity(timer)); err != nil {
 			return newError("failed to transport all TCP response").Base(err)
 		}
+		timer.SetTimeout(time.Second * 2)
 		return nil
 	})
 

+ 2 - 2
proxy/vmess/inbound/inbound.go

@@ -127,7 +127,7 @@ func (v *Handler) GetUser(email string) *protocol.User {
 	return user
 }
 
-func transferRequest(timer signal.ActivityTimer, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output ray.OutputStream) error {
+func transferRequest(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output ray.OutputStream) error {
 	defer output.Close()
 
 	bodyReader := session.DecodeRequestBody(request, input)
@@ -137,7 +137,7 @@ func transferRequest(timer signal.ActivityTimer, session *encoding.ServerSession
 	return nil
 }
 
-func transferResponse(timer signal.ActivityTimer, session *encoding.ServerSession, request *protocol.RequestHeader, response *protocol.ResponseHeader, input buf.Reader, output io.Writer) error {
+func transferResponse(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, response *protocol.ResponseHeader, input buf.Reader, output io.Writer) error {
 	session.EncodeResponseHeader(response, output)
 
 	bodyWriter := session.EncodeResponseBody(request, output)