Skip to content

Commit 24f2ecd

Browse files
authored
Merge pull request #387 from laozc/wmi-process-smb-api-create
feat: Use WMI to create SMB Global Mapping to reduce PowerShell overhead
2 parents a9bd679 + 238c7fb commit 24f2ecd

File tree

5 files changed

+140
-29
lines changed

5 files changed

+140
-29
lines changed

pkg/cim/smb.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//go:build windows
2+
// +build windows
3+
4+
package cim
5+
6+
import (
7+
"github.com/microsoft/wmi/pkg/base/query"
8+
cim "github.com/microsoft/wmi/pkg/wmiinstance"
9+
)
10+
11+
// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/smb/msft-smbmapping
12+
const (
13+
SmbMappingStatusOK int32 = iota
14+
SmbMappingStatusPaused
15+
SmbMappingStatusDisconnected
16+
SmbMappingStatusNetworkError
17+
SmbMappingStatusConnecting
18+
SmbMappingStatusReconnecting
19+
SmbMappingStatusUnavailable
20+
)
21+
22+
// QuerySmbGlobalMappingByRemotePath retrieves the SMB global mapping from its remote path.
23+
//
24+
// The equivalent WMI query is:
25+
//
26+
// SELECT [selectors] FROM MSFT_SmbGlobalMapping
27+
//
28+
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
29+
// for the WMI class definition.
30+
func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, error) {
31+
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath)
32+
instances, err := QueryInstances(WMINamespaceSmb, smbQuery)
33+
if err != nil {
34+
return nil, err
35+
}
36+
37+
return instances[0], err
38+
}
39+
40+
// RemoveSmbGlobalMappingByRemotePath removes a SMB global mapping matching to the remote path.
41+
//
42+
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
43+
// for the WMI class definition.
44+
func RemoveSmbGlobalMappingByRemotePath(remotePath string) error {
45+
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath)
46+
instances, err := QueryInstances(WMINamespaceSmb, smbQuery)
47+
if err != nil {
48+
return err
49+
}
50+
51+
_, err = instances[0].InvokeMethod("Remove", true)
52+
return err
53+
}

pkg/cim/wmi.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
const (
1919
WMINamespaceRoot = "Root\\CimV2"
2020
WMINamespaceStorage = "Root\\Microsoft\\Windows\\Storage"
21+
WMINamespaceSmb = "Root\\Microsoft\\Windows\\Smb"
2122
)
2223

2324
type InstanceHandler func(instance *cim.WmiInstance) (bool, error)

pkg/os/smb/api.go

Lines changed: 70 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,15 @@ package smb
33
import (
44
"fmt"
55
"strings"
6+
"syscall"
67

8+
"github.com/kubernetes-csi/csi-proxy/pkg/cim"
79
"github.com/kubernetes-csi/csi-proxy/pkg/utils"
10+
"golang.org/x/sys/windows"
11+
)
12+
13+
const (
14+
credentialDelimiter = ":"
815
)
916

1017
type API interface {
@@ -26,18 +33,52 @@ func New(requirePrivacy bool) *SmbAPI {
2633
}
2734
}
2835

36+
func remotePathForQuery(remotePath string) string {
37+
return strings.ReplaceAll(remotePath, "\\", "\\\\")
38+
}
39+
40+
func escapeUserName(userName string) string {
41+
// refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170
42+
escaped := strings.ReplaceAll(userName, "\\", "\\\\")
43+
escaped = strings.ReplaceAll(escaped, credentialDelimiter, "\\"+credentialDelimiter)
44+
return escaped
45+
}
46+
47+
func createSymlink(link, target string, isDir bool) error {
48+
linkPtr, err := syscall.UTF16PtrFromString(link)
49+
if err != nil {
50+
return err
51+
}
52+
targetPtr, err := syscall.UTF16PtrFromString(target)
53+
if err != nil {
54+
return err
55+
}
56+
57+
var flags uint32
58+
if isDir {
59+
flags = windows.SYMBOLIC_LINK_FLAG_DIRECTORY
60+
}
61+
62+
err = windows.CreateSymbolicLink(
63+
linkPtr,
64+
targetPtr,
65+
flags,
66+
)
67+
return err
68+
}
69+
2970
func (*SmbAPI) IsSmbMapped(remotePath string) (bool, error) {
30-
cmdLine := `$(Get-SmbGlobalMapping -RemotePath $Env:smbremotepath -ErrorAction Stop).Status `
31-
cmdEnv := fmt.Sprintf("smbremotepath=%s", remotePath)
32-
out, err := utils.RunPowershellCmd(cmdLine, cmdEnv)
71+
inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePathForQuery(remotePath))
3372
if err != nil {
34-
return false, fmt.Errorf("error checking smb mapping. cmd %s, output: %s, err: %v", remotePath, string(out), err)
73+
return false, cim.IgnoreNotFound(err)
3574
}
3675

37-
if len(out) == 0 || !strings.EqualFold(strings.TrimSpace(string(out)), "OK") {
38-
return false, nil
76+
status, err := inst.GetProperty("Status")
77+
if err != nil {
78+
return false, err
3979
}
40-
return true, nil
80+
81+
return status.(int32) == cim.SmbMappingStatusOK, nil
4182
}
4283

