ipnet.go 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. package net
  2. import (
  3. "net"
  4. )
  5. var (
  6. onesCount = make(map[byte]byte)
  7. )
  8. type IPNet struct {
  9. cache map[uint32]byte
  10. }
  11. func NewIPNet() *IPNet {
  12. return &IPNet{
  13. cache: make(map[uint32]byte, 1024),
  14. }
  15. }
  16. func ipToUint32(ip net.IP) uint32 {
  17. value := uint32(0)
  18. for _, b := range []byte(ip) {
  19. value <<= 8
  20. value += uint32(b)
  21. }
  22. return value
  23. }
  24. func ipMaskToByte(mask net.IPMask) byte {
  25. value := byte(0)
  26. for _, b := range []byte(mask) {
  27. value += onesCount[b]
  28. }
  29. return value
  30. }
  31. func (v *IPNet) Add(ipNet *net.IPNet) {
  32. ipv4 := ipNet.IP.To4()
  33. if ipv4 == nil {
  34. // For now, we don't support IPv6
  35. return
  36. }
  37. mask := ipMaskToByte(ipNet.Mask)
  38. v.AddIP(ipv4, mask)
  39. }
  40. func (v *IPNet) AddIP(ip []byte, mask byte) {
  41. k := ipToUint32(ip)
  42. existing, found := v.cache[k]
  43. if !found || existing > mask {
  44. v.cache[k] = mask
  45. }
  46. }
  47. func (v *IPNet) Contains(ip net.IP) bool {
  48. ipv4 := ip.To4()
  49. if ipv4 == nil {
  50. return false
  51. }
  52. originalValue := ipToUint32(ipv4)
  53. if entry, found := v.cache[originalValue]; found {
  54. if entry == 32 {
  55. return true
  56. }
  57. }
  58. mask := uint32(0)
  59. for maskbit := byte(1); maskbit <= 32; maskbit++ {
  60. mask += 1 << uint32(32-maskbit)
  61. maskedValue := originalValue & mask
  62. if entry, found := v.cache[maskedValue]; found {
  63. if entry == maskbit {
  64. return true
  65. }
  66. }
  67. }
  68. return false
  69. }
  70. func (v *IPNet) IsEmpty() bool {
  71. return len(v.cache) == 0
  72. }
  73. func init() {
  74. value := byte(0)
  75. for mask := byte(1); mask <= 8; mask++ {
  76. value += 1 << byte(8-mask)
  77. onesCount[value] = mask
  78. }
  79. }