Browse Source

fix: make sure the ctx is propagated to connections

Shelikhoo 4 years ago
parent
commit
2e26cf6587
2 changed files with 7 additions and 7 deletions
  1. 6 6
      app/dns/nameserver_quic.go
  2. 1 1
      app/dns/nameserver_udp.go

+ 6 - 6
app/dns/nameserver_quic.go

@@ -330,7 +330,7 @@ func isActive(s quic.Session) bool {
 	}
 }
 
-func (s *QUICNameServer) getSession() (quic.Session, error) {
+func (s *QUICNameServer) getSession(ctx context.Context) (quic.Session, error) {
 	var session quic.Session
 	s.RLock()
 	session = s.session
@@ -348,14 +348,14 @@ func (s *QUICNameServer) getSession() (quic.Session, error) {
 	defer s.Unlock()
 
 	var err error
-	session, err = s.openSession()
+	session, err = s.openSession(ctx)
 	if err != nil {
 		// This does not look too nice, but QUIC (or maybe quic-go)
 		// doesn't seem stable enough.
 		// Maybe retransmissions aren't fully implemented in quic-go?
 		// Anyways, the simple solution is to make a second try when
 		// it fails to open the QUIC session.
-		session, err = s.openSession()
+		session, err = s.openSession(ctx)
 		if err != nil {
 			return nil, err
 		}
@@ -364,13 +364,13 @@ func (s *QUICNameServer) getSession() (quic.Session, error) {
 	return session, nil
 }
 
-func (s *QUICNameServer) openSession() (quic.Session, error) {
+func (s *QUICNameServer) openSession(ctx context.Context) (quic.Session, error) {
 	tlsConfig := tls.Config{}
 	quicConfig := &quic.Config{
 		HandshakeIdleTimeout: handshakeIdleTimeout,
 	}
 
-	session, err := quic.DialAddrContext(context.Background(), s.destination.NetAddr(), tlsConfig.GetTLSConfig(tls.WithNextProto("http/1.1", http2.NextProtoTLS, NextProtoDQ)), quicConfig)
+	session, err := quic.DialAddrContext(ctx, s.destination.NetAddr(), tlsConfig.GetTLSConfig(tls.WithNextProto("http/1.1", http2.NextProtoTLS, NextProtoDQ)), quicConfig)
 	if err != nil {
 		return nil, err
 	}
@@ -379,7 +379,7 @@ func (s *QUICNameServer) openSession() (quic.Session, error) {
 }
 
 func (s *QUICNameServer) openStream(ctx context.Context) (quic.Stream, error) {
-	session, err := s.getSession()
+	session, err := s.getSession(ctx)
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
app/dns/nameserver_udp.go

@@ -192,7 +192,7 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, client
 	for _, req := range reqs {
 		s.addPendingRequest(req)
 		b, _ := dns.PackMessage(req.msg)
-		udpCtx := context.Background()
+		udpCtx := ctx
 		if inbound := session.InboundFromContext(ctx); inbound != nil {
 			udpCtx = session.ContextWithInbound(udpCtx, inbound)
 		}