Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("up wg interface: %w", err)
}



// if inbound conns are blocked there is no need to create the ACL manager
if e.firewall != nil && !e.config.BlockInbound {
e.acl = acl.NewDefaultManager(e.firewall)
Expand Down Expand Up @@ -760,7 +758,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
}

nm := update.GetNetworkMap()
if nm == nil {
if nm == nil || update.SkipNetworkMapUpdate {
return nil
}

Expand Down
79 changes: 79 additions & 0 deletions client/internal/engine_sync_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package internal

import (
"context"
"testing"

"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
)

// Ensures handleSync exits early when SkipNetworkMapUpdate is true
func TestEngine_HandleSync_SkipNetworkMapUpdate(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
WgIfaceName: "utun199",
WgAddr: "100.70.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx

// Precondition
if engine.networkSerial != 0 {
t.Fatalf("unexpected initial serial: %d", engine.networkSerial)
}

resp := &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{Serial: 42},
SkipNetworkMapUpdate: true,
}

if err := engine.handleSync(resp); err != nil {
t.Fatalf("handleSync returned error: %v", err)
}

if engine.networkSerial != 0 {
t.Fatalf("networkSerial changed despite SkipNetworkMapUpdate; got %d, want 0", engine.networkSerial)
}
}

// Ensures handleSync exits early when NetworkMap is nil
func TestEngine_HandleSync_NilNetworkMap(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
WgIfaceName: "utun198",
WgAddr: "100.70.0.2/24",
WgPrivateKey: key,
WgPort: 33101,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx

resp := &mgmtProto.SyncResponse{NetworkMap: nil}

if err := engine.handleSync(resp); err != nil {
t.Fatalf("handleSync returned error: %v", err)
}
}


15 changes: 15 additions & 0 deletions client/system/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ func (i *Info) SetFlags(
i.LazyConnectionEnabled = lazyConnectionEnabled
}

func (i *Info) CopyFlagsFrom(other *Info) {
i.SetFlags(
other.RosenpassEnabled,
other.RosenpassPermissive,
&other.ServerSSHAllowed,
other.DisableClientRoutes,
other.DisableServerRoutes,
other.DisableDNS,
other.DisableFirewall,
other.BlockLANAccess,
other.BlockInbound,
other.LazyConnectionEnabled,
)
}

// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
func extractUserAgent(ctx context.Context) string {
md, hasMeta := metadata.FromOutgoingContext(ctx)
Expand Down
67 changes: 63 additions & 4 deletions client/system/info_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,72 @@
package system

import (
"context"
"testing"
"context"
"testing"

"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
)

func TestInfo_CopyFlagsFrom(t *testing.T) {
origin := &Info{}
serverSSHAllowed := true
origin.SetFlags(
true, // RosenpassEnabled
false, // RosenpassPermissive
&serverSSHAllowed,
true, // DisableClientRoutes
false, // DisableServerRoutes
true, // DisableDNS
false, // DisableFirewall
true, // BlockLANAccess
false, // BlockInbound
true, // LazyConnectionEnabled
)

got := &Info{}
got.CopyFlagsFrom(origin)

if got.RosenpassEnabled != true {
t.Fatalf("RosenpassEnabled not copied: got %v", got.RosenpassEnabled)
}
if got.RosenpassPermissive != false {
t.Fatalf("RosenpassPermissive not copied: got %v", got.RosenpassPermissive)
}
if got.ServerSSHAllowed != true {
t.Fatalf("ServerSSHAllowed not copied: got %v", got.ServerSSHAllowed)
}
if got.DisableClientRoutes != true {
t.Fatalf("DisableClientRoutes not copied: got %v", got.DisableClientRoutes)
}
if got.DisableServerRoutes != false {
t.Fatalf("DisableServerRoutes not copied: got %v", got.DisableServerRoutes)
}
if got.DisableDNS != true {
t.Fatalf("DisableDNS not copied: got %v", got.DisableDNS)
}
if got.DisableFirewall != false {
t.Fatalf("DisableFirewall not copied: got %v", got.DisableFirewall)
}
if got.BlockLANAccess != true {
t.Fatalf("BlockLANAccess not copied: got %v", got.BlockLANAccess)
}
if got.BlockInbound != false {
t.Fatalf("BlockInbound not copied: got %v", got.BlockInbound)
}
if got.LazyConnectionEnabled != true {
t.Fatalf("LazyConnectionEnabled not copied: got %v", got.LazyConnectionEnabled)
}

// ensure CopyFlagsFrom does not touch unrelated fields
origin.Hostname = "host-a"
got.Hostname = "host-b"
got.CopyFlagsFrom(origin)
if got.Hostname != "host-b" {
t.Fatalf("CopyFlagsFrom should not overwrite non-flag fields, got Hostname=%q", got.Hostname)
}
}

func Test_LocalWTVersion(t *testing.T) {
got := GetInfo(context.TODO())
want := "development"
Expand Down
37 changes: 36 additions & 1 deletion shared/management/client/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ type GrpcClient struct {
conn *grpc.ClientConn
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
// lastNetworkMapSerial stores last seen network map serial to optimize sync
lastNetworkMapSerial uint64
lastNetworkMapSerialMu sync.Mutex
}

// NewClient creates a new client to Management service
Expand Down Expand Up @@ -216,11 +219,23 @@ func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, err
return nil, fmt.Errorf("invalid msg, required network map")
}

// update last seen serial
c.setLastNetworkMapSerial(decryptedResp.GetNetworkMap().GetSerial())

return decryptedResp.GetNetworkMap(), nil
}

