Explorar o código

cancel sessions after inactivity

Darien Raymond %!s(int64=8) %!d(string=hai) anos
pai
achega
c462e35aad

+ 6 - 3
common/buf/io.go

@@ -4,6 +4,7 @@ import (
 	"io"
 
 	"v2ray.com/core/common/errors"
+	"v2ray.com/core/common/signal"
 )
 
 // Reader extends io.Reader with alloc.Buffer.
@@ -33,13 +34,15 @@ func ReadFullFrom(reader io.Reader, size int) Supplier {
 }
 
 // Pipe dumps all content from reader to writer, until an error happens.
-func Pipe(reader Reader, writer Writer) error {
+func Pipe(timer *signal.ActivityTimer, reader Reader, writer Writer) error {
 	for {
 		buffer, err := reader.Read()
 		if err != nil {
 			return err
 		}
 
+		timer.UpdateActivity()
+
 		if buffer.IsEmpty() {
 			buffer.Release()
 			continue
@@ -54,8 +57,8 @@ func Pipe(reader Reader, writer Writer) error {
 }
 
 // PipeUntilEOF behaves the same as Pipe(). The only difference is PipeUntilEOF returns nil on EOF.
-func PipeUntilEOF(reader Reader, writer Writer) error {
-	err := Pipe(reader, writer)
+func PipeUntilEOF(timer *signal.ActivityTimer, reader Reader, writer Writer) error {
+	err := Pipe(timer, reader, writer)
 	if err != nil && errors.Cause(err) != io.EOF {
 		return err
 	}

+ 45 - 0
common/signal/timer.go

@@ -0,0 +1,45 @@
+package signal
+
+import (
+	"context"
+	"time"
+)
+
+type ActivityTimer struct {
+	updated chan bool
+	timeout time.Duration
+	ctx     context.Context
+	cancel  context.CancelFunc
+}
+
+func (t *ActivityTimer) UpdateActivity() {
+	select {
+	case t.updated <- true:
+	default:
+	}
+}
+
+func (t *ActivityTimer) run() {
+	for {
+		time.Sleep(t.timeout)
+		select {
+		case <-t.ctx.Done():
+			return
+		case <-t.updated:
+		default:
+			t.cancel()
+			return
+		}
+	}
+}
+
+func CancelAfterInactivity(ctx context.Context, cancel context.CancelFunc, timeout time.Duration) *ActivityTimer {
+	timer := &ActivityTimer{
+		ctx:     ctx,
+		cancel:  cancel,
+		timeout: timeout,
+		updated: make(chan bool, 1),
+	}
+	go timer.run()
+	return timer
+}

+ 21 - 4
proxy/dokodemo/dokodemo.go

@@ -2,6 +2,8 @@ package dokodemo
 
 import (
 	"context"
+	"runtime"
+	"time"
 
 	"v2ray.com/core/app"
 	"v2ray.com/core/app/dispatcher"
@@ -52,11 +54,24 @@ func (d *DokodemoDoor) Network() net.NetworkList {
 func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn internet.Connection) error {
 	log.Debug("Dokodemo: processing connection from: ", conn.RemoteAddr())
 	conn.SetReusable(false)
-	ctx = proxy.ContextWithDestination(ctx, net.Destination{
+	dest := net.Destination{
 		Network: network,
 		Address: d.address,
 		Port:    d.port,
-	})
+	}
+	if d.config.FollowRedirect {
+		if origDest := proxy.OriginalDestinationFromContext(ctx); origDest.IsValid() {
+			dest = origDest
+		}
+	}
+	if !dest.IsValid() || dest.Address == nil {
+		log.Info("Dokodemo: Invalid destination. Discarding...")
+		return errors.New("Dokodemo: Unable to get destination.")
+	}
+	ctx = proxy.ContextWithDestination(ctx, dest)
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*2)
+
 	inboundRay := d.packetDispatcher.DispatchToOutbound(ctx)
 
 	requestDone := signal.ExecuteAsync(func() error {
@@ -65,7 +80,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 		timedReader := net.NewTimeOutReader(d.config.Timeout, conn)
 		chunkReader := buf.NewReader(timedReader)
 
-		if err := buf.Pipe(chunkReader, inboundRay.InboundInput()); err != nil {
+		if err := buf.PipeUntilEOF(timer, chunkReader, inboundRay.InboundInput()); err != nil {
 			log.Info("Dokodemo: Failed to transport request: ", err)
 			return err
 		}
@@ -76,7 +91,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 	responseDone := signal.ExecuteAsync(func() error {
 		v2writer := buf.NewWriter(conn)
 
-		if err := buf.PipeUntilEOF(inboundRay.InboundOutput(), v2writer); err != nil {
+		if err := buf.PipeUntilEOF(timer, inboundRay.InboundOutput(), v2writer); err != nil {
 			log.Info("Dokodemo: Failed to transport response: ", err)
 			return err
 		}
@@ -90,6 +105,8 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 		return err
 	}
 
+	runtime.KeepAlive(timer)
+
 	return nil
 }
 

+ 15 - 14
proxy/freedom/freedom.go

@@ -2,7 +2,9 @@ package freedom
 
 import (
 	"context"
-	"io"
+	"time"
+
+	"runtime"
 
 	"v2ray.com/core/app"
 	"v2ray.com/core/app/dns"
@@ -108,29 +110,26 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay) erro
 
 	conn.SetReusable(false)
 
+	ctx, cancel := context.WithCancel(ctx)
+	timeout := time.Second * time.Duration(v.timeout)
+	if timeout == 0 {
+		timeout = time.Minute * 10
+	}
+	timer := signal.CancelAfterInactivity(ctx, cancel, timeout)
+
 	requestDone := signal.ExecuteAsync(func() error {
 		v2writer := buf.NewWriter(conn)
-		if err := buf.PipeUntilEOF(input, v2writer); err != nil {
+		if err := buf.PipeUntilEOF(timer, input, v2writer); err != nil {
 			return err
 		}
 		return nil
 	})
 
-	var reader io.Reader = conn
-
-	timeout := v.timeout
-	if destination.Network == net.Network_UDP {
-		timeout = 16
-	}
-	if timeout > 0 {
-		reader = net.NewTimeOutReader(timeout /* seconds */, conn)
-	}
-
 	responseDone := signal.ExecuteAsync(func() error {
 		defer output.Close()
 
-		v2reader := buf.NewReader(reader)
-		if err := buf.Pipe(v2reader, output); err != nil {
+		v2reader := buf.NewReader(conn)
+		if err := buf.PipeUntilEOF(timer, v2reader, output); err != nil {
 			return err
 		}
 		return nil
@@ -143,6 +142,8 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay) erro
 		return err
 	}
 
+	runtime.KeepAlive(timer)
+
 	return nil
 }
 

+ 8 - 2
proxy/http/server.go

@@ -5,9 +5,11 @@ import (
 	"io"
 	"net"
 	"net/http"
+	"runtime"
 	"strconv"
 	"strings"
 	"sync"
+	"time"
 
 	"v2ray.com/core/app"
 	"v2ray.com/core/app/dispatcher"
@@ -130,13 +132,15 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
 		return err
 	}
 
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*2)
 	ray := s.packetDispatcher.DispatchToOutbound(ctx)
 
 	requestDone := signal.ExecuteAsync(func() error {
 		defer ray.InboundInput().Close()
 
 		v2reader := buf.NewReader(reader)
-		if err := buf.Pipe(v2reader, ray.InboundInput()); err != nil {
+		if err := buf.PipeUntilEOF(timer, v2reader, ray.InboundInput()); err != nil {
 			return err
 		}
 		return nil
@@ -144,7 +148,7 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
 
 	responseDone := signal.ExecuteAsync(func() error {
 		v2writer := buf.NewWriter(writer)
-		if err := buf.PipeUntilEOF(ray.InboundOutput(), v2writer); err != nil {
+		if err := buf.PipeUntilEOF(timer, ray.InboundOutput(), v2writer); err != nil {
 			return err
 		}
 		return nil
@@ -157,6 +161,8 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
 		return err
 	}
 
+	runtime.KeepAlive(timer)
+
 	return nil
 }
 

+ 14 - 7
proxy/shadowsocks/client.go

@@ -3,6 +3,10 @@ package shadowsocks
 import (
 	"context"
 
+	"time"
+
+	"runtime"
+
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/bufio"
@@ -88,6 +92,9 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay) error
 		request.Option |= RequestOptionOneTimeAuth
 	}
 
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*2)
+
 	if request.Command == protocol.RequestCommandTCP {
 		bufferedWriter := bufio.NewWriter(conn)
 		bodyWriter, err := WriteTCPRequest(request, bufferedWriter)
@@ -99,7 +106,7 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay) error
 		bufferedWriter.SetBuffered(false)
 
 		requestDone := signal.ExecuteAsync(func() error {
-			if err := buf.PipeUntilEOF(outboundRay.OutboundInput(), bodyWriter); err != nil {
+			if err := buf.PipeUntilEOF(timer, outboundRay.OutboundInput(), bodyWriter); err != nil {
 				return err
 			}
 			return nil
@@ -113,7 +120,7 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay) error
 				return err
 			}
 
-			if err := buf.Pipe(responseReader, outboundRay.OutboundOutput()); err != nil {
+			if err := buf.PipeUntilEOF(timer, responseReader, outboundRay.OutboundOutput()); err != nil {
 				return err
 			}
 
@@ -136,24 +143,22 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay) error
 		}
 
 		requestDone := signal.ExecuteAsync(func() error {
-			if err := buf.PipeUntilEOF(outboundRay.OutboundInput(), writer); err != nil {
+			if err := buf.PipeUntilEOF(timer, outboundRay.OutboundInput(), writer); err != nil {
 				log.Info("Shadowsocks|Client: Failed to transport all UDP request: ", err)
 				return err
 			}
 			return nil
 		})
 
-		timedReader := net.NewTimeOutReader(16, conn)
-
 		responseDone := signal.ExecuteAsync(func() error {
 			defer outboundRay.OutboundOutput().Close()
 
 			reader := &UDPReader{
-				Reader: timedReader,
+				Reader: conn,
 				User:   user,
 			}
 
-			if err := buf.Pipe(reader, outboundRay.OutboundOutput()); err != nil {
+			if err := buf.PipeUntilEOF(timer, reader, outboundRay.OutboundOutput()); err != nil {
 				log.Info("Shadowsocks|Client: Failed to transport all UDP response: ", err)
 				return err
 			}
@@ -168,6 +173,8 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay) error
 		return nil
 	}
 
+	runtime.KeepAlive(timer)
+
 	return nil
 }
 

+ 10 - 2
proxy/shadowsocks/server.go

@@ -2,6 +2,9 @@ package shadowsocks
 
 import (
 	"context"
+	"time"
+
+	"runtime"
 
 	"v2ray.com/core/app"
 	"v2ray.com/core/app/dispatcher"
@@ -154,6 +157,9 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection)
 
 	ctx = proxy.ContextWithDestination(ctx, dest)
 	ctx = protocol.ContextWithUser(ctx, request.User)
+
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*2)
 	ray := s.packetDispatcher.DispatchToOutbound(ctx)
 
 	requestDone := signal.ExecuteAsync(func() error {
@@ -177,7 +183,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection)
 			return err
 		}
 
-		if err := buf.Pipe(ray.InboundOutput(), responseWriter); err != nil {
+		if err := buf.PipeUntilEOF(timer, ray.InboundOutput(), responseWriter); err != nil {
 			log.Info("Shadowsocks|Server: Failed to transport all TCP response: ", err)
 			return err
 		}
@@ -188,7 +194,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection)
 	responseDone := signal.ExecuteAsync(func() error {
 		defer ray.InboundInput().Close()
 
-		if err := buf.PipeUntilEOF(bodyReader, ray.InboundInput()); err != nil {
+		if err := buf.PipeUntilEOF(timer, bodyReader, ray.InboundInput()); err != nil {
 			log.Info("Shadowsocks|Server: Failed to transport all TCP request: ", err)
 			return err
 		}
@@ -202,6 +208,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection)
 		return err
 	}
 
+	runtime.KeepAlive(timer)
+
 	return nil
 }
 

+ 12 - 4
proxy/socks/client.go

@@ -3,6 +3,9 @@ package socks
 import (
 	"context"
 
+	"runtime"
+	"time"
+
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/log"
@@ -79,15 +82,18 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay) error {
 		return err
 	}
 
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*2)
+
 	var requestFunc func() error
 	var responseFunc func() error
 	if request.Command == protocol.RequestCommandTCP {
 		requestFunc = func() error {
-			return buf.PipeUntilEOF(ray.OutboundInput(), buf.NewWriter(conn))
+			return buf.PipeUntilEOF(timer, ray.OutboundInput(), buf.NewWriter(conn))
 		}
 		responseFunc = func() error {
 			defer ray.OutboundOutput().Close()
-			return buf.Pipe(buf.NewReader(conn), ray.OutboundOutput())
+			return buf.PipeUntilEOF(timer, buf.NewReader(conn), ray.OutboundOutput())
 		}
 	} else if request.Command == protocol.RequestCommandUDP {
 		udpConn, err := dialer.Dial(ctx, udpRequest.Destination())
@@ -97,12 +103,12 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay) error {
 		}
 		defer udpConn.Close()
 		requestFunc = func() error {
-			return buf.PipeUntilEOF(ray.OutboundInput(), &UDPWriter{request: request, writer: udpConn})
+			return buf.PipeUntilEOF(timer, ray.OutboundInput(), &UDPWriter{request: request, writer: udpConn})
 		}
 		responseFunc = func() error {
 			defer ray.OutboundOutput().Close()
 			reader := &UDPReader{reader: net.NewTimeOutReader(16, udpConn)}
-			return buf.Pipe(reader, ray.OutboundOutput())
+			return buf.PipeUntilEOF(timer, reader, ray.OutboundOutput())
 		}
 	}
 
@@ -113,6 +119,8 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay) error {
 		return err
 	}
 
+	runtime.KeepAlive(timer)
+
 	return nil
 }
 

+ 8 - 2
proxy/socks/server.go

@@ -3,6 +3,7 @@ package socks
 import (
 	"context"
 	"io"
+	"runtime"
 	"time"
 
 	"v2ray.com/core/app"
@@ -115,6 +116,9 @@ func (*Server) handleUDP() error {
 }
 
 func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writer) error {
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*2)
+
 	ray := v.packetDispatcher.DispatchToOutbound(ctx)
 	input := ray.InboundInput()
 	output := ray.InboundOutput()
@@ -123,7 +127,7 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 		defer input.Close()
 
 		v2reader := buf.NewReader(reader)
-		if err := buf.Pipe(v2reader, input); err != nil {
+		if err := buf.PipeUntilEOF(timer, v2reader, input); err != nil {
 			log.Info("Socks|Server: Failed to transport all TCP request: ", err)
 			return err
 		}
@@ -132,7 +136,7 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 
 	responseDone := signal.ExecuteAsync(func() error {
 		v2writer := buf.NewWriter(writer)
-		if err := buf.PipeUntilEOF(output, v2writer); err != nil {
+		if err := buf.PipeUntilEOF(timer, output, v2writer); err != nil {
 			log.Info("Socks|Server: Failed to transport all TCP response: ", err)
 			return err
 		}
@@ -146,6 +150,8 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 		return err
 	}
 
+	runtime.KeepAlive(timer)
+
 	return nil
 }
 

+ 13 - 6
proxy/vmess/inbound/inbound.go

@@ -5,6 +5,9 @@ import (
 	"io"
 	"sync"
 
+	"runtime"
+	"time"
+
 	"v2ray.com/core/app"
 	"v2ray.com/core/app/dispatcher"
 	"v2ray.com/core/app/proxyman"
@@ -129,17 +132,17 @@ func (v *VMessInboundHandler) GetUser(email string) *protocol.User {
 	return user
 }
 
-func transferRequest(session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output ray.OutputStream) error {
+func transferRequest(timer *signal.ActivityTimer, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output ray.OutputStream) error {
 	defer output.Close()
 
 	bodyReader := session.DecodeRequestBody(request, input)
-	if err := buf.Pipe(bodyReader, output); err != nil {
+	if err := buf.PipeUntilEOF(timer, bodyReader, output); err != nil {
 		return err
 	}
 	return nil
 }
 
-func transferResponse(session *encoding.ServerSession, request *protocol.RequestHeader, response *protocol.ResponseHeader, input ray.InputStream, output io.Writer) error {
+func transferResponse(timer *signal.ActivityTimer, session *encoding.ServerSession, request *protocol.RequestHeader, response *protocol.ResponseHeader, input ray.InputStream, output io.Writer) error {
 	session.EncodeResponseHeader(response, output)
 
 	bodyWriter := session.EncodeResponseBody(request, output)
@@ -161,7 +164,7 @@ func transferResponse(session *encoding.ServerSession, request *protocol.Request
 		}
 	}
 
-	if err := buf.PipeUntilEOF(input, bodyWriter); err != nil {
+	if err := buf.PipeUntilEOF(timer, input, bodyWriter); err != nil {
 		return err
 	}
 
@@ -196,6 +199,8 @@ func (v *VMessInboundHandler) Process(ctx context.Context, network net.Network,
 
 	ctx = proxy.ContextWithDestination(ctx, request.Destination())
 	ctx = protocol.ContextWithUser(ctx, request.User)
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*2)
 	ray := v.packetDispatcher.DispatchToOutbound(ctx)
 
 	input := ray.InboundInput()
@@ -206,7 +211,7 @@ func (v *VMessInboundHandler) Process(ctx context.Context, network net.Network,
 	reader.SetBuffered(false)
 
 	requestDone := signal.ExecuteAsync(func() error {
-		return transferRequest(session, request, reader, input)
+		return transferRequest(timer, session, request, reader, input)
 	})
 
 	writer := bufio.NewWriter(connection)
@@ -219,7 +224,7 @@ func (v *VMessInboundHandler) Process(ctx context.Context, network net.Network,
 	}
 
 	responseDone := signal.ExecuteAsync(func() error {
-		return transferResponse(session, request, response, output, writer)
+		return transferResponse(timer, session, request, response, output, writer)
 	})
 
 	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
@@ -236,6 +241,8 @@ func (v *VMessInboundHandler) Process(ctx context.Context, network net.Network,
 		return err
 	}
 
+	runtime.KeepAlive(timer)
+
 	return nil
 }
 

+ 7 - 2
proxy/vmess/outbound/outbound.go

@@ -2,6 +2,7 @@ package outbound
 
 import (
 	"context"
+	"runtime"
 	"time"
 
 	"v2ray.com/core/app"
@@ -101,6 +102,9 @@ func (v *VMessOutboundHandler) Process(ctx context.Context, outboundRay ray.Outb
 
 	session := encoding.NewClientSession(protocol.DefaultIDHash)
 
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*2)
+
 	requestDone := signal.ExecuteAsync(func() error {
 		writer := bufio.NewWriter(conn)
 		session.EncodeRequestHeader(request, writer)
@@ -119,7 +123,7 @@ func (v *VMessOutboundHandler) Process(ctx context.Context, outboundRay ray.Outb
 
 		writer.SetBuffered(false)
 
-		if err := buf.PipeUntilEOF(input, bodyWriter); err != nil {
+		if err := buf.PipeUntilEOF(timer, input, bodyWriter); err != nil {
 			return err
 		}
 
@@ -145,7 +149,7 @@ func (v *VMessOutboundHandler) Process(ctx context.Context, outboundRay ray.Outb
 
 		reader.SetBuffered(false)
 		bodyReader := session.DecodeResponseBody(request, reader)
-		if err := buf.Pipe(bodyReader, output); err != nil {
+		if err := buf.PipeUntilEOF(timer, bodyReader, output); err != nil {
 			return err
 		}
 
@@ -157,6 +161,7 @@ func (v *VMessOutboundHandler) Process(ctx context.Context, outboundRay ray.Outb
 		conn.SetReusable(false)
 		return err
 	}
+	runtime.KeepAlive(timer)
 
 	return nil
 }