Browse Source

Stats: Implements blocking/non-blocking messaging of Channel (#250)

Ye Zhihao 5 years ago
parent
commit
67f409de04

+ 9 - 4
app/router/command/command.go

@@ -6,6 +6,7 @@ package command
 
 import (
 	"context"
+	"time"
 
 	"google.golang.org/grpc"
 
@@ -38,7 +39,8 @@ func (s *routingServer) TestRoute(ctx context.Context, request *TestRouteRequest
 		return nil, err
 	}
 	if request.PublishResult && s.routingStats != nil {
-		s.routingStats.Publish(route)
+		ctx, _ := context.WithTimeout(context.Background(), 4*time.Second) // nolint: lostcancel
+		s.routingStats.Publish(ctx, route)
 	}
 	return AsProtobufMessage(request.FieldSelectors)(route), nil
 }
@@ -55,10 +57,13 @@ func (s *routingServer) SubscribeRoutingStats(request *SubscribeRoutingStatsRequ
 	defer stats.UnsubscribeClosableChannel(s.routingStats, subscriber) // nolint: errcheck
 	for {
 		select {
-		case value, received := <-subscriber:
+		case value, ok := <-subscriber:
+			if !ok {
+				return newError("Upstream closed the subscriber channel.")
+			}
 			route, ok := value.(routing.Route)
-			if !(received && ok) {
-				return newError("Receiving upstream statistics failed.")
+			if !ok {
+				return newError("Upstream sent malformed statistics.")
 			}
 			err := stream.Send(genMessage(route))
 			if err != nil {

+ 155 - 128
app/router/command/command_test.go

@@ -21,9 +21,9 @@ import (
 
 func TestServiceSubscribeRoutingStats(t *testing.T) {
 	c := stats.NewChannel(&stats.ChannelConfig{
-		SubscriberLimit:  1,
-		BufferSize:       16,
-		BroadcastTimeout: 100,
+		SubscriberLimit: 1,
+		BufferSize:      0,
+		Blocking:        true,
 	})
 	common.Must(c.Start())
 	defer c.Close()
@@ -55,122 +55,138 @@ func TestServiceSubscribeRoutingStats(t *testing.T) {
 
 	// Publisher goroutine
 	go func() {
-		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
-		defer cancel()
-		for { // Wait until there's one subscriber in routing stats channel
-			if len(c.Subscribers()) > 0 {
-				break
+		publishTestCases := func() error {
+			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+			defer cancel()
+			for { // Wait until there's one subscriber in routing stats channel
+				if len(c.Subscribers()) > 0 {
+					break
+				}
+				if ctx.Err() != nil {
+					return ctx.Err()
+				}
 			}
-			if ctx.Err() != nil {
-				errCh <- ctx.Err()
+			for _, tc := range testCases {
+				c.Publish(context.Background(), AsRoutingRoute(tc))
+				time.Sleep(time.Millisecond)
 			}
+			return nil
 		}
-		for _, tc := range testCases {
-			c.Publish(AsRoutingRoute(tc))
+
+		if err := publishTestCases(); err != nil {
+			errCh <- err
 		}
 
 		// Wait for next round of publishing
 		<-nextPub
 
-		ctx, cancel = context.WithTimeout(context.Background(), time.Second)
-		defer cancel()
-		for { // Wait until there's one subscriber in routing stats channel
-			if len(c.Subscribers()) > 0 {
-				break
-			}
-			if ctx.Err() != nil {
-				errCh <- ctx.Err()
-			}
-		}
-		for _, tc := range testCases {
-			c.Publish(AsRoutingRoute(tc))
+		if err := publishTestCases(); err != nil {
+			errCh <- err
 		}
 	}()
 
 	// Client goroutine
 	go func() {
+		defer lis.Close()
 		conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
 		if err != nil {
 			errCh <- err
+			return
 		}
-		defer lis.Close()
 		defer conn.Close()
 		client := NewRoutingServiceClient(conn)
 
 		// Test retrieving all fields
-		streamCtx, streamClose := context.WithCancel(context.Background())
-		stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{})
-		if err != nil {
-			errCh <- err
-		}
+		testRetrievingAllFields := func() error {
+			streamCtx, streamClose := context.WithCancel(context.Background())
+
+			// Test the unsubscription of stream works well
+			defer func() {
+				streamClose()
+				timeOutCtx, timeout := context.WithTimeout(context.Background(), time.Second)
+				defer timeout()
+				for { // Wait until there's no subscriber in routing stats channel
+					if len(c.Subscribers()) == 0 {
+						break
+					}
+					if timeOutCtx.Err() != nil {
+						t.Error("unexpected subscribers not decreased in channel", timeOutCtx.Err())
+					}
+				}
+			}()
 
-		for _, tc := range testCases {
-			msg, err := stream.Recv()
+			stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{})
 			if err != nil {
-				errCh <- err
+				return err
 			}
-			if r := cmp.Diff(msg, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
-				t.Error(r)
-			}
-		}
 
-		// Test that double subscription will fail
-		errStream, err := client.SubscribeRoutingStats(context.Background(), &SubscribeRoutingStatsRequest{
-			FieldSelectors: []string{"ip", "port", "domain", "outbound"},
-		})
-		if err != nil {
-			errCh <- err
-		}
-		if _, err := errStream.Recv(); err == nil {
-			t.Error("unexpected successful subscription")
-		}
+			for _, tc := range testCases {
+				msg, err := stream.Recv()
+				if err != nil {
+					return err
+				}
+				if r := cmp.Diff(msg, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
+					t.Error(r)
+				}
+			}
 
-		// Test the unsubscription of stream works well
-		streamClose()
-		timeOutCtx, timeout := context.WithTimeout(context.Background(), time.Second)
-		defer timeout()
-		for { // Wait until there's no subscriber in routing stats channel
-			if len(c.Subscribers()) == 0 {
-				break
+			// Test that double subscription will fail
+			errStream, err := client.SubscribeRoutingStats(context.Background(), &SubscribeRoutingStatsRequest{
+				FieldSelectors: []string{"ip", "port", "domain", "outbound"},
+			})
+			if err != nil {
+				return err
 			}
-			if timeOutCtx.Err() != nil {
-				t.Error("unexpected subscribers not decreased in channel")
-				errCh <- timeOutCtx.Err()
+			if _, err := errStream.Recv(); err == nil {
+				t.Error("unexpected successful subscription")
 			}
-		}
 
-		// Test retrieving only a subset of fields
-		streamCtx, streamClose = context.WithCancel(context.Background())
-		stream, err = client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{
-			FieldSelectors: []string{"ip", "port", "domain", "outbound"},
-		})
-		if err != nil {
-			errCh <- err
+			return nil
 		}
 
-		close(nextPub) // Send nextPub signal to start next round of publishing
-		for _, tc := range testCases {
-			msg, err := stream.Recv()
-			stat := &RoutingContext{ // Only a subset of stats is retrieved
-				SourceIPs:         tc.SourceIPs,
-				TargetIPs:         tc.TargetIPs,
-				SourcePort:        tc.SourcePort,
-				TargetPort:        tc.TargetPort,
-				TargetDomain:      tc.TargetDomain,
-				OutboundGroupTags: tc.OutboundGroupTags,
-				OutboundTag:       tc.OutboundTag,
-			}
+		// Test retrieving only a subset of fields
+		testRetrievingSubsetOfFields := func() error {
+			streamCtx, streamClose := context.WithCancel(context.Background())
+			defer streamClose()
+			stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{
+				FieldSelectors: []string{"ip", "port", "domain", "outbound"},
+			})
 			if err != nil {
-				errCh <- err
+				return err
 			}
-			if r := cmp.Diff(msg, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
-				t.Error(r)
+
+			// Send nextPub signal to start next round of publishing
+			close(nextPub)
+
+			for _, tc := range testCases {
+				msg, err := stream.Recv()
+				if err != nil {
+					return err
+				}
+				stat := &RoutingContext{ // Only a subset of stats is retrieved
+					SourceIPs:         tc.SourceIPs,
+					TargetIPs:         tc.TargetIPs,
+					SourcePort:        tc.SourcePort,
+					TargetPort:        tc.TargetPort,
+					TargetDomain:      tc.TargetDomain,
+					OutboundGroupTags: tc.OutboundGroupTags,
+					OutboundTag:       tc.OutboundTag,
+				}
+				if r := cmp.Diff(msg, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
+					t.Error(r)
+				}
 			}
+
+			return nil
 		}
-		streamClose()
 
-		// Client passed all tests successfully
-		errCh <- nil
+		if err := testRetrievingAllFields(); err != nil {
+			errCh <- err
+		}
+		if err := testRetrievingSubsetOfFields(); err != nil {
+			errCh <- err
+		}
+		errCh <- nil // Client passed all tests successfully
 	}()
 
 	// Wait for goroutines to complete
@@ -186,9 +202,9 @@ func TestServiceSubscribeRoutingStats(t *testing.T) {
 
 func TestSerivceTestRoute(t *testing.T) {
 	c := stats.NewChannel(&stats.ChannelConfig{
-		SubscriberLimit:  1,
-		BufferSize:       16,
-		BroadcastTimeout: 100,
+		SubscriberLimit: 1,
+		BufferSize:      16,
+		Blocking:        true,
 	})
 	common.Must(c.Start())
 	defer c.Close()
@@ -249,11 +265,11 @@ func TestSerivceTestRoute(t *testing.T) {
 
 	// Client goroutine
 	go func() {
+		defer lis.Close()
 		conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
 		if err != nil {
 			errCh <- err
 		}
-		defer lis.Close()
 		defer conn.Close()
 		client := NewRoutingServiceClient(conn)
 
@@ -268,58 +284,69 @@ func TestSerivceTestRoute(t *testing.T) {
 		}
 
 		// Test simple TestRoute
-		for _, tc := range testCases {
-			route, err := client.TestRoute(context.Background(), &TestRouteRequest{RoutingContext: tc})
-			if err != nil {
-				errCh <- err
-			}
-			if r := cmp.Diff(route, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
-				t.Error(r)
+		testSimple := func() error {
+			for _, tc := range testCases {
+				route, err := client.TestRoute(context.Background(), &TestRouteRequest{RoutingContext: tc})
+				if err != nil {
+					return err
+				}
+				if r := cmp.Diff(route, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
+					t.Error(r)
+				}
 			}
+			return nil
 		}
 
 		// Test TestRoute with special options
-		sub, err := c.Subscribe()
-		if err != nil {
-			errCh <- err
-		}
-		for _, tc := range testCases {
-			route, err := client.TestRoute(context.Background(), &TestRouteRequest{
-				RoutingContext: tc,
-				FieldSelectors: []string{"ip", "port", "domain", "outbound"},
-				PublishResult:  true,
-			})
-			stat := &RoutingContext{ // Only a subset of stats is retrieved
-				SourceIPs:         tc.SourceIPs,
-				TargetIPs:         tc.TargetIPs,
-				SourcePort:        tc.SourcePort,
-				TargetPort:        tc.TargetPort,
-				TargetDomain:      tc.TargetDomain,
-				OutboundGroupTags: tc.OutboundGroupTags,
-				OutboundTag:       tc.OutboundTag,
-			}
+		testOptions := func() error {
+			sub, err := c.Subscribe()
 			if err != nil {
-				errCh <- err
+				return err
 			}
-			if r := cmp.Diff(route, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
-				t.Error(r)
-			}
-			select { // Check that routing result has been published to statistics channel
-			case msg, received := <-sub:
-				if route, ok := msg.(routing.Route); received && ok {
-					if r := cmp.Diff(AsProtobufMessage(nil)(route), tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
-						t.Error(r)
+			for _, tc := range testCases {
+				route, err := client.TestRoute(context.Background(), &TestRouteRequest{
+					RoutingContext: tc,
+					FieldSelectors: []string{"ip", "port", "domain", "outbound"},
+					PublishResult:  true,
+				})
+				if err != nil {
+					return err
+				}
+				stat := &RoutingContext{ // Only a subset of stats is retrieved
+					SourceIPs:         tc.SourceIPs,
+					TargetIPs:         tc.TargetIPs,
+					SourcePort:        tc.SourcePort,
+					TargetPort:        tc.TargetPort,
+					TargetDomain:      tc.TargetDomain,
+					OutboundGroupTags: tc.OutboundGroupTags,
+					OutboundTag:       tc.OutboundTag,
+				}
+				if r := cmp.Diff(route, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
+					t.Error(r)
+				}
+				select { // Check that routing result has been published to statistics channel
+				case msg, received := <-sub:
+					if route, ok := msg.(routing.Route); received && ok {
+						if r := cmp.Diff(AsProtobufMessage(nil)(route), tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
+							t.Error(r)
+						}
+					} else {
+						t.Error("unexpected failure in receiving published routing result for testcase", tc)
 					}
-				} else {
-					t.Error("unexpected failure in receiving published routing result")
+				case <-time.After(100 * time.Millisecond):
+					t.Error("unexpected failure in receiving published routing result", tc)
 				}
-			case <-time.After(100 * time.Millisecond):
-				t.Error("unexpected failure in receiving published routing result")
 			}
+			return nil
 		}
 
-		// Client passed all tests successfully
-		errCh <- nil
+		if err := testSimple(); err != nil {
+			errCh <- err
+		}
+		if err := testOptions(); err != nil {
+			errCh <- err
+		}
+		errCh <- nil // Client passed all tests successfully
 	}()
 
 	// Wait for goroutines to complete

+ 61 - 31
app/stats/channel.go

@@ -3,15 +3,15 @@
 package stats
 
 import (
+	"context"
 	"sync"
-	"time"
 
 	"v2ray.com/core/common"
 )
 
 // Channel is an implementation of stats.Channel.
 type Channel struct {
-	channel     chan interface{}
+	channel     chan channelMessage
 	subscribers []chan interface{}
 
 	// Synchronization components
@@ -19,28 +19,21 @@ type Channel struct {
 	closed chan struct{}
 
 	// Channel options
-	subscriberLimit   int           // Set to 0 as no subscriber limit
-	channelBufferSize int           // Set to 0 as no buffering
-	broadcastTimeout  time.Duration // Set to 0 as non-blocking immediate timeout
+	blocking   bool // Set blocking state if channel buffer reaches limit
+	bufferSize int  // Set to 0 as no buffering
+	subsLimit  int  // Set to 0 as no subscriber limit
 }
 
 // NewChannel creates an instance of Statistics Channel.
 func NewChannel(config *ChannelConfig) *Channel {
 	return &Channel{
-		channel:           make(chan interface{}, config.BufferSize),
-		subscriberLimit:   int(config.SubscriberLimit),
-		channelBufferSize: int(config.BufferSize),
-		broadcastTimeout:  time.Duration(config.BroadcastTimeout+1) * time.Millisecond,
+		channel:    make(chan channelMessage, config.BufferSize),
+		subsLimit:  int(config.SubscriberLimit),
+		bufferSize: int(config.BufferSize),
+		blocking:   config.Blocking,
 	}
 }
 
-// Channel returns the underlying go channel.
-func (c *Channel) Channel() chan interface{} {
-	c.access.RLock()
-	defer c.access.RUnlock()
-	return c.channel
-}
-
 // Subscribers implements stats.Channel.
 func (c *Channel) Subscribers() []chan interface{} {
 	c.access.RLock()
@@ -52,10 +45,10 @@ func (c *Channel) Subscribers() []chan interface{} {
 func (c *Channel) Subscribe() (chan interface{}, error) {
 	c.access.Lock()
 	defer c.access.Unlock()
-	if c.subscriberLimit > 0 && len(c.subscribers) >= c.subscriberLimit {
+	if c.subsLimit > 0 && len(c.subscribers) >= c.subsLimit {
 		return nil, newError("Number of subscribers has reached limit")
 	}
-	subscriber := make(chan interface{}, c.channelBufferSize)
+	subscriber := make(chan interface{}, c.bufferSize)
 	c.subscribers = append(c.subscribers, subscriber)
 	return subscriber, nil
 }
@@ -77,16 +70,17 @@ func (c *Channel) Unsubscribe(subscriber chan interface{}) error {
 }
 
 // Publish implements stats.Channel.
-func (c *Channel) Publish(message interface{}) {
+func (c *Channel) Publish(ctx context.Context, msg interface{}) {
 	select { // Early exit if channel closed
 	case <-c.closed:
 		return
 	default:
-	}
-	select { // Drop message if not successfully sent
-	case c.channel <- message:
-	default:
-		return
+		pub := channelMessage{context: ctx, message: msg}
+		if c.blocking {
+			pub.publish(c.channel)
+		} else {
+			pub.publishNonBlocking(c.channel)
+		}
 	}
 }
 
@@ -111,13 +105,12 @@ func (c *Channel) Start() error {
 		go func() {
 			for {
 				select {
-				case message := <-c.channel: // Broadcast message
-					for _, sub := range c.Subscribers() { // Concurrency-safe subscribers retreivement
-						select {
-						case sub <- message: // Successfully sent message
-						case <-time.After(c.broadcastTimeout): // Remove timeout subscriber
-							common.Must(c.Unsubscribe(sub))
-							close(sub) // Actively close subscriber as notification
+				case pub := <-c.channel: // Published message received
+					for _, sub := range c.Subscribers() { // Concurrency-safe subscribers retrievement
+						if c.blocking {
+							pub.broadcast(sub)
+						} else {
+							pub.broadcastNonBlocking(sub)
 						}
 					}
 				case <-c.closed: // Channel closed
@@ -142,3 +135,40 @@ func (c *Channel) Close() error {
 	}
 	return nil
 }
+
+// channelMessage is the published message with guaranteed delivery.
+// message is discarded only when the context is early cancelled.
+type channelMessage struct {
+	context context.Context
+	message interface{}
+}
+
+func (c channelMessage) publish(publisher chan channelMessage) {
+	select {
+	case publisher <- c:
+	case <-c.context.Done():
+	}
+}
+
+func (c channelMessage) publishNonBlocking(publisher chan channelMessage) {
+	select {
+	case publisher <- c:
+	default: // Create another goroutine to keep sending message
+		go c.publish(publisher)
+	}
+}
+
+func (c channelMessage) broadcast(subscriber chan interface{}) {
+	select {
+	case subscriber <- c.message:
+	case <-c.context.Done():
+	}
+}
+
+func (c channelMessage) broadcastNonBlocking(subscriber chan interface{}) {
+	select {
+	case subscriber <- c.message:
+	default: // Create another goroutine to keep sending message
+		go c.broadcast(subscriber)
+	}
+}

+ 141 - 86
app/stats/channel_test.go

@@ -1,6 +1,7 @@
 package stats_test
 
 import (
+	"context"
 	"fmt"
 	"testing"
 	"time"
@@ -12,8 +13,7 @@ import (
 
 func TestStatsChannel(t *testing.T) {
 	// At most 2 subscribers could be registered
-	c := NewChannel(&ChannelConfig{SubscriberLimit: 2})
-	source := c.Channel()
+	c := NewChannel(&ChannelConfig{SubscriberLimit: 2, Blocking: true})
 
 	a, err := stats.SubscribeRunnableChannel(c)
 	common.Must(err)
@@ -34,21 +34,12 @@ func TestStatsChannel(t *testing.T) {
 	stopCh := make(chan struct{})
 	errCh := make(chan string)
 
-	go func() { // Blocking publish
-		source <- 1
-		source <- 2
-		source <- "3"
-		source <- []int{4}
-		source <- nil // Dummy messsage with no subscriber receiving, will block reading goroutine
-		for i := 0; i < cap(source); i++ {
-			source <- nil // Fill source channel's buffer
-		}
-		select {
-		case source <- nil: // Source writing should be blocked here, for last message was not cleared and buffer was full
-			errCh <- fmt.Sprint("unexpected non-blocked source channel")
-		default:
-			close(stopCh)
-		}
+	go func() {
+		c.Publish(context.Background(), 1)
+		c.Publish(context.Background(), 2)
+		c.Publish(context.Background(), "3")
+		c.Publish(context.Background(), []int{4})
+		stopCh <- struct{}{}
 	}()
 
 	go func() {
@@ -64,6 +55,7 @@ func TestStatsChannel(t *testing.T) {
 		if v, ok := (<-a).([]int); !ok || v[0] != 4 {
 			errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", []int{4})
 		}
+		stopCh <- struct{}{}
 	}()
 
 	go func() {
@@ -79,14 +71,18 @@ func TestStatsChannel(t *testing.T) {
 		if v, ok := (<-b).([]int); !ok || v[0] != 4 {
 			errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", []int{4})
 		}
+		stopCh <- struct{}{}
 	}()
 
-	select {
-	case <-time.After(2 * time.Second):
-		t.Fatal("Test timeout after 2s")
-	case e := <-errCh:
-		t.Fatal(e)
-	case <-stopCh:
+	timeout := time.After(2 * time.Second)
+	for i := 0; i < 3; i++ {
+		select {
+		case <-timeout:
+			t.Fatal("Test timeout after 2s")
+		case e := <-errCh:
+			t.Fatal(e)
+		case <-stopCh:
+		}
 	}
 
 	// Test the unsubscription of channel
@@ -100,12 +96,10 @@ func TestStatsChannel(t *testing.T) {
 }
 
 func TestStatsChannelUnsubcribe(t *testing.T) {
-	c := NewChannel(&ChannelConfig{})
+	c := NewChannel(&ChannelConfig{Blocking: true})
 	common.Must(c.Start())
 	defer c.Close()
 
-	source := c.Channel()
-
 	a, err := c.Subscribe()
 	common.Must(err)
 	defer c.Unsubscribe(a)
@@ -133,9 +127,9 @@ func TestStatsChannelUnsubcribe(t *testing.T) {
 	}
 
 	go func() { // Blocking publish
-		source <- 1
+		c.Publish(context.Background(), 1)
 		<-pauseCh // Wait for `b` goroutine to resume sending message
-		source <- 2
+		c.Publish(context.Background(), 2)
 	}()
 
 	go func() {
@@ -151,7 +145,7 @@ func TestStatsChannelUnsubcribe(t *testing.T) {
 		if v, ok := (<-b).(int); !ok || v != 1 {
 			errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1)
 		}
-		// Unsubscribe `b` while `source`'s messaging is paused
+		// Unsubscribe `b` while publishing is paused
 		c.Unsubscribe(b)
 		{ // Test `b` is not in subscribers
 			var aSet, bSet bool
@@ -167,7 +161,7 @@ func TestStatsChannelUnsubcribe(t *testing.T) {
 				errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers())
 			}
 		}
-		// Resume `source`'s progress
+		// Resume publishing progress
 		close(pauseCh)
 		// Test `b` is neither closed nor able to receive any data
 		select {
@@ -191,78 +185,142 @@ func TestStatsChannelUnsubcribe(t *testing.T) {
 	}
 }
 
-func TestStatsChannelTimeout(t *testing.T) {
+func TestStatsChannelBlocking(t *testing.T) {
 	// Do not use buffer so as to create blocking scenario
-	c := NewChannel(&ChannelConfig{BufferSize: 0, BroadcastTimeout: 50})
+	c := NewChannel(&ChannelConfig{BufferSize: 0, Blocking: true})
 	common.Must(c.Start())
 	defer c.Close()
 
-	source := c.Channel()
-
 	a, err := c.Subscribe()
 	common.Must(err)
 	defer c.Unsubscribe(a)
 
-	b, err := c.Subscribe()
-	common.Must(err)
-	defer c.Unsubscribe(b)
-
+	pauseCh := make(chan struct{})
 	stopCh := make(chan struct{})
 	errCh := make(chan string)
 
-	go func() { // Blocking publish
-		source <- 1
-		source <- 2
+	ctx, cancel := context.WithCancel(context.Background())
+
+	// Test blocking channel publishing
+	go func() {
+		// Dummy messsage with no subscriber receiving, will block broadcasting goroutine
+		c.Publish(context.Background(), nil)
+
+		<-pauseCh
+
+		// Publishing should be blocked here, for last message was not cleared and buffer was full
+		c.Publish(context.Background(), nil)
+
+		pauseCh <- struct{}{}
+
+		// Publishing should still be blocked here
+		c.Publish(ctx, nil)
+
+		// Check publishing is done because context is canceled
+		select {
+		case <-ctx.Done():
+			if ctx.Err() != context.Canceled {
+				errCh <- fmt.Sprint("unexpected error: ", ctx.Err())
+			}
+		default:
+			errCh <- "unexpected non-blocked publishing"
+		}
+		close(stopCh)
 	}()
 
 	go func() {
-		if v, ok := (<-a).(int); !ok || v != 1 {
-			errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1)
+		pauseCh <- struct{}{}
+
+		select {
+		case <-pauseCh:
+			errCh <- "unexpected non-blocked publishing"
+		case <-time.After(100 * time.Millisecond):
 		}
-		if v, ok := (<-a).(int); !ok || v != 2 {
-			errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2)
+
+		// Receive first published message
+		<-a
+
+		select {
+		case <-pauseCh:
+		case <-time.After(100 * time.Millisecond):
+			errCh <- "unexpected blocking publishing"
 		}
-		{ // Test `b` is still in subscribers yet (because `a` receives 2 first)
-			var aSet, bSet bool
-			for _, s := range c.Subscribers() {
-				if s == a {
-					aSet = true
-				}
-				if s == b {
-					bSet = true
-				}
-			}
-			if !(aSet && bSet) {
-				errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers())
+
+		// Manually cancel the context to end publishing
+		cancel()
+	}()
+
+	select {
+	case <-time.After(2 * time.Second):
+		t.Fatal("Test timeout after 2s")
+	case e := <-errCh:
+		t.Fatal(e)
+	case <-stopCh:
+	}
+}
+
+func TestStatsChannelNonBlocking(t *testing.T) {
+	// Do not use buffer so as to create blocking scenario
+	c := NewChannel(&ChannelConfig{BufferSize: 0, Blocking: false})
+	common.Must(c.Start())
+	defer c.Close()
+
+	a, err := c.Subscribe()
+	common.Must(err)
+	defer c.Unsubscribe(a)
+
+	pauseCh := make(chan struct{})
+	stopCh := make(chan struct{})
+	errCh := make(chan string)
+
+	ctx, cancel := context.WithCancel(context.Background())
+
+	// Test blocking channel publishing
+	go func() {
+		c.Publish(context.Background(), nil)
+		c.Publish(context.Background(), nil)
+		pauseCh <- struct{}{}
+		<-pauseCh
+		c.Publish(ctx, nil)
+		c.Publish(ctx, nil)
+		// Check publishing is done because context is canceled
+		select {
+		case <-ctx.Done():
+			if ctx.Err() != context.Canceled {
+				errCh <- fmt.Sprint("unexpected error: ", ctx.Err())
 			}
+		case <-time.After(100 * time.Millisecond):
+			errCh <- "unexpected non-cancelled publishing"
 		}
 	}()
 
 	go func() {
-		if v, ok := (<-b).(int); !ok || v != 1 {
-			errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1)
+		// Check publishing won't block even if there is no subscriber receiving message
+		select {
+		case <-pauseCh:
+		case <-time.After(100 * time.Millisecond):
+			errCh <- "unexpected blocking publishing"
 		}
-		// Block `b` channel for a time longer than `source`'s timeout
-		<-time.After(200 * time.Millisecond)
-		{ // Test `b` has been unsubscribed by source
-			var aSet, bSet bool
-			for _, s := range c.Subscribers() {
-				if s == a {
-					aSet = true
-				}
-				if s == b {
-					bSet = true
-				}
-			}
-			if !(aSet && !bSet) {
-				errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers())
-			}
+
+		// Receive first and second published message
+		<-a
+		<-a
+
+		pauseCh <- struct{}{}
+
+		// Manually cancel the context to end publishing
+		cancel()
+
+		// Check third and forth published message is cancelled and cannot receive
+		<-time.After(100 * time.Millisecond)
+		select {
+		case <-a:
+			errCh <- "unexpected non-cancelled publishing"
+		default:
 		}
-		select { // Test `b` has been closed by source
-		case v, ok := <-b:
-			if ok {
-				errCh <- fmt.Sprint("unexpected data received: ", v)
-			}
+		select {
+		case <-a:
+			errCh <- "unexpected non-cancelled publishing"
 		default:
 		}
 		close(stopCh)
@@ -279,12 +337,10 @@ func TestStatsChannelTimeout(t *testing.T) {
 
 func TestStatsChannelConcurrency(t *testing.T) {
 	// Do not use buffer so as to create blocking scenario
-	c := NewChannel(&ChannelConfig{BufferSize: 0, BroadcastTimeout: 100})
+	c := NewChannel(&ChannelConfig{BufferSize: 0, Blocking: true})
 	common.Must(c.Start())
 	defer c.Close()
 
-	source := c.Channel()
-
 	a, err := c.Subscribe()
 	common.Must(err)
 	defer c.Unsubscribe(a)
@@ -297,8 +353,8 @@ func TestStatsChannelConcurrency(t *testing.T) {
 	errCh := make(chan string)
 
 	go func() { // Blocking publish
-		source <- 1
-		source <- 2
+		c.Publish(context.Background(), 1)
+		c.Publish(context.Background(), 2)
 	}()
 
 	go func() {
@@ -311,8 +367,7 @@ func TestStatsChannelConcurrency(t *testing.T) {
 	}()
 
 	go func() {
-		// Block `b` for a time shorter than `source`'s timeout
-		// So as to ensure source channel is trying to send message to `b`.
+		// Block `b` for a time so as to ensure source channel is trying to send message to `b`.
 		<-time.After(25 * time.Millisecond)
 		// This causes concurrency scenario: unsubscribe `b` while trying to send message to it
 		c.Unsubscribe(b)

+ 24 - 25
app/stats/config.pb.go

@@ -68,9 +68,9 @@ type ChannelConfig struct {
 	sizeCache     protoimpl.SizeCache
 	unknownFields protoimpl.UnknownFields
 
-	SubscriberLimit  int32 `protobuf:"varint,1,opt,name=SubscriberLimit,proto3" json:"SubscriberLimit,omitempty"`
-	BufferSize       int32 `protobuf:"varint,2,opt,name=BufferSize,proto3" json:"BufferSize,omitempty"`
-	BroadcastTimeout int32 `protobuf:"varint,3,opt,name=BroadcastTimeout,proto3" json:"BroadcastTimeout,omitempty"`
+	Blocking        bool  `protobuf:"varint,1,opt,name=Blocking,proto3" json:"Blocking,omitempty"`
+	SubscriberLimit int32 `protobuf:"varint,2,opt,name=SubscriberLimit,proto3" json:"SubscriberLimit,omitempty"`
+	BufferSize      int32 `protobuf:"varint,3,opt,name=BufferSize,proto3" json:"BufferSize,omitempty"`
 }
 
 func (x *ChannelConfig) Reset() {
@@ -105,23 +105,23 @@ func (*ChannelConfig) Descriptor() ([]byte, []int) {
 	return file_app_stats_config_proto_rawDescGZIP(), []int{1}
 }
 
-func (x *ChannelConfig) GetSubscriberLimit() int32 {
+func (x *ChannelConfig) GetBlocking() bool {
 	if x != nil {
-		return x.SubscriberLimit
+		return x.Blocking
 	}
-	return 0
+	return false
 }
 
-func (x *ChannelConfig) GetBufferSize() int32 {
+func (x *ChannelConfig) GetSubscriberLimit() int32 {
 	if x != nil {
-		return x.BufferSize
+		return x.SubscriberLimit
 	}
 	return 0
 }
 
-func (x *ChannelConfig) GetBroadcastTimeout() int32 {
+func (x *ChannelConfig) GetBufferSize() int32 {
 	if x != nil {
-		return x.BroadcastTimeout
+		return x.BufferSize
 	}
 	return 0
 }
@@ -132,21 +132,20 @@ var file_app_stats_config_proto_rawDesc = []byte{
 	0x0a, 0x16, 0x61, 0x70, 0x70, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x66,
 	0x69, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x14, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e,
 	0x63, 0x6f, 0x72, 0x65, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x73, 0x74, 0x61, 0x74, 0x73, 0x22, 0x08,
-	0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x85, 0x01, 0x0a, 0x0d, 0x43, 0x68, 0x61,
-	0x6e, 0x6e, 0x65, 0x6c, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28, 0x0a, 0x0f, 0x53, 0x75,
-	0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x72, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x01, 0x20,
-	0x01, 0x28, 0x05, 0x52, 0x0f, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x72, 0x4c,
-	0x69, 0x6d, 0x69, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72, 0x53, 0x69,
-	0x7a, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72,
-	0x53, 0x69, 0x7a, 0x65, 0x12, 0x2a, 0x0a, 0x10, 0x42, 0x72, 0x6f, 0x61, 0x64, 0x63, 0x61, 0x73,
-	0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x10,
-	0x42, 0x72, 0x6f, 0x61, 0x64, 0x63, 0x61, 0x73, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74,
-	0x42, 0x4d, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f,
-	0x72, 0x65, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x73, 0x74, 0x61, 0x74, 0x73, 0x50, 0x01, 0x5a, 0x18,
-	0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x61,
-	0x70, 0x70, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x73, 0xaa, 0x02, 0x14, 0x56, 0x32, 0x52, 0x61, 0x79,
-	0x2e, 0x43, 0x6f, 0x72, 0x65, 0x2e, 0x41, 0x70, 0x70, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x73, 0x62,
-	0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+	0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x75, 0x0a, 0x0d, 0x43, 0x68, 0x61, 0x6e,
+	0x6e, 0x65, 0x6c, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x42, 0x6c, 0x6f,
+	0x63, 0x6b, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x42, 0x6c, 0x6f,
+	0x63, 0x6b, 0x69, 0x6e, 0x67, 0x12, 0x28, 0x0a, 0x0f, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69,
+	0x62, 0x65, 0x72, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0f,
+	0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x72, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x12,
+	0x1e, 0x0a, 0x0a, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72, 0x53, 0x69, 0x7a, 0x65, 0x18, 0x03, 0x20,
+	0x01, 0x28, 0x05, 0x52, 0x0a, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72, 0x53, 0x69, 0x7a, 0x65, 0x42,
+	0x4d, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72,
+	0x65, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x73, 0x74, 0x61, 0x74, 0x73, 0x50, 0x01, 0x5a, 0x18, 0x76,
+	0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x61, 0x70,
+	0x70, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x73, 0xaa, 0x02, 0x14, 0x56, 0x32, 0x52, 0x61, 0x79, 0x2e,
+	0x43, 0x6f, 0x72, 0x65, 0x2e, 0x41, 0x70, 0x70, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x73, 0x62, 0x06,
+	0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
 }
 
 var (

+ 3 - 3
app/stats/config.proto

@@ -11,7 +11,7 @@ message Config {
 }
 
 message ChannelConfig {
-    int32 SubscriberLimit = 1;
-    int32 BufferSize = 2;
-    int32 BroadcastTimeout = 3;
+  bool  Blocking = 1;
+  int32 SubscriberLimit = 2;
+  int32 BufferSize = 3;
 }

+ 1 - 1
app/stats/stats.go

@@ -94,7 +94,7 @@ func (m *Manager) RegisterChannel(name string) (stats.Channel, error) {
 		return nil, newError("Channel ", name, " already registered.")
 	}
 	newError("create new channel ", name).AtDebug().WriteToLog()
-	c := NewChannel(&ChannelConfig{BufferSize: 16, BroadcastTimeout: 100})
+	c := NewChannel(&ChannelConfig{BufferSize: 64, Blocking: false})
 	m.channels[name] = c
 	if m.running {
 		return c, c.Start()

+ 4 - 2
features/stats/stats.go

@@ -3,6 +3,8 @@ package stats
 //go:generate errorgen
 
 import (
+	"context"
+
 	"v2ray.com/core/common"
 	"v2ray.com/core/features"
 )
@@ -25,8 +27,8 @@ type Counter interface {
 type Channel interface {
 	// Channel is a runnable unit.
 	common.Runnable
-	// Publish broadcasts a message through the channel.
-	Publish(interface{})
+	// Publish broadcasts a message through the channel with a controlling context.
+	Publish(context.Context, interface{})
 	// SubscriberCount returns the number of the subscribers.
 	Subscribers() []chan interface{}
 	// Subscribe registers for listening to channel stream and returns a new listener channel.