func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{Meta: infoToMetaData(sysInfo)}
// Always compute latest system info to ensure up-to-date PeerSystemMeta on first and subsequent syncs
recomputed := system.GetInfo(c.ctx)
if sysInfo != nil {
recomputed.CopyFlagsFrom(sysInfo)
// carry over posture files if any were computed
if len(sysInfo.Files) > 0 {
recomputed.Files = sysInfo.Files
}
}
req := &proto.SyncRequest{Meta: infoToMetaData(recomputed), NetworkMapSerial: c.getLastNetworkMapSerial()}

myPrivateKey := c.key
myPublicKey := myPrivateKey.PublicKey()
Expand Down Expand Up @@ -258,6 +273,11 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se
return err
}

// track latest network map serial if present
if decryptedResp.GetNetworkMap() != nil {
c.setLastNetworkMapSerial(decryptedResp.GetNetworkMap().GetSerial())
}

if err := msgHandler(decryptedResp); err != nil {
log.Errorf("failed handling an update message received from Management Service: %v", err.Error())
}
Expand Down Expand Up @@ -582,3 +602,18 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
},
}
}

// setLastNetworkMapSerial updates the cached last seen network map serial in a 32-bit safe manner
func (c *GrpcClient) setLastNetworkMapSerial(serial uint64) {
c.lastNetworkMapSerialMu.Lock()
c.lastNetworkMapSerial = serial
c.lastNetworkMapSerialMu.Unlock()
}

// getLastNetworkMapSerial returns the cached last seen network map serial in a 32-bit safe manner
func (c *GrpcClient) getLastNetworkMapSerial() uint64 {
c.lastNetworkMapSerialMu.Lock()
v := c.lastNetworkMapSerial
c.lastNetworkMapSerialMu.Unlock()
return v
}
26 changes: 26 additions & 0 deletions shared/management/client/grpc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package client

import (
"testing"
)

func TestGrpcClient_LastNetworkMapSerial_SetGet(t *testing.T) {
c := &GrpcClient{}

if got := c.getLastNetworkMapSerial(); got != 0 {
t.Fatalf("initial serial should be 0, got %d", got)
}

c.setLastNetworkMapSerial(123)
if got := c.getLastNetworkMapSerial(); got != 123 {
t.Fatalf("serial after set should be 123, got %d", got)
}

// overwrite should work
c.setLastNetworkMapSerial(5)
if got := c.getLastNetworkMapSerial(); got != 5 {
t.Fatalf("serial after overwrite should be 5, got %d", got)
}
}


Loading
Loading