diff --git a/ip_range/ip_range.go b/ip_range/ip_range.go index 21e6a60..51018b4 100644 --- a/ip_range/ip_range.go +++ b/ip_range/ip_range.go @@ -2,10 +2,8 @@ package iprange import ( "fmt" - "net" + "net/netip" "strings" - - "github.com/charlienet/go-mixed/bytesconv" ) type IpRange struct { @@ -13,40 +11,32 @@ type IpRange struct { } type ipSegment interface { - Contains(net.IP) bool + Contains(netip.Addr) bool } type singleIp struct { - ip net.IP + ip netip.Addr } -func (i *singleIp) Contains(ip net.IP) bool { - return i.ip.Equal(ip) +func (i *singleIp) Contains(ip netip.Addr) bool { + return i.ip.Compare(ip) == 0 } -type cidrSegments struct { - cidr *net.IPNet +type prefixSegments struct { + prefix netip.Prefix } -func (i *cidrSegments) Contains(ip net.IP) bool { - return i.cidr.Contains(ip) +func (i *prefixSegments) Contains(ip netip.Addr) bool { + return i.prefix.Contains(ip) } type rangeSegment struct { - start rangeIP - end rangeIP + start netip.Addr + end netip.Addr } -type rangeIP struct { - Hight uint64 - Lower uint64 -} - -func (r *rangeSegment) Contains(ip net.IP) bool { - ih, _ := bytesconv.BigEndian.BytesToUInt64(ip[:8]) - i, _ := bytesconv.BigEndian.BytesToUInt64(ip[8:]) - - return ih >= r.start.Hight && ih <= r.end.Hight && i >= r.start.Lower && i <= r.end.Lower +func (r *rangeSegment) Contains(ip netip.Addr) bool { + return ip.Compare(r.start) >= 0 && ip.Compare(r.end) <= 0 } // IP范围判断,支持以下规则: @@ -55,6 +45,7 @@ func (r *rangeSegment) Contains(ip net.IP) bool { // 掩码模式,如 192.168.2.0/24 func NewRange(ip ...string) (*IpRange, error) { seg := make([]ipSegment, 0, len(ip)) + for _, i := range ip { if s, err := createSegment(i); err != nil { return nil, err @@ -67,13 +58,13 @@ func NewRange(ip ...string) (*IpRange, error) { } func (r *IpRange) Contains(ip string) bool { - nip := net.ParseIP(ip) - if nip == nil { + addr, err := netip.ParseAddr(ip) + if err != nil { return false } for _, v := range r.segments { - if v.Contains(nip) { + if v.Contains(addr) { return true } } @@ -89,34 +80,30 @@ func createSegment(ip string) (ipSegment, error) { return nil, fmt.Errorf("IP范围定义错误:%s", ip) } - start := net.ParseIP(ips[0]) - end := net.ParseIP(ips[1]) - if start == nil { - return nil, fmt.Errorf("IP范围起始地址格式错误:%s", ips[0]) + start, err := netip.ParseAddr(ips[0]) + if err != nil { + return nil, err } - if end == nil { - return nil, fmt.Errorf("IP范围结束地址格式错误:%s", ips[0]) + end, err := netip.ParseAddr(ips[1]) + if err != nil { + return nil, err } - sh, _ := bytesconv.BigEndian.BytesToUInt64(start[:8]) - s, _ := bytesconv.BigEndian.BytesToUInt64(start[8:]) - eh, _ := bytesconv.BigEndian.BytesToUInt64(end[:8]) - e, _ := bytesconv.BigEndian.BytesToUInt64(end[8:]) - - return &rangeSegment{start: rangeIP{ - Hight: sh, Lower: s}, - end: rangeIP{Hight: eh, Lower: e}}, nil + return &rangeSegment{ + start: start, + end: end, + }, nil case strings.Contains(ip, "/"): - if _, cidr, err := net.ParseCIDR(ip); err != nil { + if prefix, err := netip.ParsePrefix(ip); err != nil { return nil, err } else { - return &cidrSegments{cidr: cidr}, nil + return &prefixSegments{prefix: prefix}, nil } default: - i := net.ParseIP(ip) - if i == nil { + i, err := netip.ParseAddr(ip) + if err != nil { return nil, fmt.Errorf("格式错误, 不是有效的IP地址:%s", ip) } diff --git a/ip_range/ip_range_test.go b/ip_range/ip_range_test.go index 34bc535..e4e325c 100644 --- a/ip_range/ip_range_test.go +++ b/ip_range/ip_range_test.go @@ -1,6 +1,7 @@ package iprange import ( + "net/netip" "testing" "github.com/stretchr/testify/assert" @@ -29,7 +30,7 @@ func TestSingleIp(t *testing.T) { assert.False(t, r.Contains("192.168.0.123")) } -func TestCIDR(t *testing.T) { +func TestPrefix(t *testing.T) { r, err := NewRange("192.168.2.0/24") if err != nil { t.Fatal(err) @@ -40,6 +41,18 @@ func TestCIDR(t *testing.T) { assert.False(t, r.Contains("192.168.3.162")) } +func TestPrefix2(t *testing.T) { + r, err := NewRange("192.168.15.0/21") + if err != nil { + t.Fatal(err) + } + + assert.True(t, r.Contains("192.168.8.10")) + assert.True(t, r.Contains("192.168.14.162")) + assert.False(t, r.Contains("192.168.3.162")) + assert.False(t, r.Contains("192.168.2.162")) +} + func TestRange(t *testing.T) { r, err := NewRange("192.168.2.20-192.168.2.30") if err != nil { @@ -52,5 +65,29 @@ func TestRange(t *testing.T) { assert.False(t, r.Contains("192.168.2.10")) assert.False(t, r.Contains("192.168.2.31")) - +} + +func TestLocalhost(t *testing.T) { + r, err := NewRange("::1") + if err != nil { + t.Fatal(err) + } + + assert.True(t, r.Contains("::1")) +} + +func TestNetIP(t *testing.T) { + addr, err := netip.ParseAddr("192.168.2.10") + if err != nil { + t.Fatal(err) + } + + t.Log(netip.MustParseAddr("192.168.2.4").Compare(addr)) + t.Log(netip.MustParseAddr("192.168.2.10").Compare(addr)) + t.Log(netip.MustParseAddr("192.168.2.11").Compare(addr)) + + prefix := netip.MustParsePrefix("192.168.2.0/24") + + t.Log(prefix.Contains(netip.MustParseAddr("192.168.2.53"))) + t.Log(prefix.Contains(netip.MustParseAddr("192.168.3.53"))) }