diff --git a/pkg/controllers/netpol/network_policy_controller.go b/pkg/controllers/netpol/network_policy_controller.go index e7ab15a6b..57f44206e 100644 --- a/pkg/controllers/netpol/network_policy_controller.go +++ b/pkg/controllers/netpol/network_policy_controller.go @@ -684,17 +684,28 @@ func (npc *NetworkPolicyController) cleanupStaleIPSets(activePolicyIPSets map[st }() } - for _, ipsets := range npc.ipSetHandlers { + for ipFamily, ipsets := range npc.ipSetHandlers { cleanupPolicyIPSets := make([]*utils.Set, 0) if err := ipsets.Save(); err != nil { klog.Fatalf("failed to initialize ipsets command executor due to %s", err.Error()) } - for _, set := range ipsets.Sets() { - if strings.HasPrefix(set.Name, kubeSourceIPSetPrefix) || - strings.HasPrefix(set.Name, kubeDestinationIPSetPrefix) { - if _, ok := activePolicyIPSets[set.Name]; !ok { - cleanupPolicyIPSets = append(cleanupPolicyIPSets, set) + if ipFamily == v1core.IPv6Protocol { + for _, set := range ipsets.Sets() { + if strings.HasPrefix(set.Name, fmt.Sprintf("%s:%s", utils.FamillyInet6, kubeSourceIPSetPrefix)) || + strings.HasPrefix(set.Name, fmt.Sprintf("%s:%s", utils.FamillyInet6, kubeDestinationIPSetPrefix)) { + if _, ok := activePolicyIPSets[set.Name]; !ok { + cleanupPolicyIPSets = append(cleanupPolicyIPSets, set) + } + } + } + } else { + for _, set := range ipsets.Sets() { + if strings.HasPrefix(set.Name, kubeSourceIPSetPrefix) || + strings.HasPrefix(set.Name, kubeDestinationIPSetPrefix) { + if _, ok := activePolicyIPSets[set.Name]; !ok { + cleanupPolicyIPSets = append(cleanupPolicyIPSets, set) + } } } } diff --git a/pkg/controllers/netpol/policy.go b/pkg/controllers/netpol/policy.go index 737979165..b6237bf66 100644 --- a/pkg/controllers/netpol/policy.go +++ b/pkg/controllers/netpol/policy.go @@ -474,10 +474,10 @@ func (npc *NetworkPolicyController) appendRuleToPolicyChain(policyChainName, com args = append(args, "-m", "comment", "--comment", "\""+comment+"\"") } if srcIPSetName != "" { - args = append(args, "-m", "set", "--match-set", srcIPSetName, "src") + args = append(args, "-m", "set", "--match-set", ipSetName(srcIPSetName, ipFamily), "src") } if dstIPSetName != "" { - args = append(args, "-m", "set", "--match-set", dstIPSetName, "dst") + args = append(args, "-m", "set", "--match-set", ipSetName(dstIPSetName, ipFamily), "dst") } if protocol != "" { args = append(args, "-p", protocol) @@ -904,13 +904,13 @@ func networkPolicyChainName(namespace, policyName string, version string, ipFami func policySourcePodIPSetName(namespace, policyName string, ipFamily api.IPFamily) string { hash := sha256.Sum256([]byte(namespace + policyName + string(ipFamily))) encoded := base32.StdEncoding.EncodeToString(hash[:]) - return kubeSourceIPSetPrefix + encoded[:16] + return ipSetName(kubeSourceIPSetPrefix+encoded[:16], ipFamily) } func policyDestinationPodIPSetName(namespace, policyName string, ipFamily api.IPFamily) string { hash := sha256.Sum256([]byte(namespace + policyName + string(ipFamily))) encoded := base32.StdEncoding.EncodeToString(hash[:]) - return kubeDestinationIPSetPrefix + encoded[:16] + return ipSetName(kubeDestinationIPSetPrefix+encoded[:16], ipFamily) } func policyIndexedSourcePodIPSetName( @@ -918,7 +918,7 @@ func policyIndexedSourcePodIPSetName( hash := sha256.Sum256([]byte(namespace + policyName + "ingressrule" + strconv.Itoa(ingressRuleNo) + string(ipFamily) + "pod")) encoded := base32.StdEncoding.EncodeToString(hash[:]) - return kubeSourceIPSetPrefix + encoded[:16] + return ipSetName(kubeSourceIPSetPrefix+encoded[:16], ipFamily) } func policyIndexedDestinationPodIPSetName( @@ -926,7 +926,7 @@ func policyIndexedDestinationPodIPSetName( hash := sha256.Sum256([]byte(namespace + policyName + "egressrule" + strconv.Itoa(egressRuleNo) + string(ipFamily) + "pod")) encoded := base32.StdEncoding.EncodeToString(hash[:]) - return kubeDestinationIPSetPrefix + encoded[:16] + return ipSetName(kubeDestinationIPSetPrefix+encoded[:16], ipFamily) } func policyIndexedSourceIPBlockIPSetName( @@ -934,7 +934,7 @@ func policyIndexedSourceIPBlockIPSetName( hash := sha256.Sum256([]byte(namespace + policyName + "ingressrule" + strconv.Itoa(ingressRuleNo) + string(ipFamily) + "ipblock")) encoded := base32.StdEncoding.EncodeToString(hash[:]) - return kubeSourceIPSetPrefix + encoded[:16] + return ipSetName(kubeSourceIPSetPrefix+encoded[:16], ipFamily) } func policyIndexedDestinationIPBlockIPSetName( @@ -942,7 +942,7 @@ func policyIndexedDestinationIPBlockIPSetName( hash := sha256.Sum256([]byte(namespace + policyName + "egressrule" + strconv.Itoa(egressRuleNo) + string(ipFamily) + "ipblock")) encoded := base32.StdEncoding.EncodeToString(hash[:]) - return kubeDestinationIPSetPrefix + encoded[:16] + return ipSetName(kubeDestinationIPSetPrefix+encoded[:16], ipFamily) } func policyIndexedIngressNamedPortIPSetName( @@ -950,7 +950,7 @@ func policyIndexedIngressNamedPortIPSetName( hash := sha256.Sum256([]byte(namespace + policyName + "ingressrule" + strconv.Itoa(ingressRuleNo) + strconv.Itoa(namedPortNo) + string(ipFamily) + "namedport")) encoded := base32.StdEncoding.EncodeToString(hash[:]) - return kubeDestinationIPSetPrefix + encoded[:16] + return ipSetName(kubeDestinationIPSetPrefix+encoded[:16], ipFamily) } func policyIndexedEgressNamedPortIPSetName( @@ -958,7 +958,7 @@ func policyIndexedEgressNamedPortIPSetName( hash := sha256.Sum256([]byte(namespace + policyName + "egressrule" + strconv.Itoa(egressRuleNo) + strconv.Itoa(namedPortNo) + string(ipFamily) + "namedport")) encoded := base32.StdEncoding.EncodeToString(hash[:]) - return kubeDestinationIPSetPrefix + encoded[:16] + return ipSetName(kubeDestinationIPSetPrefix+encoded[:16], ipFamily) } func policyRulePortsHasNamedPort(npPorts []networking.NetworkPolicyPort) bool { diff --git a/pkg/controllers/netpol/policy_test.go b/pkg/controllers/netpol/policy_test.go new file mode 100644 index 000000000..666eb40eb --- /dev/null +++ b/pkg/controllers/netpol/policy_test.go @@ -0,0 +1,87 @@ +package netpol + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" +) + +func testNamePrefix(t *testing.T, testString string, isIPv6 bool) { + if isIPv6 { + assert.Truef(t, strings.HasPrefix(testString, "inet6:"), "%s is IPv6 and should begin with inet6:", testString) + } +} + +func Test_policySourcePodIPSetName(t *testing.T) { + t.Run("Check IPv4 and IPv6 names are correct", func(t *testing.T) { + setName := policySourcePodIPSetName("foo", "bar", v1.IPv4Protocol) + testNamePrefix(t, setName, false) + setName = policySourcePodIPSetName("foo", "bar", v1.IPv6Protocol) + testNamePrefix(t, setName, true) + }) +} + +func Test_policyDestinationPodIPSetName(t *testing.T) { + t.Run("Check IPv4 and IPv6 names are correct", func(t *testing.T) { + setName := policyDestinationPodIPSetName("foo", "bar", v1.IPv4Protocol) + testNamePrefix(t, setName, false) + setName = policyDestinationPodIPSetName("foo", "bar", v1.IPv6Protocol) + testNamePrefix(t, setName, true) + }) +} + +func Test_policyIndexedSourcePodIPSetName(t *testing.T) { + t.Run("Check IPv4 and IPv6 names are correct", func(t *testing.T) { + setName := policyIndexedSourcePodIPSetName("foo", "bar", 1, v1.IPv4Protocol) + testNamePrefix(t, setName, false) + setName = policyIndexedSourcePodIPSetName("foo", "bar", 1, v1.IPv6Protocol) + testNamePrefix(t, setName, true) + }) +} + +func Test_policyIndexedDestinationPodIPSetName(t *testing.T) { + t.Run("Check IPv4 and IPv6 names are correct", func(t *testing.T) { + setName := policyIndexedDestinationPodIPSetName("foo", "bar", 1, v1.IPv4Protocol) + testNamePrefix(t, setName, false) + setName = policyIndexedDestinationPodIPSetName("foo", "bar", 1, v1.IPv6Protocol) + testNamePrefix(t, setName, true) + }) +} + +func Test_policyIndexedSourceIPBlockIPSetName(t *testing.T) { + t.Run("Check IPv4 and IPv6 names are correct", func(t *testing.T) { + setName := policyIndexedSourceIPBlockIPSetName("foo", "bar", 1, v1.IPv4Protocol) + testNamePrefix(t, setName, false) + setName = policyIndexedSourceIPBlockIPSetName("foo", "bar", 1, v1.IPv6Protocol) + testNamePrefix(t, setName, true) + }) +} + +func Test_policyIndexedDestinationIPBlockIPSetName(t *testing.T) { + t.Run("Check IPv4 and IPv6 names are correct", func(t *testing.T) { + setName := policyIndexedDestinationIPBlockIPSetName("foo", "bar", 1, v1.IPv4Protocol) + testNamePrefix(t, setName, false) + setName = policyIndexedDestinationIPBlockIPSetName("foo", "bar", 1, v1.IPv6Protocol) + testNamePrefix(t, setName, true) + }) +} + +func Test_policyIndexedIngressNamedPortIPSetName(t *testing.T) { + t.Run("Check IPv4 and IPv6 names are correct", func(t *testing.T) { + setName := policyIndexedIngressNamedPortIPSetName("foo", "bar", 1, 1, v1.IPv4Protocol) + testNamePrefix(t, setName, false) + setName = policyIndexedIngressNamedPortIPSetName("foo", "bar", 1, 1, v1.IPv6Protocol) + testNamePrefix(t, setName, true) + }) +} + +func Test_policyIndexedEgressNamedPortIPSetName(t *testing.T) { + t.Run("Check IPv4 and IPv6 names are correct", func(t *testing.T) { + setName := policyIndexedEgressNamedPortIPSetName("foo", "bar", 1, 1, v1.IPv4Protocol) + testNamePrefix(t, setName, false) + setName = policyIndexedEgressNamedPortIPSetName("foo", "bar", 1, 1, v1.IPv6Protocol) + testNamePrefix(t, setName, true) + }) +} diff --git a/pkg/controllers/netpol/utils.go b/pkg/controllers/netpol/utils.go index 328ffae1a..ebdc39473 100644 --- a/pkg/controllers/netpol/utils.go +++ b/pkg/controllers/netpol/utils.go @@ -166,3 +166,11 @@ func getPodIPForFamily(pod podInfo, ipFamily api.IPFamily) (string, error) { return "", fmt.Errorf("did not recognize IP Family for pod: %s:%s family: %s", pod.namespace, pod.name, ipFamily) } + +func ipSetName(setName string, ipFamily api.IPFamily) string { + if ipFamily == api.IPv4Protocol { + return utils.IPSetName(setName, false) + } else { + return utils.IPSetName(setName, true) + } +} diff --git a/pkg/utils/ipset.go b/pkg/utils/ipset.go index 9cb69c8c2..0b8ee33f9 100644 --- a/pkg/utils/ipset.go +++ b/pkg/utils/ipset.go @@ -427,13 +427,18 @@ func (set *Set) IsActive() (bool, error) { return true, nil } -func (ipset *IPSet) Name(setName string) string { - if ipset.isIpv6 && !strings.HasPrefix(setName, IPv6SetPrefix+":") { +// IPSetName returns the proper set name for this component based upon whether or not it is an IPv6 set +func IPSetName(setName string, isIPv6 bool) string { + if isIPv6 && !strings.HasPrefix(setName, IPv6SetPrefix+":") { return fmt.Sprintf("%s:%s", IPv6SetPrefix, setName) } return setName } +func (ipset *IPSet) Name(setName string) string { + return IPSetName(setName, ipset.isIpv6) +} + func (set *Set) name() string { return set.Parent.Name(set.Name) }