Browse Source

Fix flaky TestServiceSubscribeRoutingStats

yuhan6665 4 years ago
parent
commit
ab6811ed58
1 changed files with 86 additions and 15 deletions
  1. 86 15
      app/router/command/command_test.go

+ 86 - 15
app/router/command/command_test.go

@@ -45,7 +45,6 @@ func TestServiceSubscribeRoutingStats(t *testing.T) {
 		{SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
 		{SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
 	}
 	}
 	errCh := make(chan error)
 	errCh := make(chan error)
-	nextPub := make(chan struct{})
 
 
 	// Server goroutine
 	// Server goroutine
 	go func() {
 	go func() {
@@ -77,13 +76,6 @@ func TestServiceSubscribeRoutingStats(t *testing.T) {
 		if err := publishTestCases(); err != nil {
 		if err := publishTestCases(); err != nil {
 			errCh <- err
 			errCh <- err
 		}
 		}
-
-		// Wait for next round of publishing
-		<-nextPub
-
-		if err := publishTestCases(); err != nil {
-			errCh <- err
-		}
 	}()
 	}()
 
 
 	// Client goroutine
 	// Client goroutine
@@ -145,6 +137,92 @@ func TestServiceSubscribeRoutingStats(t *testing.T) {
 			return nil
 			return nil
 		}
 		}
 
 
+		if err := testRetrievingAllFields(); err != nil {
+			errCh <- err
+		}
+		errCh <- nil // Client passed all tests successfully
+	}()
+
+	// Wait for goroutines to complete
+	select {
+	case <-time.After(2 * time.Second):
+		t.Fatal("Test timeout after 2s")
+	case err := <-errCh:
+		if err != nil {
+			t.Fatal(err)
+		}
+	}
+}
+
+func TestServiceSubscribeSubsetOfFields(t *testing.T) {
+	c := stats.NewChannel(&stats.ChannelConfig{
+		SubscriberLimit: 1,
+		BufferSize:      0,
+		Blocking:        true,
+	})
+	common.Must(c.Start())
+	defer c.Close()
+
+	lis := bufconn.Listen(1024 * 1024)
+	bufDialer := func(context.Context, string) (net.Conn, error) {
+		return lis.Dial()
+	}
+
+	testCases := []*RoutingContext{
+		{InboundTag: "in", OutboundTag: "out"},
+		{TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
+		{TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
+		{SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
+		{Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"},
+		{Protocol: "bittorrent", OutboundTag: "blocked"},
+		{User: "example@example.com", OutboundTag: "out"},
+		{SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
+	}
+	errCh := make(chan error)
+
+	// Server goroutine
+	go func() {
+		server := grpc.NewServer()
+		RegisterRoutingServiceServer(server, NewRoutingServer(nil, c))
+		errCh <- server.Serve(lis)
+	}()
+
+	// Publisher goroutine
+	go func() {
+		publishTestCases := func() error {
+			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+			defer cancel()
+			for { // Wait until there's one subscriber in routing stats channel
+				if len(c.Subscribers()) > 0 {
+					break
+				}
+				if ctx.Err() != nil {
+					return ctx.Err()
+				}
+			}
+			for _, tc := range testCases {
+				c.Publish(context.Background(), AsRoutingRoute(tc))
+				time.Sleep(time.Millisecond)
+			}
+			return nil
+		}
+
+		if err := publishTestCases(); err != nil {
+			errCh <- err
+		}
+	}()
+
+	// Client goroutine
+	go func() {
+		defer lis.Close()
+		conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
+		if err != nil {
+			errCh <- err
+			return
+		}
+		defer conn.Close()
+		client := NewRoutingServiceClient(conn)
+
 		// Test retrieving only a subset of fields
 		// Test retrieving only a subset of fields
 		testRetrievingSubsetOfFields := func() error {
 		testRetrievingSubsetOfFields := func() error {
 			streamCtx, streamClose := context.WithCancel(context.Background())
 			streamCtx, streamClose := context.WithCancel(context.Background())
@@ -156,9 +234,6 @@ func TestServiceSubscribeRoutingStats(t *testing.T) {
 				return err
 				return err
 			}
 			}
 
 
-			// Send nextPub signal to start next round of publishing
-			close(nextPub)
-
 			for _, tc := range testCases {
 			for _, tc := range testCases {
 				msg, err := stream.Recv()
 				msg, err := stream.Recv()
 				if err != nil {
 				if err != nil {
@@ -180,10 +255,6 @@ func TestServiceSubscribeRoutingStats(t *testing.T) {
 
 
 			return nil
 			return nil
 		}
 		}
-
-		if err := testRetrievingAllFields(); err != nil {
-			errCh <- err
-		}
 		if err := testRetrievingSubsetOfFields(); err != nil {
 		if err := testRetrievingSubsetOfFields(); err != nil {
 			errCh <- err
 			errCh <- err
 		}
 		}