Browse Source

change net.IP to net.Address

Darien Raymond 6 years ago
parent
commit
fc1e660c27
2 changed files with 18 additions and 15 deletions
  1. 14 11
      app/dns/hosts.go
  2. 4 4
      app/dns/udpns.go

+ 14 - 11
app/dns/hosts.go

@@ -9,7 +9,7 @@ import (
 
 
 // StaticHosts represents static domain-ip mapping in DNS server.
 // StaticHosts represents static domain-ip mapping in DNS server.
 type StaticHosts struct {
 type StaticHosts struct {
-	ips      [][]net.IP
+	ips      [][]net.Address
 	matchers *strmatcher.MatcherGroup
 	matchers *strmatcher.MatcherGroup
 }
 }
 
 
@@ -36,7 +36,7 @@ func toStrMatcher(t DomainMatchingType, domain string) (strmatcher.Matcher, erro
 func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDomain) (*StaticHosts, error) {
 func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDomain) (*StaticHosts, error) {
 	g := new(strmatcher.MatcherGroup)
 	g := new(strmatcher.MatcherGroup)
 	sh := &StaticHosts{
 	sh := &StaticHosts{
-		ips:      make([][]net.IP, len(hosts)+len(legacy)+16),
+		ips:      make([][]net.Address, len(hosts)+len(legacy)+16),
 		matchers: g,
 		matchers: g,
 	}
 	}
 
 
@@ -50,10 +50,10 @@ func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDoma
 
 
 			address := ip.AsAddress()
 			address := ip.AsAddress()
 			if address.Family().IsDomain() {
 			if address.Family().IsDomain() {
-				return nil, newError("ignoring domain address in static hosts: ", address.Domain()).AtWarning()
+				return nil, newError("invalid domain address in static hosts: ", address.Domain()).AtWarning()
 			}
 			}
 
 
-			sh.ips[id] = []net.IP{address.IP()}
+			sh.ips[id] = []net.Address{address}
 		}
 		}
 	}
 	}
 
 
@@ -63,9 +63,13 @@ func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDoma
 			return nil, newError("failed to create domain matcher").Base(err)
 			return nil, newError("failed to create domain matcher").Base(err)
 		}
 		}
 		id := g.Add(matcher)
 		id := g.Add(matcher)
-		ips := make([]net.IP, len(mapping.Ip))
-		for idx, ip := range mapping.Ip {
-			ips[idx] = net.IP(ip)
+		ips := make([]net.Address, 0, len(mapping.Ip))
+		for _, ip := range mapping.Ip {
+			addr := net.IPAddress(ip)
+			if addr == nil {
+				return nil, newError("invalid IP address in static hosts: ", ip).AtWarning()
+			}
+			ips = append(ips, addr)
 		}
 		}
 		sh.ips[id] = ips
 		sh.ips[id] = ips
 	}
 	}
@@ -73,12 +77,11 @@ func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDoma
 	return sh, nil
 	return sh, nil
 }
 }
 
 
-func filterIP(ips []net.IP, option IPOption) []net.IP {
+func filterIP(ips []net.Address, option IPOption) []net.IP {
 	filtered := make([]net.IP, 0, len(ips))
 	filtered := make([]net.IP, 0, len(ips))
 	for _, ip := range ips {
 	for _, ip := range ips {
-		parsed := net.IPAddress(ip)
-		if (parsed.Family().IsIPv4() && option.IPv4Enable) || (parsed.Family().IsIPv6() && option.IPv6Enable) {
-			filtered = append(filtered, parsed.IP())
+		if (ip.Family().IsIPv4() && option.IPv4Enable) || (ip.Family().IsIPv6() && option.IPv6Enable) {
+			filtered = append(filtered, ip.IP())
 		}
 		}
 	}
 	}
 	if len(filtered) == 0 {
 	if len(filtered) == 0 {

+ 4 - 4
app/dns/udpns.go

@@ -20,7 +20,7 @@ import (
 )
 )
 
 
 type IPRecord struct {
 type IPRecord struct {
-	IP     net.IP
+	IP     net.Address
 	Expire time.Time
 	Expire time.Time
 }
 }
 
 
@@ -149,7 +149,7 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, payload *buf.Buf
 				break
 				break
 			}
 			}
 			ips = append(ips, IPRecord{
 			ips = append(ips, IPRecord{
-				IP:     net.IP(ans.A[:]),
+				IP:     net.IPAddress(ans.A[:]),
 				Expire: now.Add(time.Duration(ttl) * time.Second),
 				Expire: now.Add(time.Duration(ttl) * time.Second),
 			})
 			})
 		case dnsmessage.TypeAAAA:
 		case dnsmessage.TypeAAAA:
@@ -159,7 +159,7 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, payload *buf.Buf
 				break
 				break
 			}
 			}
 			ips = append(ips, IPRecord{
 			ips = append(ips, IPRecord{
-				IP:     net.IP(ans.AAAA[:]),
+				IP:     net.IPAddress(ans.AAAA[:]),
 				Expire: now.Add(time.Duration(ttl) * time.Second),
 				Expire: now.Add(time.Duration(ttl) * time.Second),
 			})
 			})
 		default:
 		default:
@@ -323,7 +323,7 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) []n
 	s.RUnlock()
 	s.RUnlock()
 
 
 	if found && len(records) > 0 {
 	if found && len(records) > 0 {
-		var ips []net.IP
+		var ips []net.Address
 		now := time.Now()
 		now := time.Now()
 		for _, rec := range records {
 		for _, rec := range records {
 			if rec.Expire.After(now) {
 			if rec.Expire.After(now) {