Browse Source

fix: use sync.Map in request to packet conn server

Shelikhoo 1 year ago
parent
commit
3e7dc35562
1 changed files with 18 additions and 7 deletions
  1. 18 7
      transport/internet/request/assembler/packetconn/req2packet.go

+ 18 - 7
transport/internet/request/assembler/packetconn/req2packet.go

@@ -5,6 +5,7 @@ import (
 	"context"
 	"crypto/rand"
 	"io"
+	"sync"
 	"time"
 
 	"github.com/golang-collections/go-datastructures/queue"
@@ -102,7 +103,7 @@ copyFromChan:
 	waitTimer.Stop()
 	go func() {
 		reader, writer := io.Pipe()
-                defer writer.Close()
+		defer writer.Close()
 		streamingRespOpt := &pipedStreamingRespOption{writer}
 		go func() {
 			for {
@@ -176,7 +177,7 @@ func (r *requestToPacketConnClientSession) Close() error {
 
 func newRequestToPacketConnServer(ctx context.Context, config *ServerConfig) *requestToPacketConnServer {
 	return &requestToPacketConnServer{
-		sessionMap: make(map[string]*requestToPacketConnServerSession),
+		sessionMap: sync.Map{},
 		ctx:        ctx,
 		config:     config,
 	}
@@ -185,7 +186,7 @@ func newRequestToPacketConnServer(ctx context.Context, config *ServerConfig) *re
 type requestToPacketConnServer struct {
 	packetSessionReceiver request.SessionReceiver
 
-	sessionMap map[string]*requestToPacketConnServerSession
+	sessionMap sync.Map
 
 	ctx    context.Context
 	config *ServerConfig
@@ -203,7 +204,15 @@ func (r *requestToPacketConnServer) OnRoundTrip(ctx context.Context, req request
 		return request.Response{}, newError("nil session id")
 	}
 	sessionID := string(SessionID)
-	session, found := r.sessionMap[sessionID]
+	var session *requestToPacketConnServerSession
+	sessionAny, found := r.sessionMap.Load(sessionID)
+	if found {
+		var ok bool
+		session, ok = sessionAny.(*requestToPacketConnServerSession)
+		if !ok {
+			return request.Response{}, newError("failed to cast session")
+		}
+	}
 	if !found {
 		ctxWithFinish, finish := context.WithCancel(ctx)
 		session = &requestToPacketConnServerSession{
@@ -218,8 +227,10 @@ func (r *requestToPacketConnServer) OnRoundTrip(ctx context.Context, req request
 			maxWriteDuration:               int(r.config.MaxWriteDurationMs),
 			maxSimultaneousWriteConnection: int(r.config.MaxSimultaneousWriteConnection),
 		}
-		r.sessionMap[sessionID] = session
-		err = r.packetSessionReceiver.OnNewSession(ctx, session)
+		_, loaded := r.sessionMap.LoadOrStore(sessionID, session)
+		if !loaded {
+			err = r.packetSessionReceiver.OnNewSession(ctx, session)
+		}
 	}
 	if err != nil {
 		return request.Response{}, err
@@ -228,7 +239,7 @@ func (r *requestToPacketConnServer) OnRoundTrip(ctx context.Context, req request
 }
 
 func (r *requestToPacketConnServer) removeSessionID(sessionID []byte) {
-	delete(r.sessionMap, string(sessionID))
+	r.sessionMap.Delete(string(sessionID))
 }
 
 type requestToPacketConnServerSession struct {