ipnet.go 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. package net
  2. import (
  3. "math/bits"
  4. "net"
  5. )
  6. type IPNet struct {
  7. cache map[uint32]byte
  8. }
  9. func NewIPNet() *IPNet {
  10. return &IPNet{
  11. cache: make(map[uint32]byte, 1024),
  12. }
  13. }
  14. func ipToUint32(ip net.IP) uint32 {
  15. value := uint32(0)
  16. for _, b := range []byte(ip) {
  17. value <<= 8
  18. value += uint32(b)
  19. }
  20. return value
  21. }
  22. func ipMaskToByte(mask net.IPMask) byte {
  23. value := byte(0)
  24. for _, b := range []byte(mask) {
  25. value += byte(bits.OnesCount8(b))
  26. }
  27. return value
  28. }
  29. func (n *IPNet) Add(ipNet *net.IPNet) {
  30. ipv4 := ipNet.IP.To4()
  31. if ipv4 == nil {
  32. // For now, we don't support IPv6
  33. return
  34. }
  35. mask := ipMaskToByte(ipNet.Mask)
  36. n.AddIP(ipv4, mask)
  37. }
  38. func (n *IPNet) AddIP(ip []byte, mask byte) {
  39. k := ipToUint32(ip)
  40. existing, found := n.cache[k]
  41. if !found || existing > mask {
  42. n.cache[k] = mask
  43. }
  44. }
  45. func (n *IPNet) Contains(ip net.IP) bool {
  46. ipv4 := ip.To4()
  47. if ipv4 == nil {
  48. return false
  49. }
  50. originalValue := ipToUint32(ipv4)
  51. if entry, found := n.cache[originalValue]; found {
  52. if entry == 32 {
  53. return true
  54. }
  55. }
  56. mask := uint32(0)
  57. for maskbit := byte(1); maskbit <= 32; maskbit++ {
  58. mask += 1 << uint32(32-maskbit)
  59. maskedValue := originalValue & mask
  60. if entry, found := n.cache[maskedValue]; found {
  61. if entry == maskbit {
  62. return true
  63. }
  64. }
  65. }
  66. return false
  67. }
  68. func (n *IPNet) IsEmpty() bool {
  69. return len(n.cache) == 0
  70. }