ipnet.go 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. package net
  2. import (
  3. "math/bits"
  4. "net"
  5. )
  6. type IPNetTable struct {
  7. cache map[uint32]byte
  8. }
  9. func NewIPNetTable() *IPNetTable {
  10. return &IPNetTable{
  11. cache: make(map[uint32]byte, 1024),
  12. }
  13. }
  14. func ipToUint32(ip IP) uint32 {
  15. value := uint32(0)
  16. for _, b := range []byte(ip) {
  17. value <<= 8
  18. value += uint32(b)
  19. }
  20. return value
  21. }
  22. func ipMaskToByte(mask net.IPMask) byte {
  23. value := byte(0)
  24. for _, b := range []byte(mask) {
  25. value += byte(bits.OnesCount8(b))
  26. }
  27. return value
  28. }
  29. func (n *IPNetTable) Add(ipNet *net.IPNet) {
  30. ipv4 := ipNet.IP.To4()
  31. if ipv4 == nil {
  32. // For now, we don't support IPv6
  33. return
  34. }
  35. mask := ipMaskToByte(ipNet.Mask)
  36. n.AddIP(ipv4, mask)
  37. }
  38. func (n *IPNetTable) AddIP(ip []byte, mask byte) {
  39. k := ipToUint32(ip)
  40. k = (k >> (32 - mask)) << (32 - mask) // normalize ip
  41. existing, found := n.cache[k]
  42. if !found || existing > mask {
  43. n.cache[k] = mask
  44. }
  45. }
  46. func (n *IPNetTable) Contains(ip net.IP) bool {
  47. ipv4 := ip.To4()
  48. if ipv4 == nil {
  49. return false
  50. }
  51. originalValue := ipToUint32(ipv4)
  52. if entry, found := n.cache[originalValue]; found {
  53. if entry == 32 {
  54. return true
  55. }
  56. }
  57. mask := uint32(0)
  58. for maskbit := byte(1); maskbit <= 32; maskbit++ {
  59. mask += 1 << uint32(32-maskbit)
  60. maskedValue := originalValue & mask
  61. if entry, found := n.cache[maskedValue]; found {
  62. if entry == maskbit {
  63. return true
  64. }
  65. }
  66. }
  67. return false
  68. }
  69. func (n *IPNetTable) IsEmpty() bool {
  70. return len(n.cache) == 0
  71. }