Browse Source

fix udp dispatcher

Darien Raymond 8 years ago
parent
commit
9978cf07e6
1 changed files with 10 additions and 11 deletions
  1. 10 11
      transport/internet/udp/dispatcher.go

+ 10 - 11
transport/internet/udp/dispatcher.go

@@ -15,57 +15,56 @@ type ResponseCallback func(payload *buf.Buffer)
 
 type Dispatcher struct {
 	sync.RWMutex
-	conns      map[string]ray.InboundRay
+	conns      map[v2net.Destination]ray.InboundRay
 	dispatcher dispatcher.Interface
 }
 
 func NewDispatcher(dispatcher dispatcher.Interface) *Dispatcher {
 	return &Dispatcher{
-		conns:      make(map[string]ray.InboundRay),
+		conns:      make(map[v2net.Destination]ray.InboundRay),
 		dispatcher: dispatcher,
 	}
 }
 
-func (v *Dispatcher) RemoveRay(name string) {
+func (v *Dispatcher) RemoveRay(dest v2net.Destination) {
 	v.Lock()
 	defer v.Unlock()
-	if conn, found := v.conns[name]; found {
+	if conn, found := v.conns[dest]; found {
 		conn.InboundInput().Close()
 		conn.InboundOutput().Close()
-		delete(v.conns, name)
+		delete(v.conns, dest)
 	}
 }
 
 func (v *Dispatcher) getInboundRay(ctx context.Context, dest v2net.Destination) (ray.InboundRay, bool) {
-	destString := dest.String()
 	v.Lock()
 	defer v.Unlock()
 
-	if entry, found := v.conns[destString]; found {
+	if entry, found := v.conns[dest]; found {
 		return entry, true
 	}
 
 	log.Info("UDP|Server: establishing new connection for ", dest)
 	inboundRay, _ := v.dispatcher.Dispatch(ctx, dest)
+	v.conns[dest] = inboundRay
 	return inboundRay, false
 }
 
 func (v *Dispatcher) Dispatch(ctx context.Context, destination v2net.Destination, payload *buf.Buffer, callback ResponseCallback) {
 	// TODO: Add user to destString
-	destString := destination.String()
-	log.Debug("UDP|Server: Dispatch request: ", destString)
+	log.Debug("UDP|Server: Dispatch request: ", destination)
 
 	inboundRay, existing := v.getInboundRay(ctx, destination)
 	outputStream := inboundRay.InboundInput()
 	if outputStream != nil {
 		if err := outputStream.Write(payload); err != nil {
-			v.RemoveRay(destString)
+			v.RemoveRay(destination)
 		}
 	}
 	if !existing {
 		go func() {
 			handleInput(inboundRay.InboundOutput(), callback)
-			v.RemoveRay(destString)
+			v.RemoveRay(destination)
 		}()
 	}
 }