Skip to content

Commit a05b58c

Browse files
committed
Ensure COM threading apartment in SMB APIs
1 parent 4dcae10 commit a05b58c

File tree

1 file changed

+30
-23
lines changed

1 file changed

+30
-23
lines changed

pkg/os/smb/api.go

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,22 @@ func New(requirePrivacy bool) *SmbAPI {
2828
}
2929

3030
func (*SmbAPI) IsSmbMapped(remotePath string) (bool, error) {
31-
inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePath)
32-
if err != nil {
33-
return false, cim.IgnoreNotFound(err)
34-
}
35-
36-
status, err := cim.GetSmbGlobalMappingStatus(inst)
37-
if err != nil {
38-
return false, err
39-
}
40-
41-
return status == cim.SmbMappingStatusOK, nil
31+
var isMapped bool
32+
err := cim.WithCOMThread(func() error {
33+
inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePath)
34+
if err != nil {
35+
return err
36+
}
37+
38+
status, err := cim.GetSmbGlobalMappingStatus(inst)
39+
if err != nil {
40+
return err
41+
}
42+
43+
isMapped = status == cim.SmbMappingStatusOK
44+
return nil
45+
})
46+
return isMapped, cim.IgnoreNotFound(err)
4247
}
4348

4449
// NewSmbLink - creates a directory symbolic link to the remote share.
@@ -62,19 +67,21 @@ func (*SmbAPI) NewSmbLink(remotePath, localPath string) error {
6267
}
6368

6469
func (api *SmbAPI) NewSmbGlobalMapping(remotePath, username, password string) error {
65-
result, err := cim.NewSmbGlobalMapping(remotePath, username, password, api.RequirePrivacy)
66-
if err != nil {
67-
return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err)
68-
}
69-
70-
return nil
70+
return cim.WithCOMThread(func() error {
71+
result, err := cim.NewSmbGlobalMapping(remotePath, username, password, api.RequirePrivacy)
72+
if err != nil {
73+
return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err)
74+
}
75+
return nil
76+
})
7177
}
7278

7379
func (*SmbAPI) RemoveSmbGlobalMapping(remotePath string) error {
74-
err := cim.RemoveSmbGlobalMappingByRemotePath(remotePath)
75-
if err != nil {
76-
return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err)
77-
}
78-
79-
return nil
80+
return cim.WithCOMThread(func() error {
81+
err := cim.RemoveSmbGlobalMappingByRemotePath(remotePath)
82+
if err != nil {
83+
return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err)
84+
}
85+
return nil
86+
})
8087
}

0 commit comments

Comments
 (0)