Browse Source

cancel failed grpc connection

Shelikhoo 4 years ago
parent
commit
e00d80eac4
1 changed files with 13 additions and 4 deletions
  1. 13 4
      transport/internet/grpc/dial.go

+ 13 - 4
transport/internet/grpc/dial.go

@@ -36,6 +36,8 @@ func init() {
 	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
 }
 
+type dialerCanceller func()
+
 var (
 	globalDialerMap    map[net.Destination]*grpc.ClientConn
 	globalDialerAccess sync.Mutex
@@ -51,19 +53,20 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne
 		dialOption = grpc.WithTransportCredentials(credentials.NewTLS(config.GetTLSConfig()))
 	}
 
-	conn, err := getGrpcClient(ctx, dest, dialOption)
+	conn, canceller, err := getGrpcClient(ctx, dest, dialOption)
 	if err != nil {
 		return nil, newError("Cannot dial grpc").Base(err)
 	}
 	client := encoding.NewGunServiceClient(conn)
 	gunService, err := client.(encoding.GunServiceClientX).TunCustomName(ctx, grpcSettings.ServiceName)
 	if err != nil {
+		canceller()
 		return nil, newError("Cannot dial grpc").Base(err)
 	}
 	return encoding.NewGunConn(gunService, nil), nil
 }
 
-func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.DialOption) (*grpc.ClientConn, error) {
+func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.DialOption) (*grpc.ClientConn, dialerCanceller, error) {
 	globalDialerAccess.Lock()
 	defer globalDialerAccess.Unlock()
 
@@ -71,9 +74,15 @@ func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.Di
 		globalDialerMap = make(map[net.Destination]*grpc.ClientConn)
 	}
 
+	canceller := func() {
+		globalDialerAccess.Lock()
+		defer globalDialerAccess.Unlock()
+		delete(globalDialerMap, dest)
+	}
+
 	// TODO Should support chain proxy to the same destination
 	if client, found := globalDialerMap[dest]; found && client.GetState() != connectivity.Shutdown {
-		return client, nil
+		return client, canceller, nil
 	}
 
 	conn, err := grpc.Dial(
@@ -106,5 +115,5 @@ func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.Di
 		}),
 	)
 	globalDialerMap[dest] = conn
-	return conn, err
+	return conn, canceller, err
 }