4384
// NewSmbLink - creates a directory symbolic link to the remote share.
@@ -48,42 +89,46 @@ func (*SmbAPI) IsSmbMapped(remotePath string) (bool, error) {
4889
// alpha to merge the paths.
4990
// TODO (for beta release): Merge the link paths - os.Symlink and Powershell link path.
5091
func (*SmbAPI) NewSmbLink(remotePath, localPath string) error {
51-
5292
if !strings.HasSuffix(remotePath, "\\") {
5393
// Golang has issues resolving paths mapped to file shares if they do not end in a trailing \
5494
// so add one if needed.
5595
remotePath = remotePath + "\\"
5696
}
97+
longRemotePath := utils.EnsureLongPath(remotePath)
98+
longLocalPath := utils.EnsureLongPath(localPath)
5799

58-
cmdLine := `New-Item -ItemType SymbolicLink $Env:smblocalPath -Target $Env:smbremotepath`
59-
output, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("smbremotepath=%s", remotePath), fmt.Sprintf("smblocalpath=%s", localPath))
100+
err := createSymlink(longLocalPath, longRemotePath, true)
60101
if err != nil {
61-
return fmt.Errorf("error linking %s to %s. output: %s, err: %v", remotePath, localPath, string(output), err)
102+
return fmt.Errorf("error linking %s to %s. err: %v", remotePath, localPath, err)
62103
}
63104

64105
return nil
65106
}
66107

67108
func (api *SmbAPI) NewSmbGlobalMapping(remotePath, username, password string) error {
68-
// use PowerShell Environment Variables to store user input string to prevent command line injection
69-
// https://docs.microsoft.com/en-us/powershell/module/microsoft.powershell.core/about/about_environment_variables?view=powershell-5.1
70-
cmdLine := fmt.Sprintf(`$PWord = ConvertTo-SecureString -String $Env:smbpassword -AsPlainText -Force`+
71-
`;$Credential = New-Object -TypeName System.Management.Automation.PSCredential -ArgumentList $Env:smbuser, $PWord`+
72-
`;New-SmbGlobalMapping -RemotePath $Env:smbremotepath -Credential $Credential -RequirePrivacy $%t`, api.RequirePrivacy)
73-
74-
if output, err := utils.RunPowershellCmd(cmdLine,
75-
fmt.Sprintf("smbuser=%s", username),
76-
fmt.Sprintf("smbpassword=%s", password),
77-
fmt.Sprintf("smbremotepath=%s", remotePath)); err != nil {
78-
return fmt.Errorf("NewSmbGlobalMapping failed. output: %q, err: %v", string(output), err)
109+
params := map[string]interface{}{
110+
"RemotePath": remotePath,
111+
"RequirePrivacy": api.RequirePrivacy,
112+
}
113+
if username != "" {
114+
// refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178
115+
// on how SMB credential is handled in PowerShell
116+
params["Credential"] = escapeUserName(username) + credentialDelimiter + password
79117
}
118+
119+
result, _, err := cim.InvokeCimMethod(cim.WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params)
120+
if err != nil {
121+
return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err)
122+
}
123+
80124
return nil
81125
}
82126

83127
func (*SmbAPI) RemoveSmbGlobalMapping(remotePath string) error {
84-
cmd := `Remove-SmbGlobalMapping -RemotePath $Env:smbremotepath -Force`
85-
if output, err := utils.RunPowershellCmd(cmd, fmt.Sprintf("smbremotepath=%s", remotePath)); err != nil {
86-
return fmt.Errorf("UnmountSmbShare failed. output: %q, err: %v", string(output), err)
128+
err := cim.RemoveSmbGlobalMappingByRemotePath(remotePathForQuery(remotePath))
129+
if err != nil {
130+
return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err)
87131
}
132+
88133
return nil
89134
}

pkg/os/volume/api.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/go-ole/go-ole"
1212
"github.com/kubernetes-csi/csi-proxy/pkg/cim"
13+
"github.com/kubernetes-csi/csi-proxy/pkg/utils"
1314
wmierrors "github.com/microsoft/wmi/pkg/errors"
1415
"github.com/pkg/errors"
1516
"golang.org/x/sys/windows"
@@ -57,8 +58,6 @@ var (
5758
// PS C:\disks> (Get-Disk -Number 1 | Get-Partition | Get-Volume).UniqueId
5859
// \\?\Volume{452e318a-5cde-421e-9831-b9853c521012}\
5960
VolumeRegexp = regexp.MustCompile(`Volume\{[\w-]*\}`)
60-
// longPathPrefix is the prefix of Windows long path
61-
longPathPrefix = "\\\\?\\"
6261

6362
notMountedFolder = errors.New("not a mounted folder")
6463
)
@@ -337,7 +336,7 @@ func getTarget(mount string) (string, error) {
337336
if err != nil {
338337
return "", err
339338
}
340-
targetPath := longPathPrefix + windows.UTF16PtrToString(&outPathBuffer[0])
339+
targetPath := utils.EnsureLongPath(windows.UTF16PtrToString(&outPathBuffer[0]))
341340
if !strings.HasSuffix(targetPath, "\\") {
342341
targetPath += "\\"
343342
}

pkg/utils/utils.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,24 @@ package utils
33
import (
44
"os"
55
"os/exec"
6+
"strings"
67

78
"k8s.io/klog/v2"
89
)
910

10-
const MaxPathLengthWindows = 260
11+
const (
12+
MaxPathLengthWindows = 260
13+
14+
// LongPathPrefix is the prefix of Windows long path
15+
LongPathPrefix = `\\?\`
16+
)
17+
18+
func EnsureLongPath(path string) string {
19+
if !strings.HasPrefix(path, LongPathPrefix) {
20+
path = LongPathPrefix + path
21+
}
22+
return path
23+
}
1124

1225
func RunPowershellCmd(command string, envs ...string) ([]byte, error) {
1326
cmd := exec.Command("powershell", "-Mta", "-NoProfile", "-Command", command)

0 commit comments

Comments
 (0)