diff --git a/docs/IMPLEMENTATION.md b/docs/IMPLEMENTATION.md
index 3326f2f2..3468e3c4 100644
--- a/docs/IMPLEMENTATION.md
+++ b/docs/IMPLEMENTATION.md
@@ -121,6 +121,20 @@ func CallMethod(disp *ole.IDispatch, name string, params ...interface{}) (result
}
```
+### Association
+
+Association can be used to retrieve all instances that are associated with
+a particular source instance.
+
+There are a few Association classes in WMI.
+
+For example, association class [MSFT_PartitionToVolume](https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partitiontovolume)
+can be used to retrieve a volume (`MSFT_Volume`) from a partition (`MSFT_Partition`), and vice versa.
+
+```go
+collection, err := part.GetAssociated("MSFT_PartitionToVolume", "MSFT_Volume", "Volume", "Partition")
+```
+
## Debug with PowerShell
@@ -181,6 +195,13 @@ PS C:\Users\Administrator> $vol.FileSystem
NTFS
```
+### Association
+
+```powershell
+PS C:\Users\Administrator> $partition = (Get-CimInstance -Namespace root\Microsoft\Windows\Storage -ClassName MSFT_Partition -Filter "DiskNumber = 0")[0]
+PS C:\Users\Administrator> Get-CimAssociatedInstance -InputObject $partition -Association MSFT_PartitionToVolume
+```
+
### Call Class Method
You may get Class Methods for a single CIM class using `$class.CimClassMethods`.
diff --git a/integrationtests/iscsi_ps_scripts.go b/integrationtests/iscsi_ps_scripts.go
index 89bec253..202e390b 100644
--- a/integrationtests/iscsi_ps_scripts.go
+++ b/integrationtests/iscsi_ps_scripts.go
@@ -42,14 +42,14 @@ $ProgressPreference = "SilentlyContinue"
$targetName = "%s"
# Get local IPv4 (e.g. 10.30.1.15, not 127.0.0.1)
-$address = $(Get-NetIPAddress | Where-Object { $_.InterfaceAlias -eq "Ethernet" -and $_.AddressFamily -eq "IPv4" }).IPAddress
+$address = $(Get-NetIPAddress | Where-Object { $_.InterfaceAlias -eq "%s" -and $_.AddressFamily -eq "IPv4" }).IPAddress
# Create virtual disk in RAM
-New-IscsiVirtualDisk -Path "ramdisk:scratch-${targetName}.vhdx" -Size 100MB | Out-Null
+New-IscsiVirtualDisk -Path "ramdisk:scratch-${targetName}.vhdx" -Size 100MB -ComputerName $env:computername | Out-Null
# Create a target that allows all initiator IQNs and map a disk to the new target
-$target = New-IscsiServerTarget -TargetName $targetName -InitiatorIds @("Iqn:*")
-Add-IscsiVirtualDiskTargetMapping -TargetName $targetName -DevicePath "ramdisk:scratch-${targetName}.vhdx" | Out-Null
+$target = New-IscsiServerTarget -TargetName $targetName -InitiatorIds @("Iqn:*") -ComputerName $env:computername
+Add-IscsiVirtualDiskTargetMapping -TargetName $targetName -DevicePath "ramdisk:scratch-${targetName}.vhdx" -ComputerName $env:computername | Out-Null
$output = @{
"iqn" = "$($target.TargetIqn)"
@@ -68,7 +68,7 @@ $username = "%s"
$password = "%s"
$securestring = ConvertTo-SecureString -String $password -AsPlainText -Force
$chap = New-Object -TypeName System.Management.Automation.PSCredential -ArgumentList ($username, $securestring)
-Set-IscsiServerTarget -TargetName $targetName -EnableChap $true -Chap $chap
+Set-IscsiServerTarget -TargetName $targetName -EnableChap $true -Chap $chap -ComputerName $env:computername
`
func setChap(targetName string, username string, password string) error {
@@ -92,7 +92,7 @@ $securestring = ConvertTo-SecureString -String $password -AsPlainText -Force
# Windows initiator does not uses the username for mutual authentication
$chap = New-Object -TypeName System.Management.Automation.PSCredential -ArgumentList ($username, $securestring)
-Set-IscsiServerTarget -TargetName $targetName -EnableReverseChap $true -ReverseChap $chap
+Set-IscsiServerTarget -TargetName $targetName -EnableReverseChap $true -ReverseChap $chap -ComputerName $env:computername
`
func setReverseChap(targetName string, password string) error {
@@ -131,8 +131,8 @@ Get-IscsiTarget | Disconnect-IscsiTarget -Confirm:$false
Get-IscsiTargetPortal | Remove-IscsiTargetPortal -confirm:$false
# Clean target
-Get-IscsiServerTarget | Remove-IscsiServerTarget
-Get-IscsiVirtualDisk | Remove-IscsiVirtualDisk
+Get-IscsiServerTarget -ComputerName $env:computername | Remove-IscsiServerTarget
+Get-IscsiVirtualDisk -ComputerName $env:computername | Remove-IscsiVirtualDisk
# Stop iSCSI initiator
Get-Service "MsiSCSI" | Stop-Service
@@ -173,7 +173,12 @@ func runPowershellScript(script string) (string, error) {
}
func setupEnv(targetName string) (*IscsiSetupConfig, error) {
- script := fmt.Sprintf(IscsiEnvironmentSetupScript, targetName)
+ ethernetName := "Ethernet"
+ if val, ok := os.LookupEnv("ETHERNET_NAME"); ok {
+ ethernetName = val
+ }
+
+ script := fmt.Sprintf(IscsiEnvironmentSetupScript, targetName, ethernetName)
out, err := runPowershellScript(script)
if err != nil {
return nil, fmt.Errorf("failed setting up environment. err=%v", err)
diff --git a/pkg/cim/disk.go b/pkg/cim/disk.go
index 0b03c2ac..58c8f376 100644
--- a/pkg/cim/disk.go
+++ b/pkg/cim/disk.go
@@ -23,6 +23,17 @@ const (
// GPTPartitionTypeMicrosoftReserved is the GUID for Microsoft Reserved Partition (MSR)
// Reserved by Windows for system use
GPTPartitionTypeMicrosoftReserved = "{e3c9e316-0b5c-4db8-817d-f92df00215ae}"
+
+ // ErrorCodeCreatePartitionAccessPathAlreadyInUse is the error code (42002) returned when the driver letter failed to assign after partition created
+ ErrorCodeCreatePartitionAccessPathAlreadyInUse = 42002
+)
+
+var (
+ DiskSelectorListForDiskNumberAndLocation = []string{"Number", "Location"}
+ DiskSelectorListForPartitionStyle = []string{"PartitionStyle"}
+ DiskSelectorListForPathAndSerialNumber = []string{"Path", "SerialNumber"}
+ DiskSelectorListForIsOffline = []string{"IsOffline"}
+ DiskSelectorListForSize = []string{"Size"}
)
// QueryDiskByNumber retrieves disk information for a specific disk identified by its number.
@@ -76,3 +87,104 @@ func ListDisks(selectorList []string) ([]*storage.MSFT_Disk, error) {
return disks, nil
}
+
+// InitializeDisk initializes a RAW disk with a particular partition style.
+//
+// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/initialize-msft-disk
+// for the WMI method definition.
+func InitializeDisk(disk *storage.MSFT_Disk, partitionStyle int) (int, error) {
+ result, err := disk.InvokeMethodWithReturn("Initialize", int32(partitionStyle))
+ return int(result), err
+}
+
+// RefreshDisk Refreshes the cached disk layout information.
+//
+// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-refresh
+// for the WMI method definition.
+func RefreshDisk(disk *storage.MSFT_Disk) (int, string, error) {
+ var status string
+ result, err := disk.InvokeMethodWithReturn("Refresh", &status)
+ return int(result), status, err
+}
+
+// CreatePartition creates a partition on a disk.
+//
+// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/createpartition-msft-disk
+// for the WMI method definition.
+func CreatePartition(disk *storage.MSFT_Disk, params ...interface{}) (int, error) {
+ result, err := disk.InvokeMethodWithReturn("CreatePartition", params...)
+ return int(result), err
+}
+
+// SetDiskState takes a disk online or offline.
+//
+// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-online and
+// https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-offline
+// for the WMI method definition.
+func SetDiskState(disk *storage.MSFT_Disk, online bool) (int, string, error) {
+ method := "Offline"
+ if online {
+ method = "Online"
+ }
+
+ var status string
+ result, err := disk.InvokeMethodWithReturn(method, &status)
+ return int(result), status, err
+}
+
+// RescanDisks rescans all changes by updating the internal cache of software objects (that is, Disks, Partitions, Volumes)
+// for the storage setting.
+//
+// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-storagesetting-updatehoststoragecache
+// for the WMI method definition.
+func RescanDisks() (int, error) {
+ result, _, err := InvokeCimMethod(WMINamespaceStorage, "MSFT_StorageSetting", "UpdateHostStorageCache", nil)
+ return result, err
+}
+
+// GetDiskNumber returns the number of a disk.
+func GetDiskNumber(disk *storage.MSFT_Disk) (uint32, error) {
+ number, err := disk.GetProperty("Number")
+ if err != nil {
+ return 0, err
+ }
+ return uint32(number.(int32)), err
+}
+
+// GetDiskLocation returns the location of a disk.
+func GetDiskLocation(disk *storage.MSFT_Disk) (string, error) {
+ return disk.GetPropertyLocation()
+}
+
+// GetDiskPartitionStyle returns the partition style of a disk.
+func GetDiskPartitionStyle(disk *storage.MSFT_Disk) (int32, error) {
+ retValue, err := disk.GetProperty("PartitionStyle")
+ if err != nil {
+ return 0, err
+ }
+ return retValue.(int32), err
+}
+
+// IsDiskOffline returns whether a disk is offline.
+func IsDiskOffline(disk *storage.MSFT_Disk) (bool, error) {
+ return disk.GetPropertyIsOffline()
+}
+
+// GetDiskSize returns the size of a disk.
+func GetDiskSize(disk *storage.MSFT_Disk) (int64, error) {
+ sz, err := disk.GetProperty("Size")
+ if err != nil {
+ return -1, err
+ }
+ return strconv.ParseInt(sz.(string), 10, 64)
+}
+
+// GetDiskPath returns the path of a disk.
+func GetDiskPath(disk *storage.MSFT_Disk) (string, error) {
+ return disk.GetPropertyPath()
+}
+
+// GetDiskSerialNumber returns the serial number of a disk.
+func GetDiskSerialNumber(disk *storage.MSFT_Disk) (string, error) {
+ return disk.GetPropertySerialNumber()
+}
diff --git a/pkg/cim/iscsi.go b/pkg/cim/iscsi.go
new file mode 100644
index 00000000..3a4e2541
--- /dev/null
+++ b/pkg/cim/iscsi.go
@@ -0,0 +1,362 @@
+//go:build windows
+// +build windows
+
+package cim
+
+import (
+ "fmt"
+ "strconv"
+
+ "github.com/microsoft/wmi/pkg/base/query"
+ "github.com/microsoft/wmi/server2019/root/microsoft/windows/storage"
+)
+
+var (
+ ISCSITargetPortalDefaultSelectorList = []string{"TargetPortalAddress", "TargetPortalPortNumber"}
+)
+
+// ListISCSITargetPortals retrieves a list of iSCSI target portals.
+//
+// The equivalent WMI query is:
+//
+// SELECT [selectors] FROM MSFT_IscsiTargetPortal
+//
+// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitargetportal
+// for the WMI class definition.
+func ListISCSITargetPortals(selectorList []string) ([]*storage.MSFT_iSCSITargetPortal, error) {
+ q := query.NewWmiQueryWithSelectList("MSFT_IscsiTargetPortal", selectorList)
+ instances, err := QueryInstances(WMINamespaceStorage, q)
+ if IgnoreNotFound(err) != nil {
+ return nil, err
+ }
+
+ var targetPortals []*storage.MSFT_iSCSITargetPortal
+ for _, instance := range instances {
+ portal, err := storage.NewMSFT_iSCSITargetPortalEx1(instance)
+ if err != nil {
+ return nil, fmt.Errorf("failed to query iSCSI target portal %v. error: %v", instance, err)
+ }
+
+ targetPortals = append(targetPortals, portal)
+ }
+
+ return targetPortals, nil
+}
+
+// QueryISCSITargetPortal retrieves information about a specific iSCSI target portal
+// identified by its network address and port number.
+//
+// The equivalent WMI query is:
+//
+// SELECT [selectors] FROM MSFT_IscsiTargetPortal
+// WHERE TargetPortalAddress = '
'
+// AND TargetPortalPortNumber = ''
+//
+// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitargetportal
+// for the WMI class definition.
+func QueryISCSITargetPortal(address string, port uint32, selectorList []string) (*storage.MSFT_iSCSITargetPortal, error) {
+ portalQuery := query.NewWmiQueryWithSelectList(
+ "MSFT_iSCSITargetPortal", selectorList,
+ "TargetPortalAddress", address,
+ "TargetPortalPortNumber", strconv.Itoa(int(port)))
+ instances, err := QueryInstances(WMINamespaceStorage, portalQuery)
+ if err != nil {
+ return nil, err
+ }
+
+ targetPortal, err := storage.NewMSFT_iSCSITargetPortalEx1(instances[0])
+ if err != nil {
+ return nil, fmt.Errorf("failed to query iSCSI target portal at (%s:%d). error: %v", address, port, err)
+ }
+
+ return targetPortal, nil
+}
+
+// ListISCSITargetsByTargetPortalAddressAndPort retrieves ISCSI targets by address and port of an iSCSI target portal.
+func ListISCSITargetsByTargetPortalAddressAndPort(address string, port uint32, selectorList []string) ([]*storage.MSFT_iSCSITarget, error) {
+ instance, err := QueryISCSITargetPortal(address, port, selectorList)
+ if err != nil {
+ return nil, err
+ }
+
+ targets, err := ListISCSITargetsByTargetPortal([]*storage.MSFT_iSCSITargetPortal{instance})
+ return targets, err
+}
+
+// NewISCSITargetPortal creates a new iSCSI target portal.
+//
+// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitargetportal-new
+// for the WMI method definition.
+func NewISCSITargetPortal(targetPortalAddress string,
+ targetPortalPortNumber uint32,
+ initiatorInstanceName *string,
+ initiatorPortalAddress *string,
+ isHeaderDigest *bool,
+ isDataDigest *bool) (*storage.MSFT_iSCSITargetPortal, error) {
+ params := map[string]interface{}{
+ "TargetPortalAddress": targetPortalAddress,
+ "TargetPortalPortNumber": targetPortalPortNumber,
+ }
+ if initiatorInstanceName != nil {
+ params["InitiatorInstanceName"] = *initiatorInstanceName
+ }
+ if initiatorPortalAddress != nil {
+ params["InitiatorPortalAddress"] = *initiatorPortalAddress
+ }
+ if isHeaderDigest != nil {
+ params["IsHeaderDigest"] = *isHeaderDigest
+ }
+ if isDataDigest != nil {
+ params["IsDataDigest"] = *isDataDigest
+ }
+ result, _, err := InvokeCimMethod(WMINamespaceStorage, "MSFT_iSCSITargetPortal", "New", params)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create iSCSI target portal with %v. result: %d, error: %v", params, result, err)
+ }
+
+ return QueryISCSITargetPortal(targetPortalAddress, targetPortalPortNumber, nil)
+}
+
+// ParseISCSITargetPortal retrieves the portal address and port number of an iSCSI target portal.
+func ParseISCSITargetPortal(instance *storage.MSFT_iSCSITargetPortal) (string, uint32, error) {
+ portalAddress, err := instance.GetPropertyTargetPortalAddress()
+ if err != nil {
+ return "", 0, fmt.Errorf("failed parsing target portal address %v. err: %w", instance, err)
+ }
+
+ portalPort, err := instance.GetProperty("TargetPortalPortNumber")
+ if err != nil {
+ return "", 0, fmt.Errorf("failed parsing target portal port number %v. err: %w", instance, err)
+ }
+
+ return portalAddress, uint32(portalPort.(int32)), nil
+}
+
+// RemoveISCSITargetPortal removes an iSCSI target portal.
+//
+// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitargetportal-remove
+// for the WMI method definition.
+func RemoveISCSITargetPortal(instance *storage.MSFT_iSCSITargetPortal) (int, error) {
+ address, port, err := ParseISCSITargetPortal(instance)
+ if err != nil {
+ return 0, fmt.Errorf("failed to parse target portal %v. error: %v", instance, err)
+ }
+
+ result, err := instance.InvokeMethodWithReturn("Remove",
+ nil,
+ nil,
+ int(port),
+ address,
+ )
+ return int(result), err
+}
+
+// ListISCSITargetsByTargetPortal retrieves all iSCSI targets from the specified iSCSI target portal
+// using MSFT_iSCSITargetToiSCSITargetPortal association.
+//
+// WMI association MSFT_iSCSITargetToiSCSITargetPortal:
+//
+// iSCSITarget | iSCSITargetPortal
+// ----------- | -----------------
+// MSFT_iSCSITarget (NodeAddress = "iqn.1991-05.com.microsoft:win-8e2evaq9q...) | MSFT_iSCSITargetPortal (TargetPortalAdd...
+//
+// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitarget
+// for the WMI class definition.
+func ListISCSITargetsByTargetPortal(portals []*storage.MSFT_iSCSITargetPortal) ([]*storage.MSFT_iSCSITarget, error) {
+ var targets []*storage.MSFT_iSCSITarget
+ for _, portal := range portals {
+ collection, err := portal.GetAssociated("MSFT_iSCSITargetToiSCSITargetPortal", "MSFT_iSCSITarget", "iSCSITarget", "iSCSITargetPortal")
+ if err != nil {
+ return nil, fmt.Errorf("failed to query associated iSCSITarget for %v. error: %v", portal, err)
+ }
+
+ for _, instance := range collection {
+ target, err := storage.NewMSFT_iSCSITargetEx1(instance)
+ if err != nil {
+ return nil, fmt.Errorf("failed to query iSCSI target %v. error: %v", instance, err)
+ }
+
+ targets = append(targets, target)
+ }
+ }
+
+ return targets, nil
+}
+
+// QueryISCSITarget retrieves the iSCSI target from the specified portal address, portal and node address.
+func QueryISCSITarget(address string, port uint32, nodeAddress string) (*storage.MSFT_iSCSITarget, error) {
+ portal, err := QueryISCSITargetPortal(address, port, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ targets, err := ListISCSITargetsByTargetPortal([]*storage.MSFT_iSCSITargetPortal{portal})
+ if err != nil {
+ return nil, err
+ }
+
+ for _, target := range targets {
+ targetNodeAddress, err := GetISCSITargetNodeAddress(target)
+ if err != nil {
+ return nil, fmt.Errorf("failed to query iSCSI target %v. error: %v", target, err)
+ }
+
+ if targetNodeAddress == nodeAddress {
+ return target, nil
+ }
+ }
+
+ return nil, nil
+}
+
+// GetISCSITargetNodeAddress returns the node address of an iSCSI target.
+func GetISCSITargetNodeAddress(target *storage.MSFT_iSCSITarget) (string, error) {
+ nodeAddress, err := target.GetProperty("NodeAddress")
+ if err != nil {
+ return "", err
+ }
+
+ return nodeAddress.(string), err
+}
+
+// IsISCSITargetConnected returns whether the iSCSI target is connected.
+func IsISCSITargetConnected(target *storage.MSFT_iSCSITarget) (bool, error) {
+ return target.GetPropertyIsConnected()
+}
+
+// QueryISCSISessionByTarget retrieves the iSCSI session from the specified iSCSI target
+// using MSFT_iSCSITargetToiSCSISession association.
+//
+// WMI association MSFT_iSCSITargetToiSCSISession:
+//
+// iSCSISession | iSCSITarget
+// ------------ | -----------
+// MSFT_iSCSISession (SessionIdentifier = "ffffac0cacbff010-4000013700000016") | MSFT_iSCSITarget (NodeAddress = "iqn.199...
+//
+// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsisession
+// for the WMI class definition.
+func QueryISCSISessionByTarget(target *storage.MSFT_iSCSITarget) (*storage.MSFT_iSCSISession, error) {
+ collection, err := target.GetAssociated("MSFT_iSCSITargetToiSCSISession", "MSFT_iSCSISession", "iSCSISession", "iSCSITarget")
+ if err != nil {
+ return nil, fmt.Errorf("failed to query associated iSCSISession for %v. error: %v", target, err)
+ }
+
+ if len(collection) == 0 {
+ return nil, nil
+ }
+
+ session, err := storage.NewMSFT_iSCSISessionEx1(collection[0])
+ return session, err
+}
+
+// UnregisterISCSISession unregisters the iSCSI session so that it is no longer persistent.
+//
+// Refer https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsisession-unregister
+// for the WMI method definition.
+func UnregisterISCSISession(session *storage.MSFT_iSCSISession) (int, error) {
+ result, err := session.InvokeMethodWithReturn("Unregister")
+ return int(result), err
+}
+
+// SetISCSISessionChapSecret sets a CHAP secret key for use with iSCSI initiator connections.
+//
+// Refer https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitarget-disconnect
+// for the WMI method definition.
+func SetISCSISessionChapSecret(mutualChapSecret string) (int, error) {
+ result, _, err := InvokeCimMethod(WMINamespaceStorage, "MSFT_iSCSISession", "SetCHAPSecret", map[string]interface{}{"ChapSecret": mutualChapSecret})
+ return result, err
+}
+
+// GetISCSISessionIdentifier returns the identifier of an iSCSI session.
+func GetISCSISessionIdentifier(session *storage.MSFT_iSCSISession) (string, error) {
+ return session.GetPropertySessionIdentifier()
+}
+
+// IsISCSISessionPersistent returns whether an iSCSI session is persistent.
+func IsISCSISessionPersistent(session *storage.MSFT_iSCSISession) (bool, error) {
+ return session.GetPropertyIsPersistent()
+}
+
+// ListDisksByTarget find all disks associated with an iSCSITarget.
+// It finds out the iSCSIConnections from MSFT_iSCSITargetToiSCSIConnection association,
+// then locate MSFT_Disk objects from MSFT_iSCSIConnectionToDisk association.
+//
+// WMI association MSFT_iSCSITargetToiSCSIConnection:
+//
+// iSCSIConnection | iSCSITarget
+// --------------- | -----------
+// MSFT_iSCSIConnection (ConnectionIdentifier = "ffffac0cacbff010-15") | MSFT_iSCSITarget (NodeAddress = "iqn.1991-05.com...
+//
+// WMI association MSFT_iSCSIConnectionToDisk:
+//
+// Disk | iSCSIConnection
+// ---- | ---------------
+// MSFT_Disk (ObjectId = "{1}\\WIN-8E2EVAQ9QSB\root/Microsoft/Win...) | MSFT_iSCSIConnection (ConnectionIdentifier = "fff...
+//
+// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsiconnection
+// for the WMI class definition.
+func ListDisksByTarget(target *storage.MSFT_iSCSITarget) ([]*storage.MSFT_Disk, error) {
+ // list connections to the given iSCSI target
+ collection, err := target.GetAssociated("MSFT_iSCSITargetToiSCSIConnection", "MSFT_iSCSIConnection", "iSCSIConnection", "iSCSITarget")
+ if err != nil {
+ return nil, fmt.Errorf("failed to query associated iSCSISession for %v. error: %v", target, err)
+ }
+
+ if len(collection) == 0 {
+ return nil, nil
+ }
+
+ var result []*storage.MSFT_Disk
+ for _, conn := range collection {
+ instances, err := conn.GetAssociated("MSFT_iSCSIConnectionToDisk", "MSFT_Disk", "Disk", "iSCSIConnection")
+ if err != nil {
+ return nil, fmt.Errorf("failed to query associated disk for %v. error: %v", target, err)
+ }
+
+ for _, instance := range instances {
+ disk, err := storage.NewMSFT_DiskEx1(instance)
+ if err != nil {
+ return nil, fmt.Errorf("failed to query associated disk %v. error: %v", instance, err)
+ }
+
+ result = append(result, disk)
+ }
+ }
+
+ return result, err
+}
+
+// ConnectISCSITarget establishes a connection to an iSCSI target with optional CHAP authentication credential.
+//
+// Refer https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitarget-connect
+// for the WMI method definition.
+func ConnectISCSITarget(portalAddress string, portalPortNumber uint32, nodeAddress string, authType string, chapUsername *string, chapSecret *string) (int, error) {
+ inParams := map[string]interface{}{
+ "NodeAddress": nodeAddress,
+ "TargetPortalAddress": portalAddress,
+ "TargetPortalPortNumber": int(portalPortNumber),
+ "AuthenticationType": authType,
+ }
+ // InitiatorPortalAddress
+ // IsDataDigest
+ // IsHeaderDigest
+ // ReportToPnP
+ if chapUsername != nil {
+ inParams["ChapUsername"] = *chapUsername
+ }
+ if chapSecret != nil {
+ inParams["ChapSecret"] = *chapSecret
+ }
+
+ result, _, err := InvokeCimMethod(WMINamespaceStorage, "MSFT_iSCSITarget", "Connect", inParams)
+ return result, err
+}
+
+// DisconnectISCSITarget disconnects the specified session between an iSCSI initiator and an iSCSI target.
+//
+// Refer https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitarget-disconnect
+// for the WMI method definition.
+func DisconnectISCSITarget(target *storage.MSFT_iSCSITarget, sessionIdentifier string) (int, error) {
+ result, err := target.InvokeMethodWithReturn("Disconnect", sessionIdentifier)
+ return int(result), err
+}
diff --git a/pkg/cim/smb.go b/pkg/cim/smb.go
index 5868d456..2850ab78 100644
--- a/pkg/cim/smb.go
+++ b/pkg/cim/smb.go
@@ -4,6 +4,8 @@
package cim
import (
+ "strings"
+
"github.com/microsoft/wmi/pkg/base/query"
cim "github.com/microsoft/wmi/pkg/wmiinstance"
)
@@ -17,8 +19,24 @@ const (
SmbMappingStatusConnecting
SmbMappingStatusReconnecting
SmbMappingStatusUnavailable
+
+ credentialDelimiter = ":"
)
+// escapeQueryParameter escapes a parameter for WMI Queries
+func escapeQueryParameter(s string) string {
+ s = strings.ReplaceAll(s, "'", "''")
+ s = strings.ReplaceAll(s, "\\", "\\\\")
+ return s
+}
+
+func escapeUserName(userName string) string {
+ // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170
+ userName = strings.ReplaceAll(userName, "\\", "\\\\")
+ userName = strings.ReplaceAll(userName, credentialDelimiter, "\\"+credentialDelimiter)
+ return userName
+}
+
// QuerySmbGlobalMappingByRemotePath retrieves the SMB global mapping from its remote path.
//
// The equivalent WMI query is:
@@ -28,7 +46,7 @@ const (
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
// for the WMI class definition.
func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, error) {
- smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath)
+ smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath))
instances, err := QueryInstances(WMINamespaceSmb, smbQuery)
if err != nil {
return nil, err
@@ -37,12 +55,22 @@ func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, err
return instances[0], err
}
-// RemoveSmbGlobalMappingByRemotePath removes a SMB global mapping matching to the remote path.
+// GetSmbGlobalMappingStatus returns the status of an SMB global mapping.
+func GetSmbGlobalMappingStatus(inst *cim.WmiInstance) (int32, error) {
+ statusProp, err := inst.GetProperty("Status")
+ if err != nil {
+ return SmbMappingStatusUnavailable, err
+ }
+
+ return statusProp.(int32), nil
+}
+
+// RemoveSmbGlobalMappingByRemotePath removes an SMB global mapping matching to the remote path.
//
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
// for the WMI class definition.
func RemoveSmbGlobalMappingByRemotePath(remotePath string) error {
- smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath)
+ smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath))
instances, err := QueryInstances(WMINamespaceSmb, smbQuery)
if err != nil {
return err
@@ -51,3 +79,22 @@ func RemoveSmbGlobalMappingByRemotePath(remotePath string) error {
_, err = instances[0].InvokeMethod("Remove", true)
return err
}
+
+// NewSmbGlobalMapping creates a new SMB global mapping to the remote path.
+//
+// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
+// for the WMI class definition.
+func NewSmbGlobalMapping(remotePath, username, password string, requirePrivacy bool) (int, error) {
+ params := map[string]interface{}{
+ "RemotePath": remotePath,
+ "RequirePrivacy": requirePrivacy,
+ }
+ if username != "" {
+ // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178
+ // on how SMB credential is handled in PowerShell
+ params["Credential"] = escapeUserName(username) + credentialDelimiter + password
+ }
+
+ result, _, err := InvokeCimMethod(WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params)
+ return result, err
+}
diff --git a/pkg/cim/system.go b/pkg/cim/system.go
index 3ab32af6..80181794 100644
--- a/pkg/cim/system.go
+++ b/pkg/cim/system.go
@@ -10,6 +10,22 @@ import (
"github.com/microsoft/wmi/server2019/root/cimv2"
)
+var (
+ BIOSSelectorList = []string{"SerialNumber"}
+ ServiceSelectorList = []string{"DisplayName", "State", "StartMode"}
+)
+
+type ServiceInterface interface {
+ GetPropertyName() (string, error)
+ GetPropertyDisplayName() (string, error)
+ GetPropertyState() (string, error)
+ GetPropertyStartMode() (string, error)
+ GetDependents() ([]ServiceInterface, error)
+ StartService() (result uint32, err error)
+ StopService() (result uint32, err error)
+ Refresh() error
+}
+
// QueryBIOSElement retrieves the BIOS element.
//
// The equivalent WMI query is:
@@ -33,6 +49,11 @@ func QueryBIOSElement(selectorList []string) (*cimv2.CIM_BIOSElement, error) {
return bios, err
}
+// GetBIOSSerialNumber returns the BIOS serial number.
+func GetBIOSSerialNumber(bios *cimv2.CIM_BIOSElement) (string, error) {
+ return bios.GetPropertySerialNumber()
+}
+
// QueryServiceByName retrieves a specific service by its name.
//
// The equivalent WMI query is:
@@ -55,3 +76,60 @@ func QueryServiceByName(name string, selectorList []string) (*cimv2.Win32_Servic
return service, err
}
+
+// GetServiceName returns the name of a service.
+func GetServiceName(service ServiceInterface) (string, error) {
+ return service.GetPropertyName()
+}
+
+// GetServiceDisplayName returns the display name of a service.
+func GetServiceDisplayName(service ServiceInterface) (string, error) {
+ return service.GetPropertyDisplayName()
+}
+
+// GetServiceState returns the state of a service.
+func GetServiceState(service ServiceInterface) (string, error) {
+ return service.GetPropertyState()
+}
+
+// GetServiceStartMode returns the start mode of a service.
+func GetServiceStartMode(service ServiceInterface) (string, error) {
+ return service.GetPropertyStartMode()
+}
+
+// Win32Service wraps the WMI class Win32_Service (mainly for testing)
+type Win32Service struct {
+ *cimv2.Win32_Service
+}
+
+func (s *Win32Service) GetDependents() ([]ServiceInterface, error) {
+ collection, err := s.GetAssociated("Win32_DependentService", "Win32_Service", "Dependent", "Antecedent")
+ if err != nil {
+ return nil, err
+ }
+
+ var result []ServiceInterface
+ for _, coll := range collection {
+ service, err := cimv2.NewWin32_ServiceEx1(coll)
+ if err != nil {
+ return nil, err
+ }
+
+ result = append(result, &Win32Service{
+ service,
+ })
+ }
+ return result, nil
+}
+
+type Win32ServiceFactory struct {
+}
+
+func (impl Win32ServiceFactory) GetService(name string) (ServiceInterface, error) {
+ service, err := QueryServiceByName(name, ServiceSelectorList)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Win32Service{Win32_Service: service}, nil
+}
diff --git a/pkg/cim/volume.go b/pkg/cim/volume.go
index 085289ba..0c880fe1 100644
--- a/pkg/cim/volume.go
+++ b/pkg/cim/volume.go
@@ -7,10 +7,23 @@ import (
"fmt"
"strconv"
+ "github.com/go-ole/go-ole"
"github.com/microsoft/wmi/pkg/base/query"
"github.com/microsoft/wmi/pkg/errors"
- cim "github.com/microsoft/wmi/pkg/wmiinstance"
"github.com/microsoft/wmi/server2019/root/microsoft/windows/storage"
+ "k8s.io/klog/v2"
+)
+
+const (
+ FileSystemUnknown = 0
+)
+
+var (
+ VolumeSelectorListForFileSystemType = []string{"FileSystemType"}
+ VolumeSelectorListForStats = []string{"UniqueId", "SizeRemaining", "Size"}
+ VolumeSelectorListUniqueID = []string{"UniqueId"}
+
+ PartitionSelectorListObjectID = []string{"ObjectId"}
)
// QueryVolumeByUniqueID retrieves a specific volume by its unique identifier,
@@ -79,6 +92,68 @@ func ListVolumes(selectorList []string) ([]*storage.MSFT_Volume, error) {
return volumes, nil
}
+// FormatVolume formats the specified volume.
+//
+// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/format-msft-volume
+// for the WMI method definition.
+func FormatVolume(volume *storage.MSFT_Volume, params ...interface{}) (int, error) {
+ result, err := volume.InvokeMethodWithReturn("Format", params...)
+ return int(result), err
+}
+
+// FlushVolume flushes the cached data in the volume's file system to disk.
+//
+// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-volume-flush
+// for the WMI method definition.
+func FlushVolume(volume *storage.MSFT_Volume) (int, error) {
+ result, err := volume.Flush()
+ return int(result), err
+}
+
+// GetVolumeUniqueID returns the unique ID (object ID) of a volume.
+func GetVolumeUniqueID(volume *storage.MSFT_Volume) (string, error) {
+ return volume.GetPropertyUniqueId()
+}
+
+// GetVolumeFileSystemType returns the file system type of a volume.
+func GetVolumeFileSystemType(volume *storage.MSFT_Volume) (int32, error) {
+ fsType, err := volume.GetProperty("FileSystemType")
+ if err != nil {
+ return 0, err
+ }
+ return fsType.(int32), nil
+}
+
+// GetVolumeSize returns the size of a volume.
+func GetVolumeSize(volume *storage.MSFT_Volume) (int64, error) {
+ volumeSizeVal, err := volume.GetProperty("Size")
+ if err != nil {
+ return -1, err
+ }
+
+ volumeSize, err := strconv.ParseInt(volumeSizeVal.(string), 10, 64)
+ if err != nil {
+ return -1, err
+ }
+
+ return volumeSize, err
+}
+
+// GetVolumeSizeRemaining returns the remaining size of a volume.
+func GetVolumeSizeRemaining(volume *storage.MSFT_Volume) (int64, error) {
+ volumeSizeRemainingVal, err := volume.GetProperty("SizeRemaining")
+ if err != nil {
+ return -1, err
+ }
+
+ volumeSizeRemaining, err := strconv.ParseInt(volumeSizeRemainingVal.(string), 10, 64)
+ if err != nil {
+ return -1, err
+ }
+
+ return volumeSizeRemaining, err
+}
+
// ListPartitionsOnDisk retrieves all partitions or a partition with the specified number on a disk.
//
// The equivalent WMI query is:
@@ -129,131 +204,84 @@ func ListPartitionsWithFilters(selectorList []string, filters ...*query.WmiQuery
return partitions, nil
}
-// ListPartitionToVolumeMappings builds a mapping between partition and volume with partition Object ID as the key.
-//
-// The equivalent WMI query is:
-//
-// SELECT [selectors] FROM MSFT_PartitionToVolume
-//
-// Partition | Volume
-// --------- | ------
-// MSFT_Partition (ObjectId = "{1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Win...) | MSFT_Volume (ObjectId = "{1}\\WIN-8E2EVAQ9QS...
-//
-// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partitiontovolume
-// for the WMI class definition.
-func ListPartitionToVolumeMappings() (map[string]string, error) {
- return ListWMIInstanceMappings(WMINamespaceStorage, "MSFT_PartitionToVolume", nil,
- mappingObjectRefIndexer("Partition", "MSFT_Partition", "ObjectId"),
- mappingObjectRefIndexer("Volume", "MSFT_Volume", "ObjectId"),
- )
-}
-
-// ListVolumeToPartitionMappings builds a mapping between volume and partition with volume Object ID as the key.
-//
-// The equivalent WMI query is:
+// FindPartitionsByVolume finds all partitions associated with the given volumes
+// using MSFT_PartitionToVolume association.
//
-// SELECT [selectors] FROM MSFT_PartitionToVolume
+// WMI association MSFT_PartitionToVolume:
//
-// Partition | Volume
-// --------- | ------
-// MSFT_Partition (ObjectId = "{1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Win...) | MSFT_Volume (ObjectId = "{1}\\WIN-8E2EVAQ9QS...
+// Partition | Volume
+// --------- | ------
+// MSFT_Partition (ObjectId = "{1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Win...) | MSFT_Volume (ObjectId = "{1}\\WIN-8E2EVAQ9QS...
//
// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partitiontovolume
// for the WMI class definition.
-func ListVolumeToPartitionMappings() (map[string]string, error) {
- return ListWMIInstanceMappings(WMINamespaceStorage, "MSFT_PartitionToVolume", nil,
- mappingObjectRefIndexer("Volume", "MSFT_Volume", "ObjectId"),
- mappingObjectRefIndexer("Partition", "MSFT_Partition", "ObjectId"),
- )
-}
-
-// FindPartitionsByVolume finds all partitions associated with the given volumes
-// using partition-to-volume mapping.
-func FindPartitionsByVolume(partitions []*storage.MSFT_Partition, volumes []*storage.MSFT_Volume) ([]*storage.MSFT_Partition, error) {
- var partitionInstances []*cim.WmiInstance
- for _, part := range partitions {
- partitionInstances = append(partitionInstances, part.WmiInstance)
- }
-
- var volumeInstances []*cim.WmiInstance
- for _, volume := range volumes {
- volumeInstances = append(volumeInstances, volume.WmiInstance)
- }
-
- partitionToVolumeMappings, err := ListPartitionToVolumeMappings()
- if err != nil {
- return nil, err
- }
-
- filtered, err := FindInstancesByObjectIDMapping(partitionInstances, volumeInstances, partitionToVolumeMappings)
- if err != nil {
- return nil, err
- }
-
+func FindPartitionsByVolume(volumes []*storage.MSFT_Volume) ([]*storage.MSFT_Partition, error) {
var result []*storage.MSFT_Partition
- for _, instance := range filtered {
- part, err := storage.NewMSFT_PartitionEx1(instance)
+ for _, vol := range volumes {
+ collection, err := vol.GetAssociated("MSFT_PartitionToVolume", "MSFT_Partition", "Partition", "Volume")
if err != nil {
- return nil, fmt.Errorf("failed to query partition %v. error: %v", instance, err)
+ return nil, fmt.Errorf("failed to query associated partition for %v. error: %v", vol, err)
}
- result = append(result, part)
+ for _, instance := range collection {
+ part, err := storage.NewMSFT_PartitionEx1(instance)
+ if err != nil {
+ return nil, fmt.Errorf("failed to query partition %v. error: %v", instance, err)
+ }
+
+ result = append(result, part)
+ }
}
return result, nil
}
// FindVolumesByPartition finds all volumes associated with the given partitions
-// using volume-to-partition mapping.
-func FindVolumesByPartition(volumes []*storage.MSFT_Volume, partitions []*storage.MSFT_Partition) ([]*storage.MSFT_Volume, error) {
- var volumeInstances []*cim.WmiInstance
- for _, volume := range volumes {
- volumeInstances = append(volumeInstances, volume.WmiInstance)
- }
-
- var partitionInstances []*cim.WmiInstance
- for _, part := range partitions {
- partitionInstances = append(partitionInstances, part.WmiInstance)
- }
-
- volumeToPartitionMappings, err := ListVolumeToPartitionMappings()
- if err != nil {
- return nil, err
- }
-
- filtered, err := FindInstancesByObjectIDMapping(volumeInstances, partitionInstances, volumeToPartitionMappings)
- if err != nil {
- return nil, err
- }
-
+// using MSFT_PartitionToVolume association.
+//
+// WMI association MSFT_PartitionToVolume:
+//
+// Partition | Volume
+// --------- | ------
+// MSFT_Partition (ObjectId = "{1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Win...) | MSFT_Volume (ObjectId = "{1}\\WIN-8E2EVAQ9QS...
+//
+// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partitiontovolume
+// for the WMI class definition.
+func FindVolumesByPartition(partitions []*storage.MSFT_Partition) ([]*storage.MSFT_Volume, error) {
var result []*storage.MSFT_Volume
- for _, instance := range filtered {
- volume, err := storage.NewMSFT_VolumeEx1(instance)
+ for _, part := range partitions {
+ collection, err := part.GetAssociated("MSFT_PartitionToVolume", "MSFT_Volume", "Volume", "Partition")
if err != nil {
- return nil, fmt.Errorf("failed to query volume %v. error: %v", instance, err)
+ return nil, fmt.Errorf("failed to query associated volumes for %v. error: %v", part, err)
}
- result = append(result, volume)
+ for _, instance := range collection {
+ volume, err := storage.NewMSFT_VolumeEx1(instance)
+ if err != nil {
+ return nil, fmt.Errorf("failed to query volume %v. error: %v", instance, err)
+ }
+
+ result = append(result, volume)
+ }
}
return result, nil
}
// GetPartitionByVolumeUniqueID retrieves a specific partition from a volume identified by its unique ID.
-func GetPartitionByVolumeUniqueID(volumeID string, partitionSelectorList []string) (*storage.MSFT_Partition, error) {
+func GetPartitionByVolumeUniqueID(volumeID string) (*storage.MSFT_Partition, error) {
volume, err := QueryVolumeByUniqueID(volumeID, []string{"ObjectId"})
if err != nil {
return nil, err
}
- partitions, err := ListPartitionsWithFilters(partitionSelectorList)
+ result, err := FindPartitionsByVolume([]*storage.MSFT_Volume{volume})
if err != nil {
return nil, err
}
- result, err := FindPartitionsByVolume(partitions, []*storage.MSFT_Volume{volume})
- if err != nil {
- return nil, err
+ if len(result) == 0 {
+ return nil, errors.NotFound
}
return result[0], nil
@@ -269,12 +297,7 @@ func GetVolumeByDriveLetter(driveLetter string, partitionSelectorList []string)
return nil, err
}
- volumes, err := ListVolumes(partitionSelectorList)
- if err != nil {
- return nil, err
- }
-
- result, err := FindVolumesByPartition(volumes, partitions)
+ result, err := FindVolumesByPartition(partitions)
if err != nil {
return nil, err
}
@@ -298,3 +321,78 @@ func GetPartitionDiskNumber(part *storage.MSFT_Partition) (uint32, error) {
return uint32(diskNumber.(int32)), nil
}
+
+// SetPartitionState takes a partition online or offline.
+//
+// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-online and
+// https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-offline
+// for the WMI method definition.
+func SetPartitionState(part *storage.MSFT_Partition, online bool) (int, string, error) {
+ method := "Offline"
+ if online {
+ method = "Online"
+ }
+
+ var status string
+ result, err := part.InvokeMethodWithReturn(method, &status)
+ return int(result), status, err
+}
+
+// GetPartitionSupportedSize retrieves the minimum and maximum sizes that the partition can be resized to using the ResizePartition method.
+//
+// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-getsupportedsizes
+// for the WMI method definition.
+func GetPartitionSupportedSize(part *storage.MSFT_Partition) (result int, sizeMin, sizeMax int64, status string, err error) {
+ sizeMin = -1
+ sizeMax = -1
+
+ var sizeMinVar, sizeMaxVar ole.VARIANT
+ invokeResult, err := part.InvokeMethodWithReturn("GetSupportedSize", &sizeMinVar, &sizeMaxVar, &status)
+ if invokeResult != 0 || err != nil {
+ result = int(invokeResult)
+ }
+ klog.V(5).Infof("got sizeMin (%v) sizeMax (%v) from partition (%v), status: %s", sizeMinVar, sizeMaxVar, part, status)
+
+ sizeMin, err = strconv.ParseInt(sizeMinVar.ToString(), 10, 64)
+ if err != nil {
+ return
+ }
+
+ sizeMax, err = strconv.ParseInt(sizeMaxVar.ToString(), 10, 64)
+ return
+}
+
+// ResizePartition resizes a partition.
+//
+// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-resize
+// for the WMI method definition.
+func ResizePartition(part *storage.MSFT_Partition, size int64) (int, string, error) {
+ var status string
+ result, err := part.InvokeMethodWithReturn("Resize", strconv.Itoa(int(size)), &status)
+ return int(result), status, err
+}
+
+// GetPartitionSize returns the size of a partition.
+func GetPartitionSize(part *storage.MSFT_Partition) (int64, error) {
+ sizeProp, err := part.GetProperty("Size")
+ if err != nil {
+ return -1, err
+ }
+
+ size, err := strconv.ParseInt(sizeProp.(string), 10, 64)
+ if err != nil {
+ return -1, err
+ }
+
+ return size, err
+}
+
+// FilterForPartitionOnDisk creates a WMI query filter to query a disk by its number.
+func FilterForPartitionOnDisk(diskNumber uint32) *query.WmiQueryFilter {
+ return query.NewWmiQueryFilter("DiskNumber", strconv.Itoa(int(diskNumber)), query.Equals)
+}
+
+// FilterForPartitionsOfTypeNormal creates a WMI query filter for all non-reserved partitions.
+func FilterForPartitionsOfTypeNormal() *query.WmiQueryFilter {
+ return query.NewWmiQueryFilter("GptType", GPTPartitionTypeMicrosoftReserved, query.NotEquals)
+}
diff --git a/pkg/cim/wmi.go b/pkg/cim/wmi.go
index 81e17701..ec9c8f08 100644
--- a/pkg/cim/wmi.go
+++ b/pkg/cim/wmi.go
@@ -4,33 +4,32 @@
package cim
import (
+ "errors"
"fmt"
- "strings"
+ "runtime"
"github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
"github.com/microsoft/wmi/pkg/base/query"
- "github.com/microsoft/wmi/pkg/errors"
+ wmierrors "github.com/microsoft/wmi/pkg/errors"
cim "github.com/microsoft/wmi/pkg/wmiinstance"
+ "golang.org/x/sys/windows"
"k8s.io/klog/v2"
)
const (
- WMINamespaceRoot = "Root\\CimV2"
+ WMINamespaceCimV2 = "Root\\CimV2"
WMINamespaceStorage = "Root\\Microsoft\\Windows\\Storage"
WMINamespaceSmb = "Root\\Microsoft\\Windows\\Smb"
)
type InstanceHandler func(instance *cim.WmiInstance) (bool, error)
-// An InstanceIndexer provides index key to a WMI Instance in a map
-type InstanceIndexer func(instance *cim.WmiInstance) (string, error)
-
// NewWMISession creates a new local WMI session for the given namespace, defaulting
// to root namespace if none specified.
func NewWMISession(namespace string) (*cim.WmiSession, error) {
if namespace == "" {
- namespace = WMINamespaceRoot
+ namespace = WMINamespaceCimV2
}
sessionManager := cim.NewWmiSessionManager()
@@ -65,7 +64,7 @@ func QueryFromWMI(namespace string, query *query.WmiQuery, handler InstanceHandl
}
if len(instances) == 0 {
- return errors.NotFound
+ return wmierrors.NotFound
}
var cont bool
@@ -99,7 +98,7 @@ func executeClassMethodParam(classInst *cim.WmiInstance, method *cim.WmiMethod,
iDispatchInstance := classInst.GetIDispatch()
if iDispatchInstance == nil {
- return nil, errors.Wrapf(errors.InvalidInput, "InvalidInstance")
+ return nil, wmierrors.Wrapf(wmierrors.InvalidInput, "InvalidInstance")
}
rawResult, err := iDispatchInstance.GetProperty("Methods_")
if err != nil {
@@ -239,130 +238,52 @@ func InvokeCimMethod(namespace, class, methodName string, inputParameters map[st
return int(result.ReturnValue), outputParameters, nil
}
+// IsNotFound returns true if it's a "not found" error.
+func IsNotFound(err error) bool {
+ return wmierrors.IsNotFound(err)
+}
+
// IgnoreNotFound returns nil if the error is nil or a "not found" error,
// otherwise returns the original error.
func IgnoreNotFound(err error) error {
- if err == nil || errors.IsNotFound(err) {
+ if err == nil || IsNotFound(err) {
return nil
}
return err
}
-// parseObjectRef extracts the object ID from a WMI object reference string.
-// The result string is in this format
-// {1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Windows/Storage/Providers_v2\WSP_Partition.ObjectId="{b65bb3cd-da86-11ee-854b-806e6f6e6963}:PR:{00000000-0000-0000-0000-100000000000}\\?\scsi#disk&ven_vmware&prod_virtual_disk#4&2c28f6c4&0&000000#{53f56307-b6bf-11d0-94f2-00a0c91efb8b}"
-// from an escape string
-func parseObjectRef(input, objectClass, refName string) (string, error) {
- tokens := strings.Split(input, fmt.Sprintf("%s.%s=", objectClass, refName))
- if len(tokens) < 2 {
- return "", fmt.Errorf("invalid object ID value: %s", input)
- }
-
- objectID := tokens[1]
- objectID = strings.ReplaceAll(objectID, "\\\"", "\"")
- objectID = strings.ReplaceAll(objectID, "\\\\", "\\")
- objectID = objectID[1 : len(objectID)-1]
- return objectID, nil
-}
-
-// ListWMIInstanceMappings queries WMI instances and creates a map using custom indexing functions
-// to extract keys and values from each instance.
-func ListWMIInstanceMappings(namespace, mappingClassName string, selectorList []string, keyIndexer InstanceIndexer, valueIndexer InstanceIndexer) (map[string]string, error) {
- q := query.NewWmiQueryWithSelectList(mappingClassName, selectorList)
- mappingInstances, err := QueryInstances(namespace, q)
- if err != nil {
- return nil, err
- }
-
- result := make(map[string]string)
- for _, mapping := range mappingInstances {
- key, err := keyIndexer(mapping)
- if err != nil {
- return nil, err
- }
-
- value, err := valueIndexer(mapping)
- if err != nil {
- return nil, err
- }
-
- result[key] = value
- }
-
- return result, nil
-}
-
-// FindInstancesByMapping filters instances based on a mapping relationship,
-// matching instances through custom indexing and mapping functions.
-func FindInstancesByMapping(instanceToFind []*cim.WmiInstance, instanceToFindIndex InstanceIndexer, associatedInstances []*cim.WmiInstance, associatedInstanceIndexer InstanceIndexer, instanceMappings map[string]string) ([]*cim.WmiInstance, error) {
- associatedInstanceObjectIDMapping := map[string]*cim.WmiInstance{}
- for _, inst := range associatedInstances {
- key, err := associatedInstanceIndexer(inst)
- if err != nil {
- return nil, err
- }
-
- associatedInstanceObjectIDMapping[key] = inst
- }
-
- var filtered []*cim.WmiInstance
- for _, inst := range instanceToFind {
- key, err := instanceToFindIndex(inst)
- if err != nil {
- return nil, err
- }
-
- valueObjectID, ok := instanceMappings[key]
- if !ok {
- continue
- }
-
- _, ok = associatedInstanceObjectIDMapping[strings.ToUpper(valueObjectID)]
- if !ok {
- continue
- }
- filtered = append(filtered, inst)
- }
-
- if len(filtered) == 0 {
- return nil, errors.NotFound
- }
-
- return filtered, nil
-}
-
-// mappingObjectRefIndexer indexes an WMI object by the Object ID reference from a specified property.
-func mappingObjectRefIndexer(propertyName, className, refName string) InstanceIndexer {
- return func(instance *cim.WmiInstance) (string, error) {
- valueVal, err := instance.GetProperty(propertyName)
- if err != nil {
- return "", err
+// WithCOMThread runs the given function `fn` on a locked OS thread
+// with COM initialized using COINIT_MULTITHREADED.
+//
+// This is necessary for using COM/OLE APIs directly (e.g., via go-ole),
+// because COM requires that initialization and usage occur on the same thread.
+//
+// It performs the following steps:
+// - Locks the current goroutine to its OS thread
+// - Calls ole.CoInitializeEx with COINIT_MULTITHREADED
+// - Executes the user-provided function
+// - Uninitializes COM
+// - Unlocks the thread
+//
+// If COM initialization fails, or if the user's function returns an error,
+// that error is returned by WithCOMThread.
+func WithCOMThread(fn func() error) error {
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+
+ if err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil {
+ var oleError *ole.OleError
+ if errors.As(err, &oleError) && oleError != nil && oleError.Code() == uintptr(windows.S_FALSE) {
+ klog.V(10).Infof("COM library has been already initialized for the calling thread, proceeding to the function with no error")
+ err = nil
}
-
- refValue, err := parseObjectRef(valueVal.(string), className, refName)
- return strings.ToUpper(refValue), err
- }
-}
-
-// stringPropertyIndexer indexes a WMI object from a string property.
-func stringPropertyIndexer(propertyName string) InstanceIndexer {
- return func(instance *cim.WmiInstance) (string, error) {
- valueVal, err := instance.GetProperty(propertyName)
if err != nil {
- return "", err
+ return err
}
-
- return strings.ToUpper(valueVal.(string)), err
+ } else {
+ klog.V(10).Infof("COM library is initialized for the calling thread")
}
-}
-
-var (
- // objectIDPropertyIndexer indexes a WMI object from its ObjectId property.
- objectIDPropertyIndexer = stringPropertyIndexer("ObjectId")
-)
+ defer ole.CoUninitialize()
-// FindInstancesByObjectIDMapping filters instances based on ObjectId mapping
-// between two sets of WMI instances.
-func FindInstancesByObjectIDMapping(instanceToFind []*cim.WmiInstance, associatedInstances []*cim.WmiInstance, instanceMappings map[string]string) ([]*cim.WmiInstance, error) {
- return FindInstancesByMapping(instanceToFind, objectIDPropertyIndexer, associatedInstances, objectIDPropertyIndexer, instanceMappings)
+ return fn()
}
diff --git a/pkg/os/disk/api.go b/pkg/os/disk/api.go
index dc8637fd..5b46992c 100644
--- a/pkg/os/disk/api.go
+++ b/pkg/os/disk/api.go
@@ -3,14 +3,12 @@ package disk
import (
"encoding/hex"
"fmt"
- "strconv"
"strings"
"syscall"
"unsafe"
"github.com/kubernetes-csi/csi-proxy/pkg/cim"
shared "github.com/kubernetes-csi/csi-proxy/pkg/shared/disk"
- "github.com/microsoft/wmi/pkg/base/query"
"k8s.io/klog/v2"
)
@@ -66,153 +64,162 @@ func New() DiskAPI {
// ListDiskLocations - constructs a map with the disk number as the key and the DiskLocation structure
// as the value. The DiskLocation struct has various fields like the Adapter, Bus, Target and LUNID.
func (imp DiskAPI) ListDiskLocations() (map[uint32]shared.DiskLocation, error) {
- // "location": "PCI Slot 3 : Adapter 0 : Port 0 : Target 1 : LUN 0"
- disks, err := cim.ListDisks([]string{"Number", "Location"})
- if err != nil {
- return nil, fmt.Errorf("could not query disk locations")
- }
-
m := make(map[uint32]shared.DiskLocation)
- for _, disk := range disks {
- num, err := disk.GetProperty("Number")
+ err := cim.WithCOMThread(func() error {
+ // "location": "PCI Slot 3 : Adapter 0 : Port 0 : Target 1 : LUN 0"
+ disks, err := cim.ListDisks(cim.DiskSelectorListForDiskNumberAndLocation)
if err != nil {
- return m, fmt.Errorf("failed to query disk number: %v, %w", disk, err)
+ return fmt.Errorf("could not query disk locations")
}
- location, err := disk.GetPropertyLocation()
- if err != nil {
- return m, fmt.Errorf("failed to query disk location: %v, %w", disk, err)
- }
+ for _, disk := range disks {
+ num, err := cim.GetDiskNumber(disk)
+ if err != nil {
+ return fmt.Errorf("failed to query disk number: %v, %w", disk, err)
+ }
+
+ location, err := cim.GetDiskLocation(disk)
+ if err != nil {
+ return fmt.Errorf("failed to query disk location: %v, %w", disk, err)
+ }
- found := false
- s := strings.Split(location, ":")
- if len(s) >= 5 {
- var d shared.DiskLocation
- for _, item := range s {
- item = strings.TrimSpace(item)
- itemSplit := strings.Split(item, " ")
- if len(itemSplit) == 2 {
- found = true
- switch strings.TrimSpace(itemSplit[0]) {
- case "Adapter":
- d.Adapter = strings.TrimSpace(itemSplit[1])
- case "Target":
- d.Target = strings.TrimSpace(itemSplit[1])
- case "LUN":
- d.LUNID = strings.TrimSpace(itemSplit[1])
- default:
- klog.Warningf("Got unknown field : %s=%s", itemSplit[0], itemSplit[1])
+ found := false
+ s := strings.Split(location, ":")
+ if len(s) >= 5 {
+ var d shared.DiskLocation
+ for _, item := range s {
+ item = strings.TrimSpace(item)
+ itemSplit := strings.Split(item, " ")
+ if len(itemSplit) == 2 {
+ found = true
+ switch strings.TrimSpace(itemSplit[0]) {
+ case "Adapter":
+ d.Adapter = strings.TrimSpace(itemSplit[1])
+ case "Target":
+ d.Target = strings.TrimSpace(itemSplit[1])
+ case "LUN":
+ d.LUNID = strings.TrimSpace(itemSplit[1])
+ default:
+ klog.Warningf("Got unknown field : %s=%s", itemSplit[0], itemSplit[1])
+ }
}
}
- }
- if found {
- m[uint32(num.(int32))] = d
+ if found {
+ m[num] = d
+ }
}
}
- }
- return m, nil
+ return nil
+ })
+ return m, err
}
func (imp DiskAPI) Rescan() error {
- result, _, err := cim.InvokeCimMethod(cim.WMINamespaceStorage, "MSFT_StorageSetting", "UpdateHostStorageCache", nil)
- if err != nil {
- return fmt.Errorf("error updating host storage cache output. result: %d, err: %v", result, err)
- }
- return nil
+ return cim.WithCOMThread(func() error {
+ result, err := cim.RescanDisks()
+ if err != nil {
+ return fmt.Errorf("error updating host storage cache output. result: %d, err: %v", result, err)
+ }
+ return nil
+ })
}
func (imp DiskAPI) IsDiskInitialized(diskNumber uint32) (bool, error) {
var partitionStyle int32
- disk, err := cim.QueryDiskByNumber(diskNumber, []string{"PartitionStyle"})
- if err != nil {
- return false, fmt.Errorf("error checking initialized status of disk %d. %v", diskNumber, err)
- }
+ err := cim.WithCOMThread(func() error {
+ disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForPartitionStyle)
+ if err != nil {
+ return fmt.Errorf("error checking initialized status of disk %d: %v", diskNumber, err)
+ }
- retValue, err := disk.GetProperty("PartitionStyle")
- if err != nil {
- return false, fmt.Errorf("failed to query partition style of disk %d: %w", diskNumber, err)
- }
+ partitionStyle, err = cim.GetDiskPartitionStyle(disk)
+ if err != nil {
+ return fmt.Errorf("failed to query partition style of disk %d: %v", diskNumber, err)
+ }
- partitionStyle = retValue.(int32)
- return partitionStyle != cim.PartitionStyleUnknown, nil
+ return nil
+ })
+ return partitionStyle != cim.PartitionStyleUnknown, err
}
func (imp DiskAPI) InitializeDisk(diskNumber uint32) error {
- disk, err := cim.QueryDiskByNumber(diskNumber, nil)
- if err != nil {
- return fmt.Errorf("failed to initializing disk %d. error: %w", diskNumber, err)
- }
+ return cim.WithCOMThread(func() error {
+ disk, err := cim.QueryDiskByNumber(diskNumber, nil)
+ if err != nil {
+ return fmt.Errorf("failed to initializing disk %d. error: %w", diskNumber, err)
+ }
- result, err := disk.InvokeMethodWithReturn("Initialize", int32(cim.PartitionStyleGPT))
- if result != 0 || err != nil {
- return fmt.Errorf("failed to initializing disk %d: result %d, error: %w", diskNumber, result, err)
- }
+ result, err := cim.InitializeDisk(disk, cim.PartitionStyleGPT)
+ if result != 0 || err != nil {
+ return fmt.Errorf("failed to initializing disk %d: result %d, error: %w", diskNumber, result, err)
+ }
- return nil
+ return nil
+ })
}
func (imp DiskAPI) BasicPartitionsExist(diskNumber uint32) (bool, error) {
- partitions, err := cim.ListPartitionsWithFilters(nil,
- query.NewWmiQueryFilter("DiskNumber", strconv.Itoa(int(diskNumber)), query.Equals),
- query.NewWmiQueryFilter("GptType", cim.GPTPartitionTypeMicrosoftReserved, query.NotEquals))
- if cim.IgnoreNotFound(err) != nil {
- return false, fmt.Errorf("error checking presence of partitions on disk %d:, %v", diskNumber, err)
- }
+ var exist bool
+ err := cim.WithCOMThread(func() error {
+ partitions, err := cim.ListPartitionsWithFilters(nil, cim.FilterForPartitionOnDisk(diskNumber), cim.FilterForPartitionsOfTypeNormal())
+ if cim.IgnoreNotFound(err) != nil {
+ return fmt.Errorf("error checking presence of partitions on disk %d:, %v", diskNumber, err)
+ }
- return len(partitions) > 0, nil
+ exist = len(partitions) > 0
+ return nil
+ })
+ return exist, err
}
func (imp DiskAPI) CreateBasicPartition(diskNumber uint32) error {
- disk, err := cim.QueryDiskByNumber(diskNumber, nil)
- if err != nil {
- return err
- }
+ return cim.WithCOMThread(func() error {
+ disk, err := cim.QueryDiskByNumber(diskNumber, nil)
+ if err != nil {
+ return err
+ }
- result, err := disk.InvokeMethodWithReturn(
- "CreatePartition",
- nil, // Size
- true, // UseMaximumSize
- nil, // Offset
- nil, // Alignment
- nil, // DriveLetter
- false, // AssignDriveLetter
- nil, // MbrType,
- cim.GPTPartitionTypeBasicData, // GPT Type
- false, // IsHidden
- false, // IsActive,
- )
- // 42002 is returned by driver letter failed to assign after partition
- if (result != 0 && result != 42002) || err != nil {
- return fmt.Errorf("error creating partition on disk %d. result: %d, err: %v", diskNumber, result, err)
- }
+ result, err := cim.CreatePartition(
+ disk,
+ nil, // Size
+ true, // UseMaximumSize
+ nil, // Offset
+ nil, // Alignment
+ nil, // DriveLetter
+ false, // AssignDriveLetter
+ nil, // MbrType,
+ cim.GPTPartitionTypeBasicData, // GPT Type
+ false, // IsHidden
+ false, // IsActive,
+ )
+ if (result != 0 && result != cim.ErrorCodeCreatePartitionAccessPathAlreadyInUse) || err != nil {
+ return fmt.Errorf("error creating partition on disk %d. result: %d, err: %v", diskNumber, result, err)
+ }
- var status string
- result, err = disk.InvokeMethodWithReturn("Refresh", &status)
- if result != 0 || err != nil {
- return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err)
- }
+ result, _, err = cim.RefreshDisk(disk)
+ if result != 0 || err != nil {
+ return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err)
+ }
- partitions, err := cim.ListPartitionsWithFilters(nil,
- query.NewWmiQueryFilter("DiskNumber", strconv.Itoa(int(diskNumber)), query.Equals),
- query.NewWmiQueryFilter("GptType", cim.GPTPartitionTypeMicrosoftReserved, query.NotEquals))
- if err != nil {
- return fmt.Errorf("error query basic partition on disk %d:, %v", diskNumber, err)
- }
+ partitions, err := cim.ListPartitionsWithFilters(nil, cim.FilterForPartitionOnDisk(diskNumber), cim.FilterForPartitionsOfTypeNormal())
+ if err != nil {
+ return fmt.Errorf("error query basic partition on disk %d:, %v", diskNumber, err)
+ }
- if len(partitions) == 0 {
- return fmt.Errorf("failed to create basic partition on disk %d:, %v", diskNumber, err)
- }
+ if len(partitions) == 0 {
+ return fmt.Errorf("failed to create basic partition on disk %d:, %v", diskNumber, err)
+ }
- partition := partitions[0]
- result, err = partition.InvokeMethodWithReturn("Online", status)
- if result != 0 || err != nil {
- return fmt.Errorf("error bring partition %v on disk %d online. result: %d, status %s, err: %v", partition, diskNumber, result, status, err)
- }
+ partition := partitions[0]
+ result, status, err := cim.SetPartitionState(partition, true)
+ if result != 0 || err != nil {
+ return fmt.Errorf("error bring partition %v on disk %d online. result: %d, status %s, err: %v", partition, diskNumber, result, status, err)
+ }
- err = partition.Refresh()
- return err
+ return nil
+ })
}
func (imp DiskAPI) GetDiskNumberByName(page83ID string) (uint32, error) {
@@ -272,28 +279,33 @@ func (imp DiskAPI) GetDiskPage83ID(disk syscall.Handle) (string, error) {
}
func (imp DiskAPI) GetDiskNumberWithID(page83ID string) (uint32, error) {
- disks, err := cim.ListDisks([]string{"Path", "SerialNumber"})
- if err != nil {
- return 0, err
- }
-
- for _, disk := range disks {
- path, err := disk.GetPropertyPath()
+ var diskNumberResult uint32
+ err := cim.WithCOMThread(func() error {
+ disks, err := cim.ListDisks(cim.DiskSelectorListForPathAndSerialNumber)
if err != nil {
- return 0, fmt.Errorf("failed to query disk path: %v, %w", disk, err)
+ return err
}
- diskNumber, diskPage83ID, err := imp.GetDiskNumberAndPage83ID(path)
- if err != nil {
- return 0, err
- }
+ for _, disk := range disks {
+ path, err := cim.GetDiskPath(disk)
+ if err != nil {
+ return fmt.Errorf("failed to query disk path: %v, %w", disk, err)
+ }
- if diskPage83ID == page83ID {
- return diskNumber, nil
+ diskNumber, diskPage83ID, err := imp.GetDiskNumberAndPage83ID(path)
+ if err != nil {
+ return err
+ }
+
+ if diskPage83ID == page83ID {
+ diskNumberResult = diskNumber
+ return nil
+ }
}
- }
- return 0, fmt.Errorf("could not find disk with Page83 ID %s", page83ID)
+ return fmt.Errorf("could not find disk with Page83 ID %s", page83ID)
+ })
+ return diskNumberResult, err
}
func (imp DiskAPI) GetDiskNumberAndPage83ID(path string) (uint32, string, error) {
@@ -319,91 +331,98 @@ func (imp DiskAPI) GetDiskNumberAndPage83ID(path string) (uint32, string, error)
// ListDiskIDs - constructs a map with the disk number as the key and the DiskID structure
// as the value. The DiskID struct has a field for the page83 ID.
func (imp DiskAPI) ListDiskIDs() (map[uint32]shared.DiskIDs, error) {
- disks, err := cim.ListDisks([]string{"Path", "SerialNumber"})
- if err != nil {
- return nil, err
- }
-
m := make(map[uint32]shared.DiskIDs)
- for _, disk := range disks {
- path, err := disk.GetPropertyPath()
+ err := cim.WithCOMThread(func() error {
+ disks, err := cim.ListDisks(cim.DiskSelectorListForPathAndSerialNumber)
if err != nil {
- return m, fmt.Errorf("failed to query disk path: %v, %w", disk, err)
+ return err
}
- sn, err := disk.GetPropertySerialNumber()
- if err != nil {
- return m, fmt.Errorf("failed to query disk serial number: %v, %w", disk, err)
- }
+ for _, disk := range disks {
+ path, err := cim.GetDiskPath(disk)
+ if err != nil {
+ return fmt.Errorf("failed to query disk path: %v, %w", disk, err)
+ }
- diskNumber, page83, err := imp.GetDiskNumberAndPage83ID(path)
- if err != nil {
- return m, err
- }
+ sn, err := cim.GetDiskSerialNumber(disk)
+ if err != nil {
+ return fmt.Errorf("failed to query disk serial number: %v, %w", disk, err)
+ }
+
+ diskNumber, page83, err := imp.GetDiskNumberAndPage83ID(path)
+ if err != nil {
+ return err
+ }
- m[diskNumber] = shared.DiskIDs{
- Page83: page83,
- SerialNumber: sn,
+ m[diskNumber] = shared.DiskIDs{
+ Page83: page83,
+ SerialNumber: sn,
+ }
}
- }
- return m, nil
+
+ return nil
+ })
+ return m, err
}
func (imp DiskAPI) GetDiskStats(diskNumber uint32) (int64, error) {
// TODO: change to uint64 as it does not make sense to use int64 for size
- var size int64
- disk, err := cim.QueryDiskByNumber(diskNumber, []string{"Size"})
- if err != nil {
- return -1, err
- }
+ size := int64(-1)
+ err := cim.WithCOMThread(func() error {
+ disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForSize)
+ if err != nil {
+ return err
+ }
- sz, err := disk.GetProperty("Size")
- if err != nil {
- return -1, fmt.Errorf("failed to query size of disk %d. %v", diskNumber, err)
- }
+ size, err = cim.GetDiskSize(disk)
+ if err != nil {
+ return fmt.Errorf("failed to query size of disk %d. %v", diskNumber, err)
+ }
- size, err = strconv.ParseInt(sz.(string), 10, 64)
+ return nil
+ })
return size, err
}
func (imp DiskAPI) SetDiskState(diskNumber uint32, isOnline bool) error {
- disk, err := cim.QueryDiskByNumber(diskNumber, []string{"IsOffline"})
- if err != nil {
- return err
- }
-
- offline, err := disk.GetPropertyIsOffline()
- if err != nil {
- return fmt.Errorf("error setting disk %d attach state. error: %v", diskNumber, err)
- }
+ return cim.WithCOMThread(func() error {
+ disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForIsOffline)
+ if err != nil {
+ return err
+ }
- if isOnline == !offline {
- return nil
- }
+ isOffline, err := cim.IsDiskOffline(disk)
+ if err != nil {
+ return fmt.Errorf("error setting disk %d attach state. error: %v", diskNumber, err)
+ }
- method := "Offline"
- if isOnline {
- method = "Online"
- }
+ if isOnline == !isOffline {
+ return nil
+ }
- result, err := disk.InvokeMethodWithReturn(method)
- if result != 0 || err != nil {
- return fmt.Errorf("setting disk %d attach state %s: result %d, error: %w", diskNumber, method, result, err)
- }
+ result, _, err := cim.SetDiskState(disk, isOnline)
+ if result != 0 || err != nil {
+ return fmt.Errorf("setting disk %d attach state (isOnline: %v): result %d, error: %w", diskNumber, isOnline, result, err)
+ }
- return nil
+ return nil
+ })
}
func (imp DiskAPI) GetDiskState(diskNumber uint32) (bool, error) {
- disk, err := cim.QueryDiskByNumber(diskNumber, []string{"IsOffline"})
- if err != nil {
- return false, err
- }
+ var isOffline bool
+ err := cim.WithCOMThread(func() error {
+ disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForIsOffline)
+ if err != nil {
+ return err
+ }
- isOffline, err := disk.GetPropertyIsOffline()
- if err != nil {
- return false, fmt.Errorf("error parsing disk %d state. error: %v", diskNumber, err)
- }
+ isOffline, err = cim.IsDiskOffline(disk)
+ if err != nil {
+ return fmt.Errorf("error parsing disk %d state. error: %v", diskNumber, err)
+ }
- return !isOffline, nil
+ return nil
+ })
+ return !isOffline, err
}
diff --git a/pkg/os/filesystem/api.go b/pkg/os/filesystem/api.go
index f93fb2e9..458fd89f 100644
--- a/pkg/os/filesystem/api.go
+++ b/pkg/os/filesystem/api.go
@@ -4,7 +4,6 @@ import (
"fmt"
"os"
"path/filepath"
- "strings"
"github.com/kubernetes-csi/csi-proxy/pkg/utils"
)
@@ -49,17 +48,6 @@ func (filesystemAPI) PathExists(path string) (bool, error) {
return pathExists(path)
}
-func pathValid(path string) (bool, error) {
- cmd := `Test-Path $Env:remotepath`
- cmdEnv := fmt.Sprintf("remotepath=%s", path)
- output, err := utils.RunPowershellCmd(cmd, cmdEnv)
- if err != nil {
- return false, fmt.Errorf("returned output: %s, error: %v", string(output), err)
- }
-
- return strings.HasPrefix(strings.ToLower(string(output)), "true"), nil
-}
-
// PathValid determines whether all elements of a path exist
//
// https://docs.microsoft.com/en-us/powershell/module/microsoft.powershell.management/test-path?view=powershell-7
@@ -68,7 +56,7 @@ func pathValid(path string) (bool, error) {
//
// e.g. in a SMB server connection, if password is changed, connection will be lost, this func will return false
func (filesystemAPI) PathValid(path string) (bool, error) {
- return pathValid(path)
+ return utils.IsPathValid(path)
}
// Mkdir makes a dir with `os.MkdirAll`.
@@ -124,13 +112,18 @@ func (filesystemAPI) IsSymlink(tgt string) (bool, error) {
// This code is similar to k8s.io/kubernetes/pkg/util/mount except the pathExists usage.
// Also in a remote call environment the os error cannot be passed directly back, hence the callers
// are expected to perform the isExists check before calling this call in CSI proxy.
- stat, err := os.Lstat(tgt)
+ isSymlink, err := utils.IsPathSymlink(tgt)
+ if err != nil {
+ return false, err
+ }
+
+ // mounted folder created by SetVolumeMountPoint may still report ModeSymlink == 0
+ mountedFolder, err := utils.IsMountedFolder(tgt)
if err != nil {
return false, err
}
- // If its a link and it points to an existing file then its a mount point.
- if stat.Mode()&os.ModeSymlink != 0 {
+ if isSymlink || mountedFolder {
target, err := os.Readlink(tgt)
if err != nil {
return false, fmt.Errorf("readlink error: %v", err)
diff --git a/pkg/os/iscsi/api.go b/pkg/os/iscsi/api.go
index 559ed3b5..af8050e1 100644
--- a/pkg/os/iscsi/api.go
+++ b/pkg/os/iscsi/api.go
@@ -1,10 +1,12 @@
package iscsi
import (
- "encoding/json"
"fmt"
+ "strconv"
+ "strings"
- "github.com/kubernetes-csi/csi-proxy/pkg/utils"
+ "github.com/kubernetes-csi/csi-proxy/pkg/cim"
+ "k8s.io/klog/v2"
)
// Implements the iSCSI OS API calls. All code here should be very simple
@@ -19,157 +21,209 @@ func New() APIImplementor {
}
func (APIImplementor) AddTargetPortal(portal *TargetPortal) error {
- cmdLine := fmt.Sprintf(
- `New-IscsiTargetPortal -TargetPortalAddress ${Env:iscsi_tp_address} ` +
- `-TargetPortalPortNumber ${Env:iscsi_tp_port}`)
- out, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("iscsi_tp_address=%s", portal.Address),
- fmt.Sprintf("iscsi_tp_port=%d", portal.Port))
- if err != nil {
- return fmt.Errorf("error adding target portal. cmd %s, output: %s, err: %v", cmdLine, string(out), err)
- }
-
- return nil
+ return cim.WithCOMThread(func() error {
+ existing, err := cim.QueryISCSITargetPortal(portal.Address, portal.Port, nil)
+ if cim.IgnoreNotFound(err) != nil {
+ return err
+ }
+
+ if existing != nil {
+ klog.V(2).Infof("target portal at (%s:%d) already exists", portal.Address, portal.Port)
+ return nil
+ }
+
+ _, err = cim.NewISCSITargetPortal(portal.Address, portal.Port, nil, nil, nil, nil)
+ if err != nil {
+ return fmt.Errorf("error adding target portal at (%s:%d). err: %v", portal.Address, portal.Port, err)
+ }
+
+ return nil
+ })
}
func (APIImplementor) DiscoverTargetPortal(portal *TargetPortal) ([]string, error) {
- // ConvertTo-Json is not part of the pipeline because powershell converts an
- // array with one element to a single element
- cmdLine := fmt.Sprintf(
- `ConvertTo-Json -InputObject @(Get-IscsiTargetPortal -TargetPortalAddress ` +
- `${Env:iscsi_tp_address} -TargetPortalPortNumber ${Env:iscsi_tp_port} | ` +
- `Get-IscsiTarget | Select-Object -ExpandProperty NodeAddress)`)
- out, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("iscsi_tp_address=%s", portal.Address),
- fmt.Sprintf("iscsi_tp_port=%d", portal.Port))
- if err != nil {
- return nil, fmt.Errorf("error discovering target portal. cmd: %s, output: %s, err: %w", cmdLine, string(out), err)
- }
-
var iqns []string
- err = json.Unmarshal(out, &iqns)
- if err != nil {
- return nil, fmt.Errorf("failed parsing iqn list. cmd: %s output: %s, err: %w", cmdLine, string(out), err)
- }
-
- return iqns, nil
+ err := cim.WithCOMThread(func() error {
+ targets, err := cim.ListISCSITargetsByTargetPortalAddressAndPort(portal.Address, portal.Port, nil)
+ if err != nil {
+ return err
+ }
+
+ for _, target := range targets {
+ iqn, err := cim.GetISCSITargetNodeAddress(target)
+ if err != nil {
+ return fmt.Errorf("failed parsing node address of target %v to target portal at (%s:%d). err: %w", target, portal.Address, portal.Port, err)
+ }
+
+ iqns = append(iqns, iqn)
+ }
+
+ return nil
+ })
+ return iqns, err
}
func (APIImplementor) ListTargetPortals() ([]TargetPortal, error) {
- cmdLine := fmt.Sprintf(
- `ConvertTo-Json -InputObject @(Get-IscsiTargetPortal | ` +
- `Select-Object TargetPortalAddress, TargetPortalPortNumber)`)
-
- out, err := utils.RunPowershellCmd(cmdLine)
- if err != nil {
- return nil, fmt.Errorf("error listing target portals. cmd %s, output: %s, err: %w", cmdLine, string(out), err)
- }
-
var portals []TargetPortal
- err = json.Unmarshal(out, &portals)
- if err != nil {
- return nil, fmt.Errorf("failed parsing target portal list. cmd: %s output: %s, err: %w", cmdLine, string(out), err)
- }
-
- return portals, nil
+ err := cim.WithCOMThread(func() error {
+ instances, err := cim.ListISCSITargetPortals(cim.ISCSITargetPortalDefaultSelectorList)
+ if err != nil {
+ return err
+ }
+
+ for _, instance := range instances {
+ address, port, err := cim.ParseISCSITargetPortal(instance)
+ if err != nil {
+ return fmt.Errorf("failed parsing target portal %v. err: %w", instance, err)
+ }
+
+ portals = append(portals, TargetPortal{
+ Address: address,
+ Port: port,
+ })
+ }
+
+ return nil
+ })
+ return portals, err
}
func (APIImplementor) RemoveTargetPortal(portal *TargetPortal) error {
- cmdLine := fmt.Sprintf(
- `Get-IscsiTargetPortal -TargetPortalAddress ${Env:iscsi_tp_address} ` +
- `-TargetPortalPortNumber ${Env:iscsi_tp_port} | Remove-IscsiTargetPortal ` +
- `-Confirm:$false`)
-
- out, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("iscsi_tp_address=%s", portal.Address),
- fmt.Sprintf("iscsi_tp_port=%d", portal.Port))
- if err != nil {
- return fmt.Errorf("error removing target portal. cmd %s, output: %s, err: %w", cmdLine, string(out), err)
- }
-
- return nil
+ return cim.WithCOMThread(func() error {
+ instance, err := cim.QueryISCSITargetPortal(portal.Address, portal.Port, nil)
+ if err != nil {
+ return err
+ }
+
+ result, err := cim.RemoveISCSITargetPortal(instance)
+ if result != 0 || err != nil {
+ return fmt.Errorf("error removing target portal at (%s:%d). result: %d, err: %w", portal.Address, portal.Port, result, err)
+ }
+
+ return nil
+ })
}
-func (APIImplementor) ConnectTarget(portal *TargetPortal, iqn string,
- authType string, chapUser string, chapSecret string) error {
- // Not using InputObject as Connect-IscsiTarget's InputObject does not work.
- // This is due to being a static WMI method together with a bug in the
- // powershell version of the API.
- cmdLine := fmt.Sprintf(
- `Connect-IscsiTarget -TargetPortalAddress ${Env:iscsi_tp_address}` +
- ` -TargetPortalPortNumber ${Env:iscsi_tp_port} -NodeAddress ${Env:iscsi_target_iqn}` +
- ` -AuthenticationType ${Env:iscsi_auth_type}`)
-
- if chapUser != "" {
- cmdLine += ` -ChapUsername ${Env:iscsi_chap_user}`
- }
-
- if chapSecret != "" {
- cmdLine += ` -ChapSecret ${Env:iscsi_chap_secret}`
- }
-
- out, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("iscsi_tp_address=%s", portal.Address),
- fmt.Sprintf("iscsi_tp_port=%d", portal.Port),
- fmt.Sprintf("iscsi_target_iqn=%s", iqn),
- fmt.Sprintf("iscsi_auth_type=%s", authType),
- fmt.Sprintf("iscsi_chap_user=%s", chapUser),
- fmt.Sprintf("iscsi_chap_secret=%s", chapSecret))
- if err != nil {
- return fmt.Errorf("error connecting to target portal. cmd %s, output: %s, err: %w", cmdLine, string(out), err)
- }
-
- return nil
+func (APIImplementor) ConnectTarget(portal *TargetPortal, iqn string, authType string, chapUser string, chapSecret string) error {
+ return cim.WithCOMThread(func() error {
+ target, err := cim.QueryISCSITarget(portal.Address, portal.Port, iqn)
+ if err != nil {
+ return err
+ }
+
+ connected, err := cim.IsISCSITargetConnected(target)
+ if err != nil {
+ return err
+ }
+
+ if connected {
+ klog.V(2).Infof("target %s from target portal at (%s:%d) is connected.", iqn, portal.Address, portal.Port)
+ return nil
+ }
+
+ targetAuthType := strings.ToUpper(strings.ReplaceAll(authType, "_", ""))
+
+ result, err := cim.ConnectISCSITarget(portal.Address, portal.Port, iqn, targetAuthType, &chapUser, &chapSecret)
+ if err != nil {
+ return fmt.Errorf("error connecting to target portal. result: %d, err: %w", result, err)
+ }
+
+ return nil
+ })
}
func (APIImplementor) DisconnectTarget(portal *TargetPortal, iqn string) error {
- // Using InputObject instead of pipe to verify input is not empty
- cmdLine := fmt.Sprintf(
- `Disconnect-IscsiTarget -InputObject (Get-IscsiTargetPortal ` +
- `-TargetPortalAddress ${Env:iscsi_tp_address} -TargetPortalPortNumber ${Env:iscsi_tp_port} ` +
- ` | Get-IscsiTarget | Where-Object { $_.NodeAddress -eq ${Env:iscsi_target_iqn} }) ` +
- `-Confirm:$false`)
-
- out, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("iscsi_tp_address=%s", portal.Address),
- fmt.Sprintf("iscsi_tp_port=%d", portal.Port),
- fmt.Sprintf("iscsi_target_iqn=%s", iqn))
- if err != nil {
- return fmt.Errorf("error disconnecting from target portal. cmd %s, output: %s, err: %w", cmdLine, string(out), err)
- }
-
- return nil
+ return cim.WithCOMThread(func() error {
+ target, err := cim.QueryISCSITarget(portal.Address, portal.Port, iqn)
+ if err != nil {
+ return err
+ }
+
+ connected, err := cim.IsISCSITargetConnected(target)
+ if err != nil {
+ return fmt.Errorf("error query connected of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err)
+ }
+
+ if !connected {
+ klog.V(2).Infof("target %s from target portal at (%s:%d) is not connected.", iqn, portal.Address, portal.Port)
+ return nil
+ }
+
+ // get session
+ session, err := cim.QueryISCSISessionByTarget(target)
+ if err != nil {
+ return fmt.Errorf("error query session of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err)
+ }
+
+ sessionIdentifier, err := cim.GetISCSISessionIdentifier(session)
+ if err != nil {
+ return fmt.Errorf("error query session identifier of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err)
+ }
+
+ persistent, err := cim.IsISCSISessionPersistent(session)
+ if err != nil {
+ return fmt.Errorf("error query session persistency of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err)
+ }
+
+ if persistent {
+ result, err := cim.UnregisterISCSISession(session)
+ if err != nil {
+ return fmt.Errorf("error unregister session on target %s from target portal at (%s:%d). result: %d, err: %w", iqn, portal.Address, portal.Port, result, err)
+ }
+ }
+
+ result, err := cim.DisconnectISCSITarget(target, sessionIdentifier)
+ if err != nil {
+ return fmt.Errorf("error disconnecting target %s from target portal at (%s:%d). result: %d, err: %w", iqn, portal.Address, portal.Port, result, err)
+ }
+
+ return nil
+ })
}
func (APIImplementor) GetTargetDisks(portal *TargetPortal, iqn string) ([]string, error) {
- // Converting DiskNumber to string for compatibility with disk api group
- // Not using pipeline in order to validate that items are non-empty
- cmdLine := fmt.Sprintf(
- `$ErrorActionPreference = "Stop"; ` +
- `$tp = Get-IscsiTargetPortal -TargetPortalAddress ${Env:iscsi_tp_address} -TargetPortalPortNumber ${Env:iscsi_tp_port}; ` +
- `$t = $tp | Get-IscsiTarget | Where-Object { $_.NodeAddress -eq ${Env:iscsi_target_iqn} }; ` +
- `$c = Get-IscsiConnection -IscsiTarget $t; ` +
- `$ids = $c | Get-Disk | Select -ExpandProperty Number | Out-String -Stream; ` +
- `ConvertTo-Json -InputObject @($ids)`)
-
- out, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("iscsi_tp_address=%s", portal.Address),
- fmt.Sprintf("iscsi_tp_port=%d", portal.Port),
- fmt.Sprintf("iscsi_target_iqn=%s", iqn))
- if err != nil {
- return nil, fmt.Errorf("error getting target disks. cmd %s, output: %s, err: %w", cmdLine, string(out), err)
- }
-
var ids []string
- err = json.Unmarshal(out, &ids)
- if err != nil {
- return nil, fmt.Errorf("error parsing iqn target disks. cmd: %s output: %s, err: %w", cmdLine, string(out), err)
- }
-
- return ids, nil
+ err := cim.WithCOMThread(func() error {
+ target, err := cim.QueryISCSITarget(portal.Address, portal.Port, iqn)
+ if err != nil {
+ return err
+ }
+
+ connected, err := cim.IsISCSITargetConnected(target)
+ if err != nil {
+ return fmt.Errorf("error query connected of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err)
+ }
+
+ if !connected {
+ klog.V(2).Infof("target %s from target portal at (%s:%d) is not connected.", iqn, portal.Address, portal.Port)
+ return nil
+ }
+
+ disks, err := cim.ListDisksByTarget(target)
+ if err != nil {
+ return fmt.Errorf("error getting target disks on target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err)
+ }
+
+ for _, disk := range disks {
+ number, err := cim.GetDiskNumber(disk)
+ if err != nil {
+ return fmt.Errorf("error getting number of disk %v on target %s from target portal at (%s:%d). err: %w", disk, iqn, portal.Address, portal.Port, err)
+ }
+
+ ids = append(ids, strconv.Itoa(int(number)))
+ }
+ return nil
+ })
+ return ids, err
}
func (APIImplementor) SetMutualChapSecret(mutualChapSecret string) error {
- cmdLine := `Set-IscsiChapSecret -ChapSecret ${Env:iscsi_mutual_chap_secret}`
- out, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("iscsi_mutual_chap_secret=%s", mutualChapSecret))
- if err != nil {
- return fmt.Errorf("error setting mutual chap secret. cmd %s,"+
- " output: %s, err: %v", cmdLine, string(out), err)
- }
-
- return nil
+ return cim.WithCOMThread(func() error {
+ result, err := cim.SetISCSISessionChapSecret(mutualChapSecret)
+ if err != nil {
+ return fmt.Errorf("error setting mutual chap secret. result: %d, err: %v", result, err)
+ }
+
+ return nil
+ })
}
diff --git a/pkg/os/smb/api.go b/pkg/os/smb/api.go
index 20b9544e..9e60c9dd 100644
--- a/pkg/os/smb/api.go
+++ b/pkg/os/smb/api.go
@@ -3,15 +3,9 @@ package smb
import (
"fmt"
"strings"
- "syscall"
"github.com/kubernetes-csi/csi-proxy/pkg/cim"
"github.com/kubernetes-csi/csi-proxy/pkg/utils"
- "golang.org/x/sys/windows"
-)
-
-const (
- credentialDelimiter = ":"
)
type API interface {
@@ -33,61 +27,28 @@ func New(requirePrivacy bool) *SmbAPI {
}
}
-func remotePathForQuery(remotePath string) string {
- return strings.ReplaceAll(remotePath, "\\", "\\\\")
-}
-
-func escapeUserName(userName string) string {
- // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170
- escaped := strings.ReplaceAll(userName, "\\", "\\\\")
- escaped = strings.ReplaceAll(escaped, credentialDelimiter, "\\"+credentialDelimiter)
- return escaped
-}
-
-func createSymlink(link, target string, isDir bool) error {
- linkPtr, err := syscall.UTF16PtrFromString(link)
- if err != nil {
- return err
- }
- targetPtr, err := syscall.UTF16PtrFromString(target)
- if err != nil {
- return err
- }
-
- var flags uint32
- if isDir {
- flags = windows.SYMBOLIC_LINK_FLAG_DIRECTORY
- }
-
- err = windows.CreateSymbolicLink(
- linkPtr,
- targetPtr,
- flags,
- )
- return err
-}
-
func (*SmbAPI) IsSmbMapped(remotePath string) (bool, error) {
- inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePathForQuery(remotePath))
- if err != nil {
- return false, cim.IgnoreNotFound(err)
- }
-
- status, err := inst.GetProperty("Status")
- if err != nil {
- return false, err
- }
-
- return status.(int32) == cim.SmbMappingStatusOK, nil
+ var isMapped bool
+ err := cim.WithCOMThread(func() error {
+ inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePath)
+ if err != nil {
+ return err
+ }
+
+ status, err := cim.GetSmbGlobalMappingStatus(inst)
+ if err != nil {
+ return err
+ }
+
+ isMapped = status == cim.SmbMappingStatusOK
+ return nil
+ })
+ return isMapped, cim.IgnoreNotFound(err)
}
// NewSmbLink - creates a directory symbolic link to the remote share.
// The os.Symlink was having issue for cases where the destination was an SMB share - the container
-// runtime would complain stating "Access Denied". Because of this, we had to perform
-// this operation with powershell commandlet creating an directory softlink.
-// Since os.Symlink is currently being used in working code paths, no attempt is made in
-// alpha to merge the paths.
-// TODO (for beta release): Merge the link paths - os.Symlink and Powershell link path.
+// runtime would complain stating "Access Denied".
func (*SmbAPI) NewSmbLink(remotePath, localPath string) error {
if !strings.HasSuffix(remotePath, "\\") {
// Golang has issues resolving paths mapped to file shares if they do not end in a trailing \
@@ -97,7 +58,7 @@ func (*SmbAPI) NewSmbLink(remotePath, localPath string) error {
longRemotePath := utils.EnsureLongPath(remotePath)
longLocalPath := utils.EnsureLongPath(localPath)
- err := createSymlink(longLocalPath, longRemotePath, true)
+ err := utils.CreateSymlink(longLocalPath, longRemotePath, true)
if err != nil {
return fmt.Errorf("error linking %s to %s. err: %v", remotePath, localPath, err)
}
@@ -106,29 +67,21 @@ func (*SmbAPI) NewSmbLink(remotePath, localPath string) error {
}
func (api *SmbAPI) NewSmbGlobalMapping(remotePath, username, password string) error {
- params := map[string]interface{}{
- "RemotePath": remotePath,
- "RequirePrivacy": api.RequirePrivacy,
- }
- if username != "" {
- // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178
- // on how SMB credential is handled in PowerShell
- params["Credential"] = escapeUserName(username) + credentialDelimiter + password
- }
-
- result, _, err := cim.InvokeCimMethod(cim.WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params)
- if err != nil {
- return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err)
- }
-
- return nil
+ return cim.WithCOMThread(func() error {
+ result, err := cim.NewSmbGlobalMapping(remotePath, username, password, api.RequirePrivacy)
+ if err != nil {
+ return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err)
+ }
+ return nil
+ })
}
func (*SmbAPI) RemoveSmbGlobalMapping(remotePath string) error {
- err := cim.RemoveSmbGlobalMappingByRemotePath(remotePathForQuery(remotePath))
- if err != nil {
- return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err)
- }
-
- return nil
+ return cim.WithCOMThread(func() error {
+ err := cim.RemoveSmbGlobalMappingByRemotePath(remotePath)
+ if err != nil {
+ return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err)
+ }
+ return nil
+ })
}
diff --git a/pkg/os/system/api.go b/pkg/os/system/api.go
index a09c83a4..128704ba 100644
--- a/pkg/os/system/api.go
+++ b/pkg/os/system/api.go
@@ -2,10 +2,12 @@ package system
import (
"fmt"
+ "time"
"github.com/kubernetes-csi/csi-proxy/pkg/cim"
"github.com/kubernetes-csi/csi-proxy/pkg/server/system/impl"
- "github.com/kubernetes-csi/csi-proxy/pkg/utils"
+ "github.com/pkg/errors"
+ "k8s.io/klog/v2"
)
// Implements the System OS API calls. All code here should be very simple
@@ -24,6 +26,29 @@ type ServiceInfo struct {
Status uint32 `json:"Status"`
}
+type stateCheckFunc func(cim.ServiceInterface, string) (bool, string, error)
+type stateTransitionFunc func(cim.ServiceInterface) error
+
+const (
+ // startServiceErrorCodeAccepted indicates the request is accepted
+ startServiceErrorCodeAccepted = 0
+
+ // startServiceErrorCodeAlreadyRunning indicates a service is already running
+ startServiceErrorCodeAlreadyRunning = 10
+
+ // stopServiceErrorCodeAccepted indicates the request is accepted
+ stopServiceErrorCodeAccepted = 0
+
+ // stopServiceErrorCodeStopPending indicates the request cannot be sent to the service because the state of the service is 0,1,2 (pending)
+ stopServiceErrorCodeStopPending = 5
+
+ // stopServiceErrorCodeDependentRunning indicates a service cannot be stopped as its dependents may still be running
+ stopServiceErrorCodeDependentRunning = 3
+
+ serviceStateRunning = "Running"
+ serviceStateStopped = "Stopped"
+)
+
var (
startModeMappings = map[string]uint32{
"Boot": impl.START_TYPE_BOOT,
@@ -33,16 +58,20 @@ var (
"Disabled": impl.START_TYPE_DISABLED,
}
- statusMappings = map[string]uint32{
- "Unknown": impl.SERVICE_STATUS_UNKNOWN,
- "Stopped": impl.SERVICE_STATUS_STOPPED,
- "Start Pending": impl.SERVICE_STATUS_START_PENDING,
- "Stop Pending": impl.SERVICE_STATUS_STOP_PENDING,
- "Running": impl.SERVICE_STATUS_RUNNING,
- "Continue Pending": impl.SERVICE_STATUS_CONTINUE_PENDING,
- "Pause Pending": impl.SERVICE_STATUS_PAUSE_PENDING,
- "Paused": impl.SERVICE_STATUS_PAUSED,
+ stateMappings = map[string]uint32{
+ "Unknown": impl.SERVICE_STATUS_UNKNOWN,
+ serviceStateStopped: impl.SERVICE_STATUS_STOPPED,
+ "Start Pending": impl.SERVICE_STATUS_START_PENDING,
+ "Stop Pending": impl.SERVICE_STATUS_STOP_PENDING,
+ serviceStateRunning: impl.SERVICE_STATUS_RUNNING,
+ "Continue Pending": impl.SERVICE_STATUS_CONTINUE_PENDING,
+ "Pause Pending": impl.SERVICE_STATUS_PAUSE_PENDING,
+ "Paused": impl.SERVICE_STATUS_PAUSED,
}
+
+ serviceStateCheckInternal = 200 * time.Millisecond
+ serviceStateCheckTimeout = 30 * time.Second
+ errTimedOut = errors.New("Timed out")
)
func serviceStartModeToStartType(startMode string) uint32 {
@@ -50,75 +79,277 @@ func serviceStartModeToStartType(startMode string) uint32 {
}
func serviceState(status string) uint32 {
- return statusMappings[status]
+ return stateMappings[status]
+}
+
+type ServiceManager interface {
+ WaitUntilServiceState(cim.ServiceInterface, stateTransitionFunc, stateCheckFunc, time.Duration, time.Duration) (string, error)
+ GetDependentsForService(string) ([]string, error)
}
-type APIImplementor struct{}
+type ServiceFactory interface {
+ GetService(name string) (cim.ServiceInterface, error)
+}
+
+type APIImplementor struct {
+ serviceFactory ServiceFactory
+ serviceManager ServiceManager
+}
func New() APIImplementor {
- return APIImplementor{}
+ serviceFactory := cim.Win32ServiceFactory{}
+ return APIImplementor{
+ serviceFactory: serviceFactory,
+ serviceManager: ServiceManagerImpl{
+ serviceFactory: serviceFactory,
+ },
+ }
}
func (APIImplementor) GetBIOSSerialNumber() (string, error) {
- bios, err := cim.QueryBIOSElement([]string{"SerialNumber"})
- if err != nil {
- return "", fmt.Errorf("failed to get BIOS element: %w", err)
+ var sn string
+ err := cim.WithCOMThread(func() error {
+ bios, err := cim.QueryBIOSElement(cim.BIOSSelectorList)
+ if err != nil {
+ return fmt.Errorf("failed to get BIOS element: %w", err)
+ }
+
+ sn, err = cim.GetBIOSSerialNumber(bios)
+ if err != nil {
+ return fmt.Errorf("failed to get BIOS serial number property: %w", err)
+ }
+
+ return nil
+ })
+ return sn, err
+}
+
+func (impl APIImplementor) GetService(name string) (*ServiceInfo, error) {
+ var serviceInfo *ServiceInfo
+ err := cim.WithCOMThread(func() error {
+ service, err := impl.serviceFactory.GetService(name)
+ if err != nil {
+ return fmt.Errorf("failed to get service %s: %w", name, err)
+ }
+
+ displayName, err := cim.GetServiceDisplayName(service)
+ if err != nil {
+ return fmt.Errorf("failed to get displayName property of service %s: %w", name, err)
+ }
+
+ state, err := cim.GetServiceState(service)
+ if err != nil {
+ return fmt.Errorf("failed to get state property of service %s: %w", name, err)
+ }
+
+ startMode, err := cim.GetServiceStartMode(service)
+ if err != nil {
+ return fmt.Errorf("failed to get startMode property of service %s: %w", name, err)
+ }
+
+ serviceInfo = &ServiceInfo{
+ DisplayName: displayName,
+ StartType: serviceStartModeToStartType(startMode),
+ Status: serviceState(state),
+ }
+ return nil
+ })
+ return serviceInfo, err
+}
+
+func (impl APIImplementor) StartService(name string) error {
+ startService := func(service cim.ServiceInterface) error {
+ retVal, err := service.StartService()
+ if err != nil || (retVal != startServiceErrorCodeAccepted && retVal != startServiceErrorCodeAlreadyRunning) {
+ return fmt.Errorf("error starting service name %s. return value: %d, error: %v", name, retVal, err)
+ }
+ return nil
}
+ serviceRunningCheck := func(service cim.ServiceInterface, state string) (bool, string, error) {
+ err := service.Refresh()
+ if err != nil {
+ return false, "", err
+ }
- sn, err := bios.GetPropertySerialNumber()
- if err != nil {
- return "", fmt.Errorf("failed to get BIOS serial number property: %w", err)
+ newState, err := cim.GetServiceState(service)
+ if err != nil {
+ return false, state, err
+ }
+
+ klog.V(6).Infof("service (%v) state check: %s => %s", service, state, newState)
+ return state == serviceStateRunning, newState, err
}
- return sn, nil
+ return cim.WithCOMThread(func() error {
+ service, err := impl.serviceFactory.GetService(name)
+ if err != nil {
+ return err
+ }
+
+ state, err := impl.serviceManager.WaitUntilServiceState(service, startService, serviceRunningCheck, serviceStateCheckInternal, serviceStateCheckTimeout)
+ if err != nil && !errors.Is(err, errTimedOut) {
+ return err
+ }
+
+ if state != serviceStateRunning {
+ return fmt.Errorf("timed out waiting for service %s to become running", name)
+ }
+
+ return nil
+ })
}
-func (APIImplementor) GetService(name string) (*ServiceInfo, error) {
- service, err := cim.QueryServiceByName(name, []string{"DisplayName", "State", "StartMode"})
- if err != nil {
- return nil, fmt.Errorf("failed to get service %s: %w", name, err)
+func (impl APIImplementor) stopSingleService(name string) (bool, error) {
+ var dependentRunning bool
+ stopService := func(service cim.ServiceInterface) error {
+ retVal, err := service.StopService()
+ if err != nil || (retVal != stopServiceErrorCodeAccepted && retVal != stopServiceErrorCodeStopPending) {
+ if retVal == stopServiceErrorCodeDependentRunning {
+ dependentRunning = true
+ return fmt.Errorf("error stopping service %s as dependent services are not stopped", name)
+ }
+ return fmt.Errorf("error stopping service %s. return value: %d, error: %v", name, retVal, err)
+ }
+ return nil
}
+ serviceStoppedCheck := func(service cim.ServiceInterface, state string) (bool, string, error) {
+ err := service.Refresh()
+ if err != nil {
+ return false, "", err
+ }
- displayName, err := service.GetPropertyDisplayName()
- if err != nil {
- return nil, fmt.Errorf("failed to get displayName property of service %s: %w", name, err)
+ newState, err := cim.GetServiceState(service)
+ if err != nil {
+ return false, state, err
+ }
+
+ klog.V(6).Infof("service (%v) state check: %s => %s", service, state, newState)
+ return newState == serviceStateStopped, newState, err
}
- state, err := service.GetPropertyState()
+ service, err := impl.serviceFactory.GetService(name)
if err != nil {
- return nil, fmt.Errorf("failed to get state property of service %s: %w", name, err)
+ return dependentRunning, err
}
- startMode, err := service.GetPropertyStartMode()
- if err != nil {
- return nil, fmt.Errorf("failed to get startMode property of service %s: %w", name, err)
+ state, err := impl.serviceManager.WaitUntilServiceState(service, stopService, serviceStoppedCheck, serviceStateCheckInternal, serviceStateCheckTimeout)
+ if err != nil && !errors.Is(err, errTimedOut) {
+ return dependentRunning, fmt.Errorf("error stopping service name %s. current state: %s", name, state)
+ }
+
+ if state != serviceStateStopped {
+ return dependentRunning, fmt.Errorf("timed out waiting for service %s to stop", name)
}
- return &ServiceInfo{
- DisplayName: displayName,
- StartType: serviceStartModeToStartType(startMode),
- Status: serviceState(state),
- }, nil
+ return dependentRunning, nil
}
-func (APIImplementor) StartService(name string) error {
- // Note: both StartService and StopService are not implemented by WMI
- script := `Start-Service -Name $env:ServiceName`
- cmdEnv := fmt.Sprintf("ServiceName=%s", name)
- out, err := utils.RunPowershellCmd(script, cmdEnv)
+func (impl APIImplementor) StopService(name string, force bool) error {
+ return cim.WithCOMThread(func() error {
+ dependentRunning, err := impl.stopSingleService(name)
+ if err == nil || !dependentRunning || !force {
+ return err
+ }
+
+ serviceNames, err := impl.serviceManager.GetDependentsForService(name)
+ if err != nil {
+ return fmt.Errorf("error getting dependent services for service name %s", name)
+ }
+
+ for _, serviceName := range serviceNames {
+ _, err = impl.stopSingleService(serviceName)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+}
+
+type ServiceManagerImpl struct {
+ serviceFactory ServiceFactory
+}
+
+func (impl ServiceManagerImpl) WaitUntilServiceState(service cim.ServiceInterface, stateTransition stateTransitionFunc, stateCheck stateCheckFunc, interval time.Duration, timeout time.Duration) (string, error) {
+ done, state, err := stateCheck(service, "")
if err != nil {
- return fmt.Errorf("error starting service name=%s. cmd: %s, output: %s, error: %v", name, script, string(out), err)
+ return state, err
+ }
+ if done {
+ return state, err
}
- return nil
+ // Perform transition if not already in desired state
+ if err := stateTransition(service); err != nil {
+ return state, err
+ }
+
+ ticker := time.NewTicker(interval)
+ defer ticker.Stop()
+
+ timeoutChan := time.After(timeout)
+
+ for {
+ select {
+ case <-ticker.C:
+ klog.V(6).Infof("Checking service (%v) state...", service)
+ done, state, err = stateCheck(service, state)
+ if err != nil {
+ return state, fmt.Errorf("check failed: %w", err)
+ }
+ if done {
+ klog.V(6).Infof("service (%v) state is %s and transition done.", service, state)
+ return state, nil
+ }
+ case <-timeoutChan:
+ done, state, err = stateCheck(service, state)
+ return state, errTimedOut
+ }
+ }
}
-func (APIImplementor) StopService(name string, force bool) error {
- script := `Stop-Service -Name $env:ServiceName -Force:$([System.Convert]::ToBoolean($env:Force))`
- out, err := utils.RunPowershellCmd(script, fmt.Sprintf("ServiceName=%s", name), fmt.Sprintf("Force=%t", force))
+func (impl ServiceManagerImpl) GetDependentsForService(name string) ([]string, error) {
+ var serviceNames []string
+ var servicesToCheck []cim.ServiceInterface
+ servicesByName := map[string]string{}
+
+ service, err := impl.serviceFactory.GetService(name)
if err != nil {
- return fmt.Errorf("error stopping service name=%s. cmd: %s, output: %s, error: %v", name, script, string(out), err)
+ return serviceNames, err
+ }
+
+ servicesToCheck = append(servicesToCheck, service)
+ i := 0
+ for i < len(servicesToCheck) {
+ service = servicesToCheck[i]
+ i += 1
+
+ serviceName, err := cim.GetServiceName(service)
+ if err != nil {
+ return serviceNames, err
+ }
+
+ currentState, err := cim.GetServiceState(service)
+ if err != nil {
+ return serviceNames, err
+ }
+
+ if currentState != serviceStateRunning {
+ continue
+ }
+
+ servicesByName[serviceName] = serviceName
+ // prepend the current service to the front
+ serviceNames = append([]string{serviceName}, serviceNames...)
+
+ dependents, err := service.GetDependents()
+ if err != nil {
+ return serviceNames, err
+ }
+
+ servicesToCheck = append(servicesToCheck, dependents...)
}
- return nil
+ return serviceNames, nil
}
diff --git a/pkg/os/system/api_test.go b/pkg/os/system/api_test.go
new file mode 100644
index 00000000..b977c74c
--- /dev/null
+++ b/pkg/os/system/api_test.go
@@ -0,0 +1,214 @@
+package system
+
+import (
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/kubernetes-csi/csi-proxy/pkg/cim"
+ "github.com/pkg/errors"
+)
+
+type MockService struct {
+ Name string
+ DisplayName string
+ State string
+ StartMode string
+ Dependents []cim.ServiceInterface
+
+ StartResult uint32
+ StopResult uint32
+
+ Err error
+}
+
+func (m *MockService) GetPropertyName() (string, error) {
+ return m.Name, m.Err
+}
+
+func (m *MockService) GetPropertyDisplayName() (string, error) {
+ return m.DisplayName, m.Err
+}
+
+func (m *MockService) GetPropertyState() (string, error) {
+ return m.State, m.Err
+}
+
+func (m *MockService) GetPropertyStartMode() (string, error) {
+ return m.StartMode, m.Err
+}
+
+func (m *MockService) GetDependents() ([]cim.ServiceInterface, error) {
+ return m.Dependents, m.Err
+}
+
+func (m *MockService) StartService() (uint32, error) {
+ m.State = "Running"
+ return m.StartResult, m.Err
+}
+
+func (m *MockService) StopService() (uint32, error) {
+ m.State = "Stopped"
+ return m.StopResult, m.Err
+}
+
+func (m *MockService) Refresh() error {
+ return nil
+}
+
+type MockServiceFactory struct {
+ Services map[string]cim.ServiceInterface
+ Err error
+}
+
+func (f *MockServiceFactory) GetService(name string) (cim.ServiceInterface, error) {
+ svc, ok := f.Services[name]
+ if !ok {
+ return nil, fmt.Errorf("service not found: %s", name)
+ }
+ return svc, f.Err
+}
+
+func TestWaitUntilServiceState_Success(t *testing.T) {
+ svc := &MockService{Name: "svc", State: "Stopped"}
+
+ stateChanged := false
+
+ stateCheck := func(s cim.ServiceInterface, state string) (bool, string, error) {
+ if stateChanged {
+ svc.State = serviceStateRunning
+ return true, svc.State, nil
+ }
+ return false, svc.State, nil
+ }
+
+ stateTransition := func(s cim.ServiceInterface) error {
+ stateChanged = true
+ return nil
+ }
+
+ impl := ServiceManagerImpl{}
+ state, err := impl.WaitUntilServiceState(svc, stateTransition, stateCheck, 10*time.Millisecond, 500*time.Millisecond)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if state != serviceStateRunning {
+ t.Fatalf("expected state %q, got %q", serviceStateRunning, state)
+ }
+}
+
+func TestWaitUntilServiceState_Timeout(t *testing.T) {
+ svc := &MockService{Name: "svc", State: "Stopped"}
+
+ stateCheck := func(s cim.ServiceInterface, state string) (bool, string, error) {
+ return false, svc.State, nil
+ }
+
+ stateTransition := func(s cim.ServiceInterface) error {
+ return nil
+ }
+
+ impl := ServiceManagerImpl{}
+ state, err := impl.WaitUntilServiceState(svc, stateTransition, stateCheck, 10*time.Millisecond, 50*time.Millisecond)
+ if !errors.Is(err, errTimedOut) {
+ t.Fatalf("expected timeout error, got %v", err)
+ }
+ if state != svc.State {
+ t.Fatalf("expected state %q, got %q", svc.State, state)
+ }
+}
+
+func TestWaitUntilServiceState_TransitionFails(t *testing.T) {
+ svc := &MockService{Name: "svc", State: "Stopped"}
+
+ stateCheck := func(s cim.ServiceInterface, state string) (bool, string, error) {
+ return false, svc.State, nil
+ }
+
+ stateTransition := func(s cim.ServiceInterface) error {
+ return fmt.Errorf("transition failed")
+ }
+
+ impl := ServiceManagerImpl{}
+ _, err := impl.WaitUntilServiceState(svc, stateTransition, stateCheck, 10*time.Millisecond, 50*time.Millisecond)
+ if err == nil || err.Error() != "transition failed" {
+ t.Fatalf("expected transition error, got %v", err)
+ }
+}
+
+func TestGetDependentsForService(t *testing.T) {
+ // Construct the dependency tree
+ svcC := &MockService{Name: "C", State: serviceStateRunning}
+ svcB := &MockService{Name: "B", State: serviceStateRunning, Dependents: []cim.ServiceInterface{svcC}}
+ svcA := &MockService{Name: "A", State: serviceStateRunning, Dependents: []cim.ServiceInterface{svcB}}
+
+ factory := &MockServiceFactory{
+ Services: map[string]cim.ServiceInterface{
+ "A": svcA,
+ "B": svcB,
+ "C": svcC,
+ },
+ }
+
+ impl := ServiceManagerImpl{
+ serviceFactory: factory,
+ }
+
+ names, err := impl.GetDependentsForService("A")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ expected := []string{"C", "B", "A"}
+ if len(names) != len(expected) {
+ t.Fatalf("expected %d services, got %d", len(expected), len(names))
+ }
+ for i, name := range expected {
+ if names[i] != name {
+ t.Errorf("expected %s at position %d, got %s", name, i, names[i])
+ }
+ }
+}
+
+func TestGetDependentsForService_SkipsNonRunning(t *testing.T) {
+ svcB := &MockService{Name: "B", State: "Stopped"}
+ svcA := &MockService{Name: "A", State: serviceStateRunning, Dependents: []cim.ServiceInterface{svcB}}
+
+ factory := &MockServiceFactory{
+ Services: map[string]cim.ServiceInterface{
+ "A": svcA,
+ "B": svcB,
+ },
+ }
+
+ impl := ServiceManagerImpl{
+ serviceFactory: factory,
+ }
+
+ names, err := impl.GetDependentsForService("A")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ expected := []string{"A"} // B is skipped due to stopped state
+ if len(names) != len(expected) {
+ t.Fatalf("expected %d services, got %d", len(expected), len(names))
+ }
+}
+
+func TestGetDependenciesForService_Winmgmt(t *testing.T) {
+ impl := ServiceManagerImpl{
+ serviceFactory: cim.Win32ServiceFactory{},
+ }
+
+ serviceName := "Winmgmt"
+ names, err := impl.GetDependentsForService(serviceName)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ expected := 4
+ if len(names) != expected || names[len(names)-1] != serviceName {
+ t.Fatalf("expected %d services, got %d", expected, len(names))
+ }
+}
diff --git a/pkg/os/volume/api.go b/pkg/os/volume/api.go
index 5bdf0e04..87a10f0f 100644
--- a/pkg/os/volume/api.go
+++ b/pkg/os/volume/api.go
@@ -2,21 +2,21 @@ package volume
import (
"fmt"
- "os"
"path/filepath"
"regexp"
- "strconv"
"strings"
- "github.com/go-ole/go-ole"
"github.com/kubernetes-csi/csi-proxy/pkg/cim"
"github.com/kubernetes-csi/csi-proxy/pkg/utils"
- wmierrors "github.com/microsoft/wmi/pkg/errors"
"github.com/pkg/errors"
"golang.org/x/sys/windows"
"k8s.io/klog/v2"
)
+const (
+ minimumResizeSize = 100 * 1024 * 1024
+)
+
// API exposes the internal volume operations available in the server
type API interface {
// ListVolumesOnDisk lists volumes on a disk identified by a `diskNumber` and optionally a partition identified by `partitionNumber`.
@@ -69,57 +69,55 @@ func New() VolumeAPI {
// ListVolumesOnDisk - returns back list of volumes(volumeIDs) in a disk and a partition.
func (VolumeAPI) ListVolumesOnDisk(diskNumber uint32, partitionNumber uint32) (volumeIDs []string, err error) {
- partitions, err := cim.ListPartitionsOnDisk(diskNumber, partitionNumber, []string{"ObjectId"})
- if err != nil {
- return nil, errors.Wrapf(err, "failed to list partition on disk %d", diskNumber)
- }
-
- volumes, err := cim.ListVolumes([]string{"ObjectId", "UniqueId"})
- if err != nil {
- return nil, errors.Wrapf(err, "failed to list volumes")
- }
+ err = cim.WithCOMThread(func() error {
+ partitions, err := cim.ListPartitionsOnDisk(diskNumber, partitionNumber, cim.PartitionSelectorListObjectID)
+ if err != nil {
+ return errors.Wrapf(err, "failed to list partition on disk %d", diskNumber)
+ }
- filtered, err := cim.FindVolumesByPartition(volumes, partitions)
- if cim.IgnoreNotFound(err) != nil {
- return nil, errors.Wrapf(err, "failed to list volumes on disk %d", diskNumber)
- }
+ volumes, err := cim.FindVolumesByPartition(partitions)
+ if cim.IgnoreNotFound(err) != nil {
+ return errors.Wrapf(err, "failed to list volumes on disk %d", diskNumber)
+ }
- for _, volume := range filtered {
- uniqueID, err := volume.GetPropertyUniqueId()
- if err != nil {
- return nil, errors.Wrapf(err, "failed to list volumes")
+ for _, volume := range volumes {
+ uniqueID, err := cim.GetVolumeUniqueID(volume)
+ if err != nil {
+ return errors.Wrapf(err, "failed to get unique ID for volume %v", volume)
+ }
+ volumeIDs = append(volumeIDs, uniqueID)
}
- volumeIDs = append(volumeIDs, uniqueID)
- }
- return volumeIDs, nil
+ return nil
+ })
+ return
}
// FormatVolume - Formats a volume with the NTFS format.
func (VolumeAPI) FormatVolume(volumeID string) (err error) {
- volume, err := cim.QueryVolumeByUniqueID(volumeID, nil)
- if err != nil {
- return fmt.Errorf("error formatting volume (%s). error: %v", volumeID, err)
- }
+ return cim.WithCOMThread(func() error {
+ volume, err := cim.QueryVolumeByUniqueID(volumeID, nil)
+ if err != nil {
+ return fmt.Errorf("error formatting volume (%s). error: %v", volumeID, err)
+ }
- result, err := volume.InvokeMethodWithReturn(
- "Format",
- "NTFS", // Format,
- "", // FileSystemLabel,
- nil, // AllocationUnitSize,
- false, // Full,
- true, // Force
- nil, // Compress,
- nil, // ShortFileNameSupport,
- nil, // SetIntegrityStreams,
- nil, // UseLargeFRS,
- nil, // DisableHeatGathering,
- )
- if result != 0 || err != nil {
- return fmt.Errorf("error formatting volume (%s). result: %d, error: %v", volumeID, result, err)
- }
- // TODO: Do we need to handle anything for len(out) == 0
- return nil
+ result, err := cim.FormatVolume(volume,
+ "NTFS", // Format,
+ "", // FileSystemLabel,
+ nil, // AllocationUnitSize,
+ false, // Full,
+ true, // Force
+ nil, // Compress,
+ nil, // ShortFileNameSupport,
+ nil, // SetIntegrityStreams,
+ nil, // UseLargeFRS,
+ nil, // DisableHeatGathering,
+ )
+ if result != 0 || err != nil {
+ return fmt.Errorf("error formatting volume (%s). result: %d, error: %v", volumeID, result, err)
+ }
+ return nil
+ })
}
// WriteVolumeCache - Writes the file system cache to disk with the given volume id
@@ -129,18 +127,22 @@ func (VolumeAPI) WriteVolumeCache(volumeID string) (err error) {
// IsVolumeFormatted - Check if the volume is formatted with the pre specified filesystem(typically ntfs).
func (VolumeAPI) IsVolumeFormatted(volumeID string) (bool, error) {
- volume, err := cim.QueryVolumeByUniqueID(volumeID, []string{"FileSystemType"})
- if err != nil {
- return false, fmt.Errorf("error checking if volume (%s) is formatted. error: %v", volumeID, err)
- }
+ var formatted bool
+ err := cim.WithCOMThread(func() error {
+ volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForFileSystemType)
+ if err != nil {
+ return fmt.Errorf("error checking if volume (%s) is formatted. error: %v", volumeID, err)
+ }
- fsType, err := volume.GetProperty("FileSystemType")
- if err != nil {
- return false, fmt.Errorf("failed to query volume file system type (%s): %w", volumeID, err)
- }
+ fsType, err := cim.GetVolumeFileSystemType(volume)
+ if err != nil {
+ return fmt.Errorf("failed to query volume file system type (%s): %w", volumeID, err)
+ }
- const FileSystemUnknown = 0
- return fsType.(int32) != FileSystemUnknown, nil
+ formatted = fsType != cim.FileSystemUnknown
+ return nil
+ })
+ return formatted, err
}
// MountVolume - mounts a volume to a path. This is done using Win32 API SetVolumeMountPoint for presenting the volume via a path.
@@ -190,124 +192,112 @@ func (VolumeAPI) UnmountVolume(volumeID, path string) error {
// ResizeVolume - resizes a volume with the given size, if size == 0 then max supported size is used
func (VolumeAPI) ResizeVolume(volumeID string, size int64) error {
- var err error
- var finalSize int64
- part, err := cim.GetPartitionByVolumeUniqueID(volumeID, nil)
- if err != nil {
- return err
- }
-
- // If size is 0 then we will resize to the maximum size possible, otherwise just resize to size
- if size == 0 {
- var sizeMin, sizeMax ole.VARIANT
- var status string
- result, err := part.InvokeMethodWithReturn("GetSupportedSize", &sizeMin, &sizeMax, &status)
- if result != 0 || err != nil {
- return fmt.Errorf("error getting sizeMin, sizeMax from volume(%s). result: %d, status: %s, error: %v", volumeID, result, status, err)
- }
- klog.V(5).Infof("got sizeMin(%v) sizeMax(%v) from volume(%s), status: %s", sizeMin, sizeMax, volumeID, status)
-
- finalSizeStr := sizeMax.ToString()
- finalSize, err = strconv.ParseInt(finalSizeStr, 10, 64)
+ return cim.WithCOMThread(func() error {
+ var err error
+ var finalSize int64
+ part, err := cim.GetPartitionByVolumeUniqueID(volumeID)
if err != nil {
- return fmt.Errorf("error parsing the sizeMax of volume (%s) with error (%v)", volumeID, err)
+ return err
}
- } else {
- finalSize = size
- }
- currentSizeVal, err := part.GetProperty("Size")
- if err != nil {
- return fmt.Errorf("error getting the current size of volume (%s) with error (%v)", volumeID, err)
- }
+ // If size is 0 then we will resize to the maximum size possible, otherwise just resize to size
+ if size == 0 {
+ var result int
+ var status string
+ result, _, finalSize, status, err = cim.GetPartitionSupportedSize(part)
+ if result != 0 || err != nil {
+ return fmt.Errorf("error getting sizeMin, sizeMax from volume (%s). result: %d, status: %s, error: %v", volumeID, result, status, err)
+ }
- currentSize, err := strconv.ParseInt(currentSizeVal.(string), 10, 64)
- if err != nil {
- return fmt.Errorf("error parsing the current size of volume (%s) with error (%v)", volumeID, err)
- }
+ } else {
+ finalSize = size
+ }
- // only resize if finalSize - currentSize is greater than 100MB
- if finalSize-currentSize < 100*1024*1024 {
- klog.V(2).Infof("minimum resize difference(1GB) not met, skipping resize. volumeID=%s currentSize=%d finalSize=%d", volumeID, currentSize, finalSize)
- return nil
- }
+ currentSize, err := cim.GetPartitionSize(part)
+ if err != nil {
+ return fmt.Errorf("error getting the current size of volume (%s) with error (%v)", volumeID, err)
+ }
- //if the partition's size is already the size we want this is a noop, just return
- if currentSize >= finalSize {
- klog.V(2).Infof("Attempted to resize volume (%s) to a lower size, from currentBytes=%d wantedBytes=%d", volumeID, currentSize, finalSize)
- return nil
- }
+ // only resize if finalSize - currentSize is greater than 100MB
+ if finalSize-currentSize < minimumResizeSize {
+ klog.V(2).Infof("minimum resize difference (100MB) not met, skipping resize. volumeID=%s currentSize=%d finalSize=%d", volumeID, currentSize, finalSize)
+ return nil
+ }
- var status string
- result, err := part.InvokeMethodWithReturn("Resize", strconv.Itoa(int(finalSize)), &status)
+ //if the partition's size is already the size we want this is a noop, just return
+ if currentSize >= finalSize {
+ klog.V(2).Infof("Attempted to resize volume (%s) to a lower size, from currentBytes=%d wantedBytes=%d", volumeID, currentSize, finalSize)
+ return nil
+ }
- if result != 0 || err != nil {
- return fmt.Errorf("error resizing volume (%s). size:%v, finalSize %v, error: %v", volumeID, size, finalSize, err)
- }
+ result, _, err := cim.ResizePartition(part, finalSize)
+ if result != 0 || err != nil {
+ return fmt.Errorf("error resizing volume (%s). size:%v, finalSize %v, error: %v", volumeID, size, finalSize, err)
+ }
- diskNumber, err := cim.GetPartitionDiskNumber(part)
- if err != nil {
- return fmt.Errorf("error parsing disk number of volume (%s). error: %v", volumeID, err)
- }
+ diskNumber, err := cim.GetPartitionDiskNumber(part)
+ if err != nil {
+ return fmt.Errorf("error parsing disk number of volume (%s). error: %v", volumeID, err)
+ }
- disk, err := cim.QueryDiskByNumber(diskNumber, nil)
- if err != nil {
- return fmt.Errorf("error parsing disk number of volume (%s). error: %v", volumeID, err)
- }
+ disk, err := cim.QueryDiskByNumber(diskNumber, nil)
+ if err != nil {
+ return fmt.Errorf("error query disk of volume (%s). error: %v", volumeID, err)
+ }
- result, err = disk.InvokeMethodWithReturn("Refresh", &status)
- if result != 0 || err != nil {
- return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err)
- }
+ result, _, err = cim.RefreshDisk(disk)
+ if result != 0 || err != nil {
+ return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err)
+ }
- return nil
+ return nil
+ })
}
// GetVolumeStats - retrieves the volume stats for a given volume
-func (VolumeAPI) GetVolumeStats(volumeID string) (int64, int64, error) {
- volume, err := cim.QueryVolumeByUniqueID(volumeID, []string{"UniqueId", "SizeRemaining", "Size"})
- if err != nil {
- return -1, -1, fmt.Errorf("error getting capacity and used size of volume (%s). error: %v", volumeID, err)
- }
-
- volumeSizeVal, err := volume.GetProperty("Size")
- if err != nil {
- return -1, -1, fmt.Errorf("failed to query volume size (%s): %w", volumeID, err)
- }
-
- volumeSize, err := strconv.ParseInt(volumeSizeVal.(string), 10, 64)
- if err != nil {
- return -1, -1, fmt.Errorf("failed to parse volume size (%s): %w", volumeID, err)
- }
+func (VolumeAPI) GetVolumeStats(volumeID string) (volumeSize, volumeUsedSize int64, err error) {
+ volumeSize = -1
+ volumeUsedSize = -1
+ err = cim.WithCOMThread(func() error {
+ volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForStats)
+ if err != nil {
+ return fmt.Errorf("error getting capacity and used size of volume (%s). error: %v", volumeID, err)
+ }
- volumeSizeRemainingVal, err := volume.GetProperty("SizeRemaining")
- if err != nil {
- return -1, -1, fmt.Errorf("failed to query volume remaining size (%s): %w", volumeID, err)
- }
+ volumeSize, err = cim.GetVolumeSize(volume)
+ if err != nil {
+ return fmt.Errorf("failed to query volume size (%s): %w", volumeID, err)
+ }
- volumeSizeRemaining, err := strconv.ParseInt(volumeSizeRemainingVal.(string), 10, 64)
- if err != nil {
- return -1, -1, fmt.Errorf("failed to parse volume remaining size (%s): %w", volumeID, err)
- }
+ volumeSizeRemaining, err := cim.GetVolumeSizeRemaining(volume)
+ if err != nil {
+ return fmt.Errorf("failed to query volume remaining size (%s): %w", volumeID, err)
+ }
- volumeUsedSize := volumeSize - volumeSizeRemaining
- return volumeSize, volumeUsedSize, nil
+ volumeUsedSize = volumeSize - volumeSizeRemaining
+ return nil
+ })
+ return
}
// GetDiskNumberFromVolumeID - gets the disk number where the volume is.
func (VolumeAPI) GetDiskNumberFromVolumeID(volumeID string) (uint32, error) {
- // get the size and sizeRemaining for the volume
- part, err := cim.GetPartitionByVolumeUniqueID(volumeID, []string{"DiskNumber"})
- if err != nil {
- return 0, err
- }
+ var diskNumber uint32
+ err := cim.WithCOMThread(func() error {
+ // get the size and sizeRemaining for the volume
+ part, err := cim.GetPartitionByVolumeUniqueID(volumeID)
+ if err != nil {
+ return err
+ }
- diskNumber, err := part.GetProperty("DiskNumber")
- if err != nil {
- return 0, fmt.Errorf("error query disk number of volume (%s). error: %v", volumeID, err)
- }
+ diskNumber, err = cim.GetPartitionDiskNumber(part)
+ if err != nil {
+ return fmt.Errorf("error query disk number of volume (%s). error: %v", volumeID, err)
+ }
- return uint32(diskNumber.(int32)), nil
+ return nil
+ })
+ return diskNumber, err
}
// GetVolumeIDFromTargetPath - gets the volume ID given a mount point, the function is recursive until it find a volume or errors out
@@ -321,7 +311,7 @@ func (VolumeAPI) GetVolumeIDFromTargetPath(mount string) (string, error) {
}
func getTarget(mount string) (string, error) {
- mountedFolder, err := isMountedFolder(mount)
+ mountedFolder, err := utils.IsMountedFolder(mount)
if err != nil {
return "", err
}
@@ -361,7 +351,7 @@ func (VolumeAPI) GetClosestVolumeIDFromTargetPath(targetPath string) (string, er
}
// findClosestVolume finds the closest volume id for a given target path
-// by following symlinks and moving up in the filesystem, if after moving up in the filesystem
+// by following symlinks and moving up in the filesystem. if after moving up in the filesystem
// we get to a DriveLetter then the volume corresponding to this drive letter is returned instead.
func findClosestVolume(path string) (string, error) {
candidatePath := path
@@ -370,22 +360,20 @@ func findClosestVolume(path string) (string, error) {
// while trying to follow symlinks
//
// The maximum path length in Windows is 260, it could be possible to end
- // up in a sceneario where we do more than 256 iterations (e.g. by following symlinks from
+ // up in a scenario where we do more than 256 iterations (e.g. by following symlinks from
// a place high in the hierarchy to a nested sibling location many times)
// https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file#:~:text=In%20editions%20of%20Windows%20before,required%20to%20remove%20the%20limit.
//
// The number of iterations is 256, which is similar to the number of iterations in filepath-securejoin
// https://github.com/cyphar/filepath-securejoin/blob/64536a8a66ae59588c981e2199f1dcf410508e07/join.go#L51
for i := 0; i < 256; i += 1 {
- fi, err := os.Lstat(candidatePath)
+ isSymlink, err := utils.IsPathSymlink(candidatePath)
if err != nil {
return "", err
}
- // for windows NTFS, check if the path is symlink instead of directory.
- isSymlink := fi.Mode()&os.ModeSymlink != 0 || fi.Mode()&os.ModeIrregular != 0
// mounted folder created by SetVolumeMountPoint may still report ModeSymlink == 0
- mountedFolder, err := isMountedFolder(candidatePath)
+ mountedFolder, err := utils.IsMountedFolder(candidatePath)
if err != nil {
return "", err
}
@@ -422,67 +410,40 @@ func findClosestVolume(path string) (string, error) {
return "", fmt.Errorf("failed to find the closest volume for path=%s", path)
}
-// isMountedFolder checks whether the `path` is a mounted folder.
-func isMountedFolder(path string) (bool, error) {
- // https://learn.microsoft.com/en-us/windows/win32/fileio/determining-whether-a-directory-is-a-volume-mount-point
- utf16Path, _ := windows.UTF16PtrFromString(path)
- attrs, err := windows.GetFileAttributes(utf16Path)
- if err != nil {
- return false, err
- }
-
- if (attrs & windows.FILE_ATTRIBUTE_REPARSE_POINT) == 0 {
- return false, nil
- }
-
- var findData windows.Win32finddata
- findHandle, err := windows.FindFirstFile(utf16Path, &findData)
- if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) {
- return false, err
- }
-
- for err == nil {
- if findData.Reserved0&windows.IO_REPARSE_TAG_MOUNT_POINT != 0 {
- return true, nil
- }
-
- err = windows.FindNextFile(findHandle, &findData)
- if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) {
- return false, err
- }
- }
-
- return false, nil
-}
-
// getVolumeForDriveLetter gets a volume from a drive letter (e.g. C:/).
func getVolumeForDriveLetter(path string) (string, error) {
if len(path) != 1 {
return "", fmt.Errorf("the path %s is not a valid drive letter", path)
}
- volume, err := cim.GetVolumeByDriveLetter(path, []string{"UniqueId"})
- if err != nil {
- return "", nil
- }
+ var uniqueID string
+ err := cim.WithCOMThread(func() error {
+ volume, err := cim.GetVolumeByDriveLetter(path, cim.VolumeSelectorListUniqueID)
+ if err != nil {
+ return err
+ }
- uniqueID, err := volume.GetPropertyUniqueId()
- if err != nil {
- return "", fmt.Errorf("error query unique ID of volume (%v). error: %v", volume, err)
- }
+ uniqueID, err = cim.GetVolumeUniqueID(volume)
+ if err != nil {
+ return fmt.Errorf("error query unique ID of volume (%v). error: %v", volume, err)
+ }
- return uniqueID, nil
+ return nil
+ })
+ return uniqueID, err
}
func writeCache(volumeID string) error {
- volume, err := cim.QueryVolumeByUniqueID(volumeID, []string{})
- if err != nil && !wmierrors.IsNotFound(err) {
- return fmt.Errorf("error writing volume (%s) cache. error: %v", volumeID, err)
- }
+ return cim.WithCOMThread(func() error {
+ volume, err := cim.QueryVolumeByUniqueID(volumeID, nil)
+ if err != nil {
+ return fmt.Errorf("error writing volume (%s) cache. error: %v", volumeID, err)
+ }
- result, err := volume.Flush()
- if result != 0 || err != nil {
- return fmt.Errorf("error writing volume (%s) cache. result: %d, error: %v", volumeID, result, err)
- }
- return nil
+ result, err := cim.FlushVolume(volume)
+ if result != 0 || err != nil {
+ return fmt.Errorf("error writing volume (%s) cache. result: %d, error: %v", volumeID, result, err)
+ }
+ return nil
+ })
}
diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go
index 102675ac..f7a69492 100644
--- a/pkg/utils/utils.go
+++ b/pkg/utils/utils.go
@@ -1,10 +1,12 @@
package utils
import (
+ "errors"
+ "fmt"
"os"
- "os/exec"
"strings"
+ "golang.org/x/sys/windows"
"k8s.io/klog/v2"
)
@@ -22,10 +24,88 @@ func EnsureLongPath(path string) string {
return path
}
-func RunPowershellCmd(command string, envs ...string) ([]byte, error) {
- cmd := exec.Command("powershell", "-Mta", "-NoProfile", "-Command", command)
- cmd.Env = append(os.Environ(), envs...)
- klog.V(8).Infof("Executing command: %q", cmd.String())
- out, err := cmd.CombinedOutput()
- return out, err
+func IsPathValid(path string) (bool, error) {
+ pathString, err := windows.UTF16PtrFromString(path)
+ if err != nil {
+ return false, fmt.Errorf("invalid path: %w", err)
+ }
+
+ attrs, err := windows.GetFileAttributes(pathString)
+ if err != nil {
+ if errors.Is(err, windows.ERROR_PATH_NOT_FOUND) || errors.Is(err, windows.ERROR_FILE_NOT_FOUND) || errors.Is(err, windows.ERROR_INVALID_NAME) {
+ return false, nil
+ }
+
+ // GetFileAttribute returns user or password incorrect for a disconnected SMB connection after the password is changed
+ return false, fmt.Errorf("failed to get path %s attribute: %w", path, err)
+ }
+
+ klog.V(6).Infof("Path %s attribute: %d", path, attrs)
+ return attrs != windows.INVALID_FILE_ATTRIBUTES, nil
+}
+
+// IsMountedFolder checks whether the `path` is a mounted folder.
+func IsMountedFolder(path string) (bool, error) {
+ // https://learn.microsoft.com/en-us/windows/win32/fileio/determining-whether-a-directory-is-a-volume-mount-point
+ utf16Path, _ := windows.UTF16PtrFromString(path)
+ attrs, err := windows.GetFileAttributes(utf16Path)
+ if err != nil {
+ return false, err
+ }
+
+ if (attrs & windows.FILE_ATTRIBUTE_REPARSE_POINT) == 0 {
+ return false, nil
+ }
+
+ var findData windows.Win32finddata
+ findHandle, err := windows.FindFirstFile(utf16Path, &findData)
+ if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) {
+ return false, err
+ }
+
+ for err == nil {
+ if findData.Reserved0&windows.IO_REPARSE_TAG_MOUNT_POINT != 0 {
+ return true, nil
+ }
+
+ err = windows.FindNextFile(findHandle, &findData)
+ if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) {
+ return false, err
+ }
+ }
+
+ return false, nil
+}
+
+func IsPathSymlink(path string) (bool, error) {
+ fi, err := os.Lstat(path)
+ if err != nil {
+ return false, err
+ }
+ // for windows NTFS, check if the path is symlink instead of directory.
+ isSymlink := fi.Mode()&os.ModeSymlink != 0 || fi.Mode()&os.ModeIrregular != 0
+ return isSymlink, nil
+}
+
+func CreateSymlink(link, target string, isDir bool) error {
+ linkPtr, err := windows.UTF16PtrFromString(link)
+ if err != nil {
+ return err
+ }
+ targetPtr, err := windows.UTF16PtrFromString(target)
+ if err != nil {
+ return err
+ }
+
+ var flags uint32
+ if isDir {
+ flags = windows.SYMBOLIC_LINK_FLAG_DIRECTORY
+ }
+
+ err = windows.CreateSymbolicLink(
+ linkPtr,
+ targetPtr,
+ flags,
+ )
+ return err
}