|
@@ -1,7 +1,9 @@
|
|
|
package dns
|
|
package dns
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
|
|
+ "bytes"
|
|
|
"context"
|
|
"context"
|
|
|
|
|
+ "encoding/binary"
|
|
|
"net/url"
|
|
"net/url"
|
|
|
"sync"
|
|
"sync"
|
|
|
"sync/atomic"
|
|
"sync/atomic"
|
|
@@ -189,13 +191,18 @@ func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ dnsReqBuf := buf.New()
|
|
|
|
|
+ binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))
|
|
|
|
|
+ dnsReqBuf.Write(b.Bytes())
|
|
|
|
|
+ b.Release()
|
|
|
|
|
+
|
|
|
conn, err := s.openStream(dnsCtx)
|
|
conn, err := s.openStream(dnsCtx)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
newError("failed to open quic connection").Base(err).AtError().WriteToLog()
|
|
newError("failed to open quic connection").Base(err).AtError().WriteToLog()
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- _, err = conn.Write(b.Bytes())
|
|
|
|
|
|
|
+ _, err = conn.Write(dnsReqBuf.Bytes())
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
newError("failed to send query").Base(err).AtError().WriteToLog()
|
|
newError("failed to send query").Base(err).AtError().WriteToLog()
|
|
|
return
|
|
return
|
|
@@ -205,9 +212,21 @@ func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP
|
|
|
|
|
|
|
|
respBuf := buf.New()
|
|
respBuf := buf.New()
|
|
|
defer respBuf.Release()
|
|
defer respBuf.Release()
|
|
|
- n, err := respBuf.ReadFrom(conn)
|
|
|
|
|
|
|
+ n, err := respBuf.ReadFullFrom(conn, 2)
|
|
|
|
|
+ if err != nil && n == 0 {
|
|
|
|
|
+ newError("failed to read response length").Base(err).AtError().WriteToLog()
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ var length int16
|
|
|
|
|
+ err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ newError("failed to parse response length").Base(err).AtError().WriteToLog()
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ respBuf.Clear()
|
|
|
|
|
+ n, err = respBuf.ReadFullFrom(conn, int32(length))
|
|
|
if err != nil && n == 0 {
|
|
if err != nil && n == 0 {
|
|
|
- newError("failed to read response").Base(err).AtError().WriteToLog()
|
|
|
|
|
|
|
+ newError("failed to read response length").Base(err).AtError().WriteToLog()
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|