Browse Source

dns protocol package

Darien Raymond 6 years ago
parent
commit
1c830472b9
2 changed files with 79 additions and 14 deletions
  1. 2 14
      app/dns/udpns.go
  2. 77 0
      common/protocol/dns/io.go

+ 2 - 14
app/dns/udpns.go

@@ -8,10 +8,10 @@ import (
 	"time"
 
 	"golang.org/x/net/dns/dnsmessage"
-
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/protocol/dns"
 	"v2ray.com/core/common/session"
 	"v2ray.com/core/common/signal/pubsub"
 	"v2ray.com/core/common/task"
@@ -293,25 +293,13 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess
 	return msgs
 }
 
-func msgToBuffer2(msg *dnsmessage.Message) (*buf.Buffer, error) {
-	buffer := buf.New()
-	rawBytes := buffer.Extend(buf.Size)
-	packed, err := msg.AppendPack(rawBytes[:0])
-	if err != nil {
-		buffer.Release()
-		return nil, err
-	}
-	buffer.Resize(0, int32(len(packed)))
-	return buffer, nil
-}
-
 func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
 	newError("querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
 
 	msgs := s.buildMsgs(domain, option)
 
 	for _, msg := range msgs {
-		b, err := msgToBuffer2(msg)
+		b, err := dns.PackMessage(msg)
 		common.Must(err)
 		s.udpServer.Dispatch(context.Background(), s.address, b)
 	}

+ 77 - 0
common/protocol/dns/io.go

@@ -0,0 +1,77 @@
+package dns
+
+import (
+	"sync"
+
+	"golang.org/x/net/dns/dnsmessage"
+	"v2ray.com/core/common"
+	"v2ray.com/core/common/buf"
+)
+
+func PackMessage(msg *dnsmessage.Message) (*buf.Buffer, error) {
+	buffer := buf.New()
+	rawBytes := buffer.Extend(buf.Size)
+	packed, err := msg.AppendPack(rawBytes[:0])
+	if err != nil {
+		buffer.Release()
+		return nil, err
+	}
+	buffer.Resize(0, int32(len(packed)))
+	return buffer, nil
+}
+
+type MessageReader interface {
+	ReadMessage() (*buf.Buffer, error)
+}
+
+type UDPReader struct {
+	buf.Reader
+
+	access sync.Mutex
+	cache  buf.MultiBuffer
+}
+
+func (r *UDPReader) readCache() *buf.Buffer {
+	r.access.Lock()
+	defer r.access.Unlock()
+
+	mb, b := buf.SplitFirst(r.cache)
+	r.cache = mb
+	return b
+}
+
+func (r *UDPReader) refill() error {
+	mb, err := r.Reader.ReadMultiBuffer()
+	if err != nil {
+		return err
+	}
+	r.access.Lock()
+	r.cache = mb
+	r.access.Unlock()
+	return nil
+}
+
+// ReadMessage implements MessageReader.
+func (r *UDPReader) ReadMessage() (*buf.Buffer, error) {
+	for {
+		b := r.readCache()
+		if b != nil {
+			return b, nil
+		}
+		if err := r.refill(); err != nil {
+			return nil, err
+		}
+	}
+}
+
+// Close implements common.Closable.
+func (r *UDPReader) Close() error {
+	defer func() {
+		r.access.Lock()
+		buf.ReleaseMulti(r.cache)
+		r.cache = nil
+		r.access.Unlock()
+	}()
+
+	return common.Close(r.Reader)
+}