Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package nebula

import (
"context"
"iter"
"net/netip"
"os"
"os/signal"
Expand Down Expand Up @@ -120,6 +121,15 @@ func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
}
}

// ListHostmapHostsIter returns an iter with details about the actual or pending (handshaking) hostmap by vpn ip
func (c *Control) ListHostmapHostsIter(pendingMap bool) iter.Seq[*ControlHostInfo] {
if pendingMap {
return listHostMapHostsIter(c.f.handshakeManager)
} else {
return listHostMapHostsIter(c.f.hostMap)
}
}

// ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id
func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
if pendingMap {
Expand All @@ -129,6 +139,15 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
}
}

// ListHostmapIndexesIter returns an iter with details about the actual or pending (handshaking) hostmap by local index id
func (c *Control) ListHostmapIndexesIter(pendingMap bool) iter.Seq[*ControlHostInfo] {
if pendingMap {
return listHostMapIndexesIter(c.f.handshakeManager)
} else {
return listHostMapIndexesIter(c.f.hostMap)
}
}

// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
Expand Down Expand Up @@ -306,6 +325,19 @@ func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
return hosts
}

func listHostMapHostsIter(hl controlHostLister) iter.Seq[*ControlHostInfo] {
pr := hl.GetPreferredRanges()

return iter.Seq[*ControlHostInfo](func(yield func(*ControlHostInfo) bool) {
hl.ForEachVpnAddr(func(hostinfo *HostInfo) {
host := copyHostInfo(hostinfo, pr)
if !yield(&host) {
return // Stop iteration early if yield returns false
}
})
})
}

func listHostMapIndexes(hl controlHostLister) []ControlHostInfo {
hosts := make([]ControlHostInfo, 0)
pr := hl.GetPreferredRanges()
Expand All @@ -314,3 +346,16 @@ func listHostMapIndexes(hl controlHostLister) []ControlHostInfo {
})
return hosts
}

func listHostMapIndexesIter(hl controlHostLister) iter.Seq[*ControlHostInfo] {
pr := hl.GetPreferredRanges()

return iter.Seq[*ControlHostInfo](func(yield func(*ControlHostInfo) bool) {
hl.ForEachIndex(func(hostinfo *HostInfo) {
host := copyHostInfo(hostinfo, pr)
if !yield(&host) {
return // Stop iteration early if yield returns false
}
})
})
}
97 changes: 97 additions & 0 deletions control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"net"
"net/netip"
"reflect"
"sort"
"testing"

"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -110,6 +111,102 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
})
}

func TestListHostMapHostsIter(t *testing.T) {
l := logrus.New()
hm := newHostMap(l)
hm.preferredRanges.Store(&[]netip.Prefix{})

hosts := []struct {
vpnIp netip.Addr
remoteAddr netip.AddrPort
localIndexId uint32
remoteIndexId uint32
}{
{vpnIp: netip.MustParseAddr("0.0.0.2"), remoteAddr: netip.MustParseAddrPort("0.0.0.101:4445"), localIndexId: 202, remoteIndexId: 201},
{vpnIp: netip.MustParseAddr("0.0.0.3"), remoteAddr: netip.MustParseAddrPort("0.0.0.102:4446"), localIndexId: 203, remoteIndexId: 202},
{vpnIp: netip.MustParseAddr("0.0.0.4"), remoteAddr: netip.MustParseAddrPort("0.0.0.103:4447"), localIndexId: 204, remoteIndexId: 203},
}

for _, h := range hosts {
hm.unlockedAddHostInfo(&HostInfo{
remote: h.remoteAddr,
ConnectionState: &ConnectionState{
peerCert: nil,
},
localIndexId: h.localIndexId,
remoteIndexId: h.remoteIndexId,
vpnAddrs: []netip.Addr{h.vpnIp},
}, &Interface{})
}

iter := listHostMapHostsIter(hm)
var results []ControlHostInfo

for h := range iter {
results = append(results, *h)
}

sort.Slice(results, func(i, j int) bool {
return results[i].VpnAddrs[0].Less(results[j].VpnAddrs[0])
})

assert.Equal(t, len(hosts), len(results), "expected number of hosts in iterator")
for i, h := range hosts {
assert.Equal(t, h.vpnIp, results[i].VpnAddrs[0])
assert.Equal(t, h.localIndexId, results[i].LocalIndex)
assert.Equal(t, h.remoteIndexId, results[i].RemoteIndex)
assert.Equal(t, h.remoteAddr, results[i].CurrentRemote)
}
}

func TestListHostMapIndexesIter(t *testing.T) {
l := logrus.New()
hm := newHostMap(l)
hm.preferredRanges.Store(&[]netip.Prefix{})

hosts := []struct {
vpnIp netip.Addr
remoteAddr netip.AddrPort
localIndexId uint32
remoteIndexId uint32
}{
{vpnIp: netip.MustParseAddr("0.0.0.2"), remoteAddr: netip.MustParseAddrPort("0.0.0.101:4445"), localIndexId: 202, remoteIndexId: 201},
{vpnIp: netip.MustParseAddr("0.0.0.3"), remoteAddr: netip.MustParseAddrPort("0.0.0.102:4446"), localIndexId: 203, remoteIndexId: 202},
{vpnIp: netip.MustParseAddr("0.0.0.4"), remoteAddr: netip.MustParseAddrPort("0.0.0.103:4447"), localIndexId: 204, remoteIndexId: 203},
}

for _, h := range hosts {
hm.unlockedAddHostInfo(&HostInfo{
remote: h.remoteAddr,
ConnectionState: &ConnectionState{
peerCert: nil,
},
localIndexId: h.localIndexId,
remoteIndexId: h.remoteIndexId,
vpnAddrs: []netip.Addr{h.vpnIp},
}, &Interface{})
}

iter := listHostMapIndexesIter(hm)
var results []ControlHostInfo

for h := range iter {
results = append(results, *h)
}

sort.Slice(results, func(i, j int) bool {
return results[i].VpnAddrs[0].Less(results[j].VpnAddrs[0])
})

assert.Equal(t, len(hosts), len(results), "expected number of hosts in iterator")
for i, h := range hosts {
assert.Equal(t, h.vpnIp, results[i].VpnAddrs[0])
assert.Equal(t, h.localIndexId, results[i].LocalIndex)
assert.Equal(t, h.remoteIndexId, results[i].RemoteIndex)
assert.Equal(t, h.remoteAddr, results[i].CurrentRemote)
}
}

func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
val := reflect.ValueOf(actualStruct).Elem()
fields := make([]string, val.NumField())
Expand Down