Skip to content

Commit d6cd0f5

Browse files
committed
Move WMI related functions to cim package
1 parent 13f6143 commit d6cd0f5

File tree

2 files changed

+42
-30
lines changed

2 files changed

+42
-30
lines changed

pkg/cim/smb.go

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
package cim
55

66
import (
7+
"strings"
8+
79
"github.com/microsoft/wmi/pkg/base/query"
810
cim "github.com/microsoft/wmi/pkg/wmiinstance"
911
)
@@ -17,8 +19,24 @@ const (
1719
SmbMappingStatusConnecting
1820
SmbMappingStatusReconnecting
1921
SmbMappingStatusUnavailable
22+
23+
credentialDelimiter = ":"
2024
)
2125

26+
// escapeQueryParameter escapes a parameter for WMI Queries
27+
func escapeQueryParameter(s string) string {
28+
s = strings.ReplaceAll(s, "'", "''")
29+
s = strings.ReplaceAll(s, "\\", "\\\\")
30+
return s
31+
}
32+
33+
func escapeUserName(userName string) string {
34+
// refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170
35+
userName = strings.ReplaceAll(userName, "\\", "\\\\")
36+
userName = strings.ReplaceAll(userName, credentialDelimiter, "\\"+credentialDelimiter)
37+
return userName
38+
}
39+
2240
// QuerySmbGlobalMappingByRemotePath retrieves the SMB global mapping from its remote path.
2341
//
2442
// The equivalent WMI query is:
@@ -28,7 +46,7 @@ const (
2846
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
2947
// for the WMI class definition.
3048
func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, error) {
31-
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath)
49+
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath))
3250
instances, err := QueryInstances(WMINamespaceSmb, smbQuery)
3351
if err != nil {
3452
return nil, err
@@ -42,7 +60,7 @@ func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, err
4260
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
4361
// for the WMI class definition.
4462
func RemoveSmbGlobalMappingByRemotePath(remotePath string) error {
45-
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath)
63+
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath))
4664
instances, err := QueryInstances(WMINamespaceSmb, smbQuery)
4765
if err != nil {
4866
return err
@@ -51,3 +69,22 @@ func RemoveSmbGlobalMappingByRemotePath(remotePath string) error {
5169
_, err = instances[0].InvokeMethod("Remove", true)
5270
return err
5371
}
72+
73+
// NewSmbGlobalMapping creates a new SMB global mapping to the remote path.
74+
//
75+
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
76+
// for the WMI class definition.
77+
func NewSmbGlobalMapping(remotePath, username, password string, requirePrivacy bool) (int, error) {
78+
params := map[string]interface{}{
79+
"RemotePath": remotePath,
80+
"RequirePrivacy": requirePrivacy,
81+
}
82+
if username != "" {
83+
// refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178
84+
// on how SMB credential is handled in PowerShell
85+
params["Credential"] = escapeUserName(username) + credentialDelimiter + password
86+
}
87+
88+
result, _, err := InvokeCimMethod(WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params)
89+
return result, err
90+
}

pkg/os/smb/api.go

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@ import (
1010
"golang.org/x/sys/windows"
1111
)
1212

13-
const (
14-
credentialDelimiter = ":"
15-
)
16-
1713
type API interface {
1814
IsSmbMapped(remotePath string) (bool, error)
1915
NewSmbLink(remotePath, localPath string) error
@@ -33,17 +29,6 @@ func New(requirePrivacy bool) *SmbAPI {
3329
}
3430
}
3531

36-
func remotePathForQuery(remotePath string) string {
37-
return strings.ReplaceAll(remotePath, "\\", "\\\\")
38-
}
39-
40-
func escapeUserName(userName string) string {
41-
// refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170
42-
escaped := strings.ReplaceAll(userName, "\\", "\\\\")
43-
escaped = strings.ReplaceAll(escaped, credentialDelimiter, "\\"+credentialDelimiter)
44-
return escaped
45-
}
46-
4732
func createSymlink(link, target string, isDir bool) error {
4833
linkPtr, err := syscall.UTF16PtrFromString(link)
4934
if err != nil {
@@ -68,7 +53,7 @@ func createSymlink(link, target string, isDir bool) error {
6853
}
6954

7055
func (*SmbAPI) IsSmbMapped(remotePath string) (bool, error) {
71-
inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePathForQuery(remotePath))
56+
inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePath)
7257
if err != nil {
7358
return false, cim.IgnoreNotFound(err)
7459
}
@@ -106,17 +91,7 @@ func (*SmbAPI) NewSmbLink(remotePath, localPath string) error {
10691
}
10792

10893
func (api *SmbAPI) NewSmbGlobalMapping(remotePath, username, password string) error {
109-
params := map[string]interface{}{
110-
"RemotePath": remotePath,
111-
"RequirePrivacy": api.RequirePrivacy,
112-
}
113-
if username != "" {
114-
// refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178
115-
// on how SMB credential is handled in PowerShell
116-
params["Credential"] = escapeUserName(username) + credentialDelimiter + password
117-
}
118-
119-
result, _, err := cim.InvokeCimMethod(cim.WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params)
94+
result, err := cim.NewSmbGlobalMapping(remotePath, username, password, api.RequirePrivacy)
12095
if err != nil {
12196
return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err)
12297
}
@@ -125,7 +100,7 @@ func (api *SmbAPI) NewSmbGlobalMapping(remotePath, username, password string) er
125100
}
126101

127102
func (*SmbAPI) RemoveSmbGlobalMapping(remotePath string) error {
128-
err := cim.RemoveSmbGlobalMappingByRemotePath(remotePathForQuery(remotePath))
103+
err := cim.RemoveSmbGlobalMappingByRemotePath(remotePath)
129104
if err != nil {
130105
return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err)
131106
}

0 commit comments

Comments
 (0)