Ver Fonte

fix lock usage in server list

Darien Raymond há 8 anos atrás
pai
commit
0d92dce5eb
2 ficheiros alterados com 28 adições e 29 exclusões
  1. 26 27
      common/protocol/server_picker.go
  2. 2 2
      common/protocol/server_spec.go

+ 26 - 27
common/protocol/server_picker.go

@@ -13,32 +13,32 @@ func NewServerList() *ServerList {
 	return &ServerList{}
 }
 
-func (v *ServerList) AddServer(server *ServerSpec) {
-	v.Lock()
-	defer v.Unlock()
+func (sl *ServerList) AddServer(server *ServerSpec) {
+	sl.Lock()
+	defer sl.Unlock()
 
-	v.servers = append(v.servers, server)
+	sl.servers = append(sl.servers, server)
 }
 
-func (v *ServerList) Size() uint32 {
-	v.RLock()
-	defer v.RUnlock()
+func (sl *ServerList) Size() uint32 {
+	sl.RLock()
+	defer sl.RUnlock()
 
-	return uint32(len(v.servers))
+	return uint32(len(sl.servers))
 }
 
-func (v *ServerList) GetServer(idx uint32) *ServerSpec {
-	v.RLock()
-	defer v.RUnlock()
+func (sl *ServerList) GetServer(idx uint32) *ServerSpec {
+	sl.Lock()
+	defer sl.Unlock()
 
 	for {
-		if idx >= uint32(len(v.servers)) {
+		if idx >= uint32(len(sl.servers)) {
 			return nil
 		}
 
-		server := v.servers[idx]
+		server := sl.servers[idx]
 		if !server.IsValid() {
-			v.RemoveServer(idx)
+			sl.removeServer(idx)
 			continue
 		}
 
@@ -46,11 +46,10 @@ func (v *ServerList) GetServer(idx uint32) *ServerSpec {
 	}
 }
 
-// Private: Visible for testing.
-func (v *ServerList) RemoveServer(idx uint32) {
-	n := len(v.servers)
-	v.servers[idx] = v.servers[n-1]
-	v.servers = v.servers[:n-1]
+func (sl *ServerList) removeServer(idx uint32) {
+	n := len(sl.servers)
+	sl.servers[idx] = sl.servers[n-1]
+	sl.servers = sl.servers[:n-1]
 }
 
 type ServerPicker interface {
@@ -70,21 +69,21 @@ func NewRoundRobinServerPicker(serverlist *ServerList) *RoundRobinServerPicker {
 	}
 }
 
-func (v *RoundRobinServerPicker) PickServer() *ServerSpec {
-	v.Lock()
-	defer v.Unlock()
+func (p *RoundRobinServerPicker) PickServer() *ServerSpec {
+	p.Lock()
+	defer p.Unlock()
 
-	next := v.nextIndex
-	server := v.serverlist.GetServer(next)
+	next := p.nextIndex
+	server := p.serverlist.GetServer(next)
 	if server == nil {
 		next = 0
-		server = v.serverlist.GetServer(0)
+		server = p.serverlist.GetServer(0)
 	}
 	next++
-	if next >= v.serverlist.Size() {
+	if next >= p.serverlist.Size() {
 		next = 0
 	}
-	v.nextIndex = next
+	p.nextIndex = next
 
 	return server
 }

+ 2 - 2
common/protocol/server_spec.go

@@ -19,11 +19,11 @@ func AlwaysValid() ValidationStrategy {
 	return AlwaysValidStrategy{}
 }
 
-func (v AlwaysValidStrategy) IsValid() bool {
+func (AlwaysValidStrategy) IsValid() bool {
 	return true
 }
 
-func (v AlwaysValidStrategy) Invalidate() {}
+func (AlwaysValidStrategy) Invalidate() {}
 
 type TimeoutValidStrategy struct {
 	until time.Time