Browse Source

Feat: core.ToContext(ctx, v) for ctx initialization (#852)

rurirei 4 years ago
parent
commit
aa40b8b835
3 changed files with 31 additions and 8 deletions
  1. 16 0
      context.go
  2. 12 1
      context_test.go
  3. 3 7
      functions.go

+ 16 - 0
context.go

@@ -27,3 +27,19 @@ func MustFromContext(ctx context.Context) *Instance {
 	}
 	return v
 }
+
+// ToContext returns ctx from the given context, or creates an Instance if the context doesn't find that.
+func ToContext(ctx context.Context, v *Instance) context.Context {
+	if FromContext(ctx) != v {
+		ctx = context.WithValue(ctx, v2rayKey, v)
+	}
+	return ctx
+}
+
+// MustToContext returns ctx from the given context, or panics if not found that.
+func MustToContext(ctx context.Context, v *Instance) context.Context {
+	if c := ToContext(ctx, v); c != ctx {
+		panic("V is not in context.")
+	}
+	return ctx
+}

+ 12 - 1
context_test.go

@@ -7,7 +7,7 @@ import (
 	. "github.com/v2fly/v2ray-core/v4"
 )
 
-func TestContextPanic(t *testing.T) {
+func TestFromContextPanic(t *testing.T) {
 	defer func() {
 		r := recover()
 		if r == nil {
@@ -17,3 +17,14 @@ func TestContextPanic(t *testing.T) {
 
 	MustFromContext(context.Background())
 }
+
+func TestToContextPanic(t *testing.T) {
+	defer func() {
+		r := recover()
+		if r == nil {
+			t.Error("expect panic, but nil")
+		}
+	}()
+
+	MustToContext(context.Background(), &Instance{})
+}

+ 3 - 7
functions.go

@@ -16,7 +16,7 @@ import (
 func CreateObject(v *Instance, config interface{}) (interface{}, error) {
 	var ctx context.Context
 	if v != nil {
-		ctx = context.WithValue(v.ctx, v2rayKey, v)
+		ctx = ToContext(v.ctx, v)
 	}
 	return common.CreateObject(ctx, config)
 }
@@ -47,9 +47,7 @@ func StartInstance(configFormat string, configBytes []byte) (*Instance, error) {
 //
 // v2ray:api:stable
 func Dial(ctx context.Context, v *Instance, dest net.Destination) (net.Conn, error) {
-	if FromContext(ctx) == nil {
-		ctx = context.WithValue(ctx, v2rayKey, v)
-	}
+	ctx = ToContext(ctx, v)
 
 	dispatcher := v.GetFeature(routing.DispatcherType())
 	if dispatcher == nil {
@@ -76,9 +74,7 @@ func Dial(ctx context.Context, v *Instance, dest net.Destination) (net.Conn, err
 //
 // v2ray:api:beta
 func DialUDP(ctx context.Context, v *Instance) (net.PacketConn, error) {
-	if FromContext(ctx) == nil {
-		ctx = context.WithValue(ctx, v2rayKey, v)
-	}
+	ctx = ToContext(ctx, v)
 
 	dispatcher := v.GetFeature(routing.DispatcherType())
 	if dispatcher == nil {