Browse Source

Fix: length of DNS over QUIC (#1888)

秋のかえで 3 years ago
parent
commit
7e1f461b74
1 changed files with 22 additions and 3 deletions
  1. 22 3
      app/dns/nameserver_quic.go

+ 22 - 3
app/dns/nameserver_quic.go

@@ -1,7 +1,9 @@
 package dns
 package dns
 
 
 import (
 import (
+	"bytes"
 	"context"
 	"context"
+	"encoding/binary"
 	"net/url"
 	"net/url"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
@@ -189,13 +191,18 @@ func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP
 				return
 				return
 			}
 			}
 
 
+			dnsReqBuf := buf.New()
+			binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))
+			dnsReqBuf.Write(b.Bytes())
+			b.Release()
+
 			conn, err := s.openStream(dnsCtx)
 			conn, err := s.openStream(dnsCtx)
 			if err != nil {
 			if err != nil {
 				newError("failed to open quic connection").Base(err).AtError().WriteToLog()
 				newError("failed to open quic connection").Base(err).AtError().WriteToLog()
 				return
 				return
 			}
 			}
 
 
-			_, err = conn.Write(b.Bytes())
+			_, err = conn.Write(dnsReqBuf.Bytes())
 			if err != nil {
 			if err != nil {
 				newError("failed to send query").Base(err).AtError().WriteToLog()
 				newError("failed to send query").Base(err).AtError().WriteToLog()
 				return
 				return
@@ -205,9 +212,21 @@ func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP
 
 
 			respBuf := buf.New()
 			respBuf := buf.New()
 			defer respBuf.Release()
 			defer respBuf.Release()
-			n, err := respBuf.ReadFrom(conn)
+			n, err := respBuf.ReadFullFrom(conn, 2)
+			if err != nil && n == 0 {
+				newError("failed to read response length").Base(err).AtError().WriteToLog()
+				return
+			}
+			var length int16
+			err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length)
+			if err != nil {
+				newError("failed to parse response length").Base(err).AtError().WriteToLog()
+				return
+			}
+			respBuf.Clear()
+			n, err = respBuf.ReadFullFrom(conn, int32(length))
 			if err != nil && n == 0 {
 			if err != nil && n == 0 {
-				newError("failed to read response").Base(err).AtError().WriteToLog()
+				newError("failed to read response length").Base(err).AtError().WriteToLog()
 				return
 				return
 			}
 			}