소스 검색

better locking in udp server

Darien Raymond 8 년 전
부모
커밋
43cc81e5a8
1개의 변경된 파일20개의 추가작업 그리고 15개의 파일을 삭제
  1. 20 15
      transport/internet/udp/udp_server.go

+ 20 - 15
transport/internet/udp/udp_server.go

@@ -130,29 +130,34 @@ func (v *Server) locateExistingAndDispatch(name string, payload *buf.Buffer) boo
 	return false
 }
 
+func (v *Server) getInboundRay(dest string, session *proxy.SessionInfo) (*TimedInboundRay, bool) {
+	v.Lock()
+	defer v.Unlock()
+
+	if entry, found := v.conns[dest]; found {
+		return entry, true
+	}
+
+	log.Info("UDP|Server: establishing new connection for ", dest)
+	inboundRay := v.packetDispatcher.DispatchToOutbound(session)
+	return NewTimedInboundRay(dest, inboundRay, v), false
+}
+
 func (v *Server) Dispatch(session *proxy.SessionInfo, payload *buf.Buffer, callback ResponseCallback) {
 	source := session.Source
 	destination := session.Destination
 
 	// TODO: Add user to destString
 	destString := source.String() + "-" + destination.String()
-	log.Debug("UDP Server: Dispatch request: ", destString)
-	if v.locateExistingAndDispatch(destString, payload) {
-		return
-	}
-
-	log.Info("UDP Server: establishing new connection for ", destString)
-	inboundRay := v.packetDispatcher.DispatchToOutbound(session)
-	timedInboundRay := NewTimedInboundRay(destString, inboundRay, v)
-	outputStream := timedInboundRay.InboundInput()
+	log.Debug("UDP|Server: Dispatch request: ", destString)
+	inboundRay, existing := v.getInboundRay(destString, session)
+	outputStream := inboundRay.InboundInput()
 	if outputStream != nil {
 		outputStream.Write(payload)
 	}
-
-	v.Lock()
-	v.conns[destString] = timedInboundRay
-	v.Unlock()
-	go v.handleConnection(timedInboundRay, source, callback)
+	if !existing {
+		go v.handleConnection(inboundRay, source, callback)
+	}
 }
 
 func (v *Server) handleConnection(inboundRay *TimedInboundRay, source v2net.Destination, callback ResponseCallback) {
@@ -161,7 +166,7 @@ func (v *Server) handleConnection(inboundRay *TimedInboundRay, source v2net.Dest
 		if inputStream == nil {
 			break
 		}
-		data, err := inboundRay.InboundOutput().Read()
+		data, err := inputStream.Read()
 		if err != nil {
 			break
 		}