command_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. package command_test
  2. import (
  3. "context"
  4. "github.com/v2fly/v2ray-core/v4/app/router/routercommon"
  5. "testing"
  6. "time"
  7. "github.com/golang/mock/gomock"
  8. "github.com/google/go-cmp/cmp"
  9. "github.com/google/go-cmp/cmp/cmpopts"
  10. "google.golang.org/grpc"
  11. "google.golang.org/grpc/test/bufconn"
  12. "github.com/v2fly/v2ray-core/v4/app/router"
  13. . "github.com/v2fly/v2ray-core/v4/app/router/command"
  14. "github.com/v2fly/v2ray-core/v4/app/stats"
  15. "github.com/v2fly/v2ray-core/v4/common"
  16. "github.com/v2fly/v2ray-core/v4/common/net"
  17. "github.com/v2fly/v2ray-core/v4/features/routing"
  18. "github.com/v2fly/v2ray-core/v4/testing/mocks"
  19. )
  20. func TestServiceSubscribeRoutingStats(t *testing.T) {
  21. c := stats.NewChannel(&stats.ChannelConfig{
  22. SubscriberLimit: 1,
  23. BufferSize: 0,
  24. Blocking: true,
  25. })
  26. common.Must(c.Start())
  27. defer c.Close()
  28. lis := bufconn.Listen(1024 * 1024)
  29. bufDialer := func(context.Context, string) (net.Conn, error) {
  30. return lis.Dial()
  31. }
  32. testCases := []*RoutingContext{
  33. {InboundTag: "in", OutboundTag: "out"},
  34. {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
  35. {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
  36. {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
  37. {Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"},
  38. {Protocol: "bittorrent", OutboundTag: "blocked"},
  39. {User: "example@v2fly.org", OutboundTag: "out"},
  40. {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
  41. }
  42. errCh := make(chan error)
  43. nextPub := make(chan struct{})
  44. // Server goroutine
  45. go func() {
  46. server := grpc.NewServer()
  47. RegisterRoutingServiceServer(server, NewRoutingServer(nil, c))
  48. errCh <- server.Serve(lis)
  49. }()
  50. // Publisher goroutine
  51. go func() {
  52. publishTestCases := func() error {
  53. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  54. defer cancel()
  55. for { // Wait until there's one subscriber in routing stats channel
  56. if len(c.Subscribers()) > 0 {
  57. break
  58. }
  59. if ctx.Err() != nil {
  60. return ctx.Err()
  61. }
  62. }
  63. for _, tc := range testCases {
  64. c.Publish(context.Background(), AsRoutingRoute(tc))
  65. time.Sleep(time.Millisecond)
  66. }
  67. return nil
  68. }
  69. if err := publishTestCases(); err != nil {
  70. errCh <- err
  71. }
  72. // Wait for next round of publishing
  73. <-nextPub
  74. if err := publishTestCases(); err != nil {
  75. errCh <- err
  76. }
  77. }()
  78. // Client goroutine
  79. go func() {
  80. defer lis.Close()
  81. conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
  82. if err != nil {
  83. errCh <- err
  84. return
  85. }
  86. defer conn.Close()
  87. client := NewRoutingServiceClient(conn)
  88. // Test retrieving all fields
  89. testRetrievingAllFields := func() error {
  90. streamCtx, streamClose := context.WithCancel(context.Background())
  91. // Test the unsubscription of stream works well
  92. defer func() {
  93. streamClose()
  94. timeOutCtx, timeout := context.WithTimeout(context.Background(), time.Second)
  95. defer timeout()
  96. for { // Wait until there's no subscriber in routing stats channel
  97. if len(c.Subscribers()) == 0 {
  98. break
  99. }
  100. if timeOutCtx.Err() != nil {
  101. t.Error("unexpected subscribers not decreased in channel", timeOutCtx.Err())
  102. }
  103. }
  104. }()
  105. stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{})
  106. if err != nil {
  107. return err
  108. }
  109. for _, tc := range testCases {
  110. msg, err := stream.Recv()
  111. if err != nil {
  112. return err
  113. }
  114. if r := cmp.Diff(msg, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  115. t.Error(r)
  116. }
  117. }
  118. // Test that double subscription will fail
  119. errStream, err := client.SubscribeRoutingStats(context.Background(), &SubscribeRoutingStatsRequest{
  120. FieldSelectors: []string{"ip", "port", "domain", "outbound"},
  121. })
  122. if err != nil {
  123. return err
  124. }
  125. if _, err := errStream.Recv(); err == nil {
  126. t.Error("unexpected successful subscription")
  127. }
  128. return nil
  129. }
  130. // Test retrieving only a subset of fields
  131. testRetrievingSubsetOfFields := func() error {
  132. streamCtx, streamClose := context.WithCancel(context.Background())
  133. defer streamClose()
  134. stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{
  135. FieldSelectors: []string{"ip", "port", "domain", "outbound"},
  136. })
  137. if err != nil {
  138. return err
  139. }
  140. // Send nextPub signal to start next round of publishing
  141. close(nextPub)
  142. for _, tc := range testCases {
  143. msg, err := stream.Recv()
  144. if err != nil {
  145. return err
  146. }
  147. stat := &RoutingContext{ // Only a subset of stats is retrieved
  148. SourceIPs: tc.SourceIPs,
  149. TargetIPs: tc.TargetIPs,
  150. SourcePort: tc.SourcePort,
  151. TargetPort: tc.TargetPort,
  152. TargetDomain: tc.TargetDomain,
  153. OutboundGroupTags: tc.OutboundGroupTags,
  154. OutboundTag: tc.OutboundTag,
  155. }
  156. if r := cmp.Diff(msg, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  157. t.Error(r)
  158. }
  159. }
  160. return nil
  161. }
  162. if err := testRetrievingAllFields(); err != nil {
  163. errCh <- err
  164. }
  165. if err := testRetrievingSubsetOfFields(); err != nil {
  166. errCh <- err
  167. }
  168. errCh <- nil // Client passed all tests successfully
  169. }()
  170. // Wait for goroutines to complete
  171. select {
  172. case <-time.After(2 * time.Second):
  173. t.Fatal("Test timeout after 2s")
  174. case err := <-errCh:
  175. if err != nil {
  176. t.Fatal(err)
  177. }
  178. }
  179. }
  180. func TestSerivceTestRoute(t *testing.T) {
  181. c := stats.NewChannel(&stats.ChannelConfig{
  182. SubscriberLimit: 1,
  183. BufferSize: 16,
  184. Blocking: true,
  185. })
  186. common.Must(c.Start())
  187. defer c.Close()
  188. r := new(router.Router)
  189. mockCtl := gomock.NewController(t)
  190. defer mockCtl.Finish()
  191. common.Must(r.Init(context.TODO(), &router.Config{
  192. Rule: []*router.RoutingRule{
  193. {
  194. InboundTag: []string{"in"},
  195. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  196. },
  197. {
  198. Protocol: []string{"bittorrent"},
  199. TargetTag: &router.RoutingRule_Tag{Tag: "blocked"},
  200. },
  201. {
  202. PortList: &net.PortList{Range: []*net.PortRange{{From: 8080, To: 8080}}},
  203. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  204. },
  205. {
  206. SourcePortList: &net.PortList{Range: []*net.PortRange{{From: 9999, To: 9999}}},
  207. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  208. },
  209. {
  210. Domain: []*routercommon.Domain{{Type: routercommon.Domain_RootDomain, Value: "com"}},
  211. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  212. },
  213. {
  214. SourceGeoip: []*routercommon.GeoIP{{CountryCode: "private", Cidr: []*routercommon.CIDR{{Ip: []byte{127, 0, 0, 0}, Prefix: 8}}}},
  215. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  216. },
  217. {
  218. UserEmail: []string{"example@v2fly.org"},
  219. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  220. },
  221. {
  222. Networks: []net.Network{net.Network_UDP, net.Network_TCP},
  223. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  224. },
  225. },
  226. }, mocks.NewDNSClient(mockCtl), mocks.NewOutboundManager(mockCtl), nil))
  227. lis := bufconn.Listen(1024 * 1024)
  228. bufDialer := func(context.Context, string) (net.Conn, error) {
  229. return lis.Dial()
  230. }
  231. errCh := make(chan error)
  232. // Server goroutine
  233. go func() {
  234. server := grpc.NewServer()
  235. RegisterRoutingServiceServer(server, NewRoutingServer(r, c))
  236. errCh <- server.Serve(lis)
  237. }()
  238. // Client goroutine
  239. go func() {
  240. defer lis.Close()
  241. conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
  242. if err != nil {
  243. errCh <- err
  244. }
  245. defer conn.Close()
  246. client := NewRoutingServiceClient(conn)
  247. testCases := []*RoutingContext{
  248. {InboundTag: "in", OutboundTag: "out"},
  249. {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
  250. {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
  251. {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
  252. {Network: net.Network_UDP, Protocol: "bittorrent", OutboundTag: "blocked"},
  253. {User: "example@v2fly.org", OutboundTag: "out"},
  254. {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
  255. }
  256. // Test simple TestRoute
  257. testSimple := func() error {
  258. for _, tc := range testCases {
  259. route, err := client.TestRoute(context.Background(), &TestRouteRequest{RoutingContext: tc})
  260. if err != nil {
  261. return err
  262. }
  263. if r := cmp.Diff(route, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  264. t.Error(r)
  265. }
  266. }
  267. return nil
  268. }
  269. // Test TestRoute with special options
  270. testOptions := func() error {
  271. sub, err := c.Subscribe()
  272. if err != nil {
  273. return err
  274. }
  275. for _, tc := range testCases {
  276. route, err := client.TestRoute(context.Background(), &TestRouteRequest{
  277. RoutingContext: tc,
  278. FieldSelectors: []string{"ip", "port", "domain", "outbound"},
  279. PublishResult: true,
  280. })
  281. if err != nil {
  282. return err
  283. }
  284. stat := &RoutingContext{ // Only a subset of stats is retrieved
  285. SourceIPs: tc.SourceIPs,
  286. TargetIPs: tc.TargetIPs,
  287. SourcePort: tc.SourcePort,
  288. TargetPort: tc.TargetPort,
  289. TargetDomain: tc.TargetDomain,
  290. OutboundGroupTags: tc.OutboundGroupTags,
  291. OutboundTag: tc.OutboundTag,
  292. }
  293. if r := cmp.Diff(route, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  294. t.Error(r)
  295. }
  296. select { // Check that routing result has been published to statistics channel
  297. case msg, received := <-sub:
  298. if route, ok := msg.(routing.Route); received && ok {
  299. if r := cmp.Diff(AsProtobufMessage(nil)(route), tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  300. t.Error(r)
  301. }
  302. } else {
  303. t.Error("unexpected failure in receiving published routing result for testcase", tc)
  304. }
  305. case <-time.After(100 * time.Millisecond):
  306. t.Error("unexpected failure in receiving published routing result", tc)
  307. }
  308. }
  309. return nil
  310. }
  311. if err := testSimple(); err != nil {
  312. errCh <- err
  313. }
  314. if err := testOptions(); err != nil {
  315. errCh <- err
  316. }
  317. errCh <- nil // Client passed all tests successfully
  318. }()
  319. // Wait for goroutines to complete
  320. select {
  321. case <-time.After(2 * time.Second):
  322. t.Fatal("Test timeout after 2s")
  323. case err := <-errCh:
  324. if err != nil {
  325. t.Fatal(err)
  326. }
  327. }
  328. }