Forráskód Böngészése

fix race condition in transport

Darien Raymond 8 éve
szülő
commit
7a97d73737

+ 9 - 2
transport/internet/internal/pool.go

@@ -31,7 +31,7 @@ func (ec *ExpiringConnection) Expired() bool {
 
 // Pool is a connection pool.
 type Pool struct {
-	sync.Mutex
+	sync.RWMutex
 	connsByDest  map[ConnectionID][]*ExpiringConnection
 	cleanupToken *signal.Semaphore
 }
@@ -74,10 +74,17 @@ func (p *Pool) Get(id ConnectionID) net.Conn {
 	return conn.conn
 }
 
+func (p *Pool) isEmpty() bool {
+	p.RLock()
+	defer p.RUnlock()
+
+	return len(p.connsByDest) == 0
+}
+
 func (p *Pool) cleanup() {
 	defer p.cleanupToken.Signal()
 
-	for len(p.connsByDest) > 0 {
+	for !p.isEmpty() {
 		time.Sleep(time.Second * 5)
 		expiredConns := make([]net.Conn, 0, 16)
 		p.Lock()

+ 3 - 2
transport/internet/kcp/listener.go

@@ -246,13 +246,14 @@ func (v *Listener) Accept() (internet.Connection, error) {
 
 // Close stops listening on the UDP address. Already Accepted connections are not closed.
 func (v *Listener) Close() error {
+
+	v.Lock()
+	defer v.Unlock()
 	select {
 	case <-v.closed:
 		return ErrClosedListener
 	default:
 	}
-	v.Lock()
-	defer v.Unlock()
 
 	close(v.closed)
 	close(v.awaitingConns)

+ 14 - 8
transport/internet/websocket/hub.go

@@ -52,7 +52,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 
 type Listener struct {
 	sync.Mutex
-	acccepting    bool
+	closed        chan bool
 	awaitingConns chan *ConnectionWithError
 	listener      net.Listener
 	tlsConfig     *tls.Config
@@ -67,7 +67,7 @@ func ListenWS(address v2net.Address, port v2net.Port, options internet.ListenOpt
 	wsSettings := networkSettings.(*Config)
 
 	l := &Listener{
-		acccepting:    true,
+		closed:        make(chan bool),
 		awaitingConns: make(chan *ConnectionWithError, 32),
 		config:        wsSettings,
 	}
@@ -130,8 +130,10 @@ func converttovws(w http.ResponseWriter, r *http.Request) (*connection, error) {
 }
 
 func (ln *Listener) Accept() (internet.Connection, error) {
-	for ln.acccepting {
+	for {
 		select {
+		case <-ln.closed:
+			return nil, ErrClosedListener
 		case connErr, open := <-ln.awaitingConns:
 			if !open {
 				return nil, ErrClosedListener
@@ -143,14 +145,15 @@ func (ln *Listener) Accept() (internet.Connection, error) {
 		case <-time.After(time.Second * 2):
 		}
 	}
-	return nil, ErrClosedListener
 }
 
 func (ln *Listener) Put(id internal.ConnectionID, conn net.Conn) {
 	ln.Lock()
 	defer ln.Unlock()
-	if !ln.acccepting {
+	select {
+	case <-ln.closed:
 		return
+	default:
 	}
 	select {
 	case ln.awaitingConns <- &ConnectionWithError{conn: conn}:
@@ -166,10 +169,13 @@ func (ln *Listener) Addr() net.Addr {
 func (ln *Listener) Close() error {
 	ln.Lock()
 	defer ln.Unlock()
-	ln.acccepting = false
-
+	select {
+	case <-ln.closed:
+		return ErrClosedListener
+	default:
+	}
+	close(ln.closed)
 	ln.listener.Close()
-
 	close(ln.awaitingConns)
 	for connErr := range ln.awaitingConns {
 		if connErr.conn != nil {