diff --git a/docs/IMPLEMENTATION.md b/docs/IMPLEMENTATION.md index 3326f2f2..3468e3c4 100644 --- a/docs/IMPLEMENTATION.md +++ b/docs/IMPLEMENTATION.md @@ -121,6 +121,20 @@ func CallMethod(disp *ole.IDispatch, name string, params ...interface{}) (result } ``` +### Association + +Association can be used to retrieve all instances that are associated with +a particular source instance. + +There are a few Association classes in WMI. + +For example, association class [MSFT_PartitionToVolume](https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partitiontovolume) +can be used to retrieve a volume (`MSFT_Volume`) from a partition (`MSFT_Partition`), and vice versa. + +```go +collection, err := part.GetAssociated("MSFT_PartitionToVolume", "MSFT_Volume", "Volume", "Partition") +``` + ## Debug with PowerShell @@ -181,6 +195,13 @@ PS C:\Users\Administrator> $vol.FileSystem NTFS ``` +### Association + +```powershell +PS C:\Users\Administrator> $partition = (Get-CimInstance -Namespace root\Microsoft\Windows\Storage -ClassName MSFT_Partition -Filter "DiskNumber = 0")[0] +PS C:\Users\Administrator> Get-CimAssociatedInstance -InputObject $partition -Association MSFT_PartitionToVolume +``` + ### Call Class Method You may get Class Methods for a single CIM class using `$class.CimClassMethods`. diff --git a/integrationtests/iscsi_ps_scripts.go b/integrationtests/iscsi_ps_scripts.go index 89bec253..202e390b 100644 --- a/integrationtests/iscsi_ps_scripts.go +++ b/integrationtests/iscsi_ps_scripts.go @@ -42,14 +42,14 @@ $ProgressPreference = "SilentlyContinue" $targetName = "%s" # Get local IPv4 (e.g. 10.30.1.15, not 127.0.0.1) -$address = $(Get-NetIPAddress | Where-Object { $_.InterfaceAlias -eq "Ethernet" -and $_.AddressFamily -eq "IPv4" }).IPAddress +$address = $(Get-NetIPAddress | Where-Object { $_.InterfaceAlias -eq "%s" -and $_.AddressFamily -eq "IPv4" }).IPAddress # Create virtual disk in RAM -New-IscsiVirtualDisk -Path "ramdisk:scratch-${targetName}.vhdx" -Size 100MB | Out-Null +New-IscsiVirtualDisk -Path "ramdisk:scratch-${targetName}.vhdx" -Size 100MB -ComputerName $env:computername | Out-Null # Create a target that allows all initiator IQNs and map a disk to the new target -$target = New-IscsiServerTarget -TargetName $targetName -InitiatorIds @("Iqn:*") -Add-IscsiVirtualDiskTargetMapping -TargetName $targetName -DevicePath "ramdisk:scratch-${targetName}.vhdx" | Out-Null +$target = New-IscsiServerTarget -TargetName $targetName -InitiatorIds @("Iqn:*") -ComputerName $env:computername +Add-IscsiVirtualDiskTargetMapping -TargetName $targetName -DevicePath "ramdisk:scratch-${targetName}.vhdx" -ComputerName $env:computername | Out-Null $output = @{ "iqn" = "$($target.TargetIqn)" @@ -68,7 +68,7 @@ $username = "%s" $password = "%s" $securestring = ConvertTo-SecureString -String $password -AsPlainText -Force $chap = New-Object -TypeName System.Management.Automation.PSCredential -ArgumentList ($username, $securestring) -Set-IscsiServerTarget -TargetName $targetName -EnableChap $true -Chap $chap +Set-IscsiServerTarget -TargetName $targetName -EnableChap $true -Chap $chap -ComputerName $env:computername ` func setChap(targetName string, username string, password string) error { @@ -92,7 +92,7 @@ $securestring = ConvertTo-SecureString -String $password -AsPlainText -Force # Windows initiator does not uses the username for mutual authentication $chap = New-Object -TypeName System.Management.Automation.PSCredential -ArgumentList ($username, $securestring) -Set-IscsiServerTarget -TargetName $targetName -EnableReverseChap $true -ReverseChap $chap +Set-IscsiServerTarget -TargetName $targetName -EnableReverseChap $true -ReverseChap $chap -ComputerName $env:computername ` func setReverseChap(targetName string, password string) error { @@ -131,8 +131,8 @@ Get-IscsiTarget | Disconnect-IscsiTarget -Confirm:$false Get-IscsiTargetPortal | Remove-IscsiTargetPortal -confirm:$false # Clean target -Get-IscsiServerTarget | Remove-IscsiServerTarget -Get-IscsiVirtualDisk | Remove-IscsiVirtualDisk +Get-IscsiServerTarget -ComputerName $env:computername | Remove-IscsiServerTarget +Get-IscsiVirtualDisk -ComputerName $env:computername | Remove-IscsiVirtualDisk # Stop iSCSI initiator Get-Service "MsiSCSI" | Stop-Service @@ -173,7 +173,12 @@ func runPowershellScript(script string) (string, error) { } func setupEnv(targetName string) (*IscsiSetupConfig, error) { - script := fmt.Sprintf(IscsiEnvironmentSetupScript, targetName) + ethernetName := "Ethernet" + if val, ok := os.LookupEnv("ETHERNET_NAME"); ok { + ethernetName = val + } + + script := fmt.Sprintf(IscsiEnvironmentSetupScript, targetName, ethernetName) out, err := runPowershellScript(script) if err != nil { return nil, fmt.Errorf("failed setting up environment. err=%v", err) diff --git a/pkg/cim/disk.go b/pkg/cim/disk.go index 0b03c2ac..58c8f376 100644 --- a/pkg/cim/disk.go +++ b/pkg/cim/disk.go @@ -23,6 +23,17 @@ const ( // GPTPartitionTypeMicrosoftReserved is the GUID for Microsoft Reserved Partition (MSR) // Reserved by Windows for system use GPTPartitionTypeMicrosoftReserved = "{e3c9e316-0b5c-4db8-817d-f92df00215ae}" + + // ErrorCodeCreatePartitionAccessPathAlreadyInUse is the error code (42002) returned when the driver letter failed to assign after partition created + ErrorCodeCreatePartitionAccessPathAlreadyInUse = 42002 +) + +var ( + DiskSelectorListForDiskNumberAndLocation = []string{"Number", "Location"} + DiskSelectorListForPartitionStyle = []string{"PartitionStyle"} + DiskSelectorListForPathAndSerialNumber = []string{"Path", "SerialNumber"} + DiskSelectorListForIsOffline = []string{"IsOffline"} + DiskSelectorListForSize = []string{"Size"} ) // QueryDiskByNumber retrieves disk information for a specific disk identified by its number. @@ -76,3 +87,104 @@ func ListDisks(selectorList []string) ([]*storage.MSFT_Disk, error) { return disks, nil } + +// InitializeDisk initializes a RAW disk with a particular partition style. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/initialize-msft-disk +// for the WMI method definition. +func InitializeDisk(disk *storage.MSFT_Disk, partitionStyle int) (int, error) { + result, err := disk.InvokeMethodWithReturn("Initialize", int32(partitionStyle)) + return int(result), err +} + +// RefreshDisk Refreshes the cached disk layout information. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-refresh +// for the WMI method definition. +func RefreshDisk(disk *storage.MSFT_Disk) (int, string, error) { + var status string + result, err := disk.InvokeMethodWithReturn("Refresh", &status) + return int(result), status, err +} + +// CreatePartition creates a partition on a disk. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/createpartition-msft-disk +// for the WMI method definition. +func CreatePartition(disk *storage.MSFT_Disk, params ...interface{}) (int, error) { + result, err := disk.InvokeMethodWithReturn("CreatePartition", params...) + return int(result), err +} + +// SetDiskState takes a disk online or offline. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-online and +// https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-offline +// for the WMI method definition. +func SetDiskState(disk *storage.MSFT_Disk, online bool) (int, string, error) { + method := "Offline" + if online { + method = "Online" + } + + var status string + result, err := disk.InvokeMethodWithReturn(method, &status) + return int(result), status, err +} + +// RescanDisks rescans all changes by updating the internal cache of software objects (that is, Disks, Partitions, Volumes) +// for the storage setting. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-storagesetting-updatehoststoragecache +// for the WMI method definition. +func RescanDisks() (int, error) { + result, _, err := InvokeCimMethod(WMINamespaceStorage, "MSFT_StorageSetting", "UpdateHostStorageCache", nil) + return result, err +} + +// GetDiskNumber returns the number of a disk. +func GetDiskNumber(disk *storage.MSFT_Disk) (uint32, error) { + number, err := disk.GetProperty("Number") + if err != nil { + return 0, err + } + return uint32(number.(int32)), err +} + +// GetDiskLocation returns the location of a disk. +func GetDiskLocation(disk *storage.MSFT_Disk) (string, error) { + return disk.GetPropertyLocation() +} + +// GetDiskPartitionStyle returns the partition style of a disk. +func GetDiskPartitionStyle(disk *storage.MSFT_Disk) (int32, error) { + retValue, err := disk.GetProperty("PartitionStyle") + if err != nil { + return 0, err + } + return retValue.(int32), err +} + +// IsDiskOffline returns whether a disk is offline. +func IsDiskOffline(disk *storage.MSFT_Disk) (bool, error) { + return disk.GetPropertyIsOffline() +} + +// GetDiskSize returns the size of a disk. +func GetDiskSize(disk *storage.MSFT_Disk) (int64, error) { + sz, err := disk.GetProperty("Size") + if err != nil { + return -1, err + } + return strconv.ParseInt(sz.(string), 10, 64) +} + +// GetDiskPath returns the path of a disk. +func GetDiskPath(disk *storage.MSFT_Disk) (string, error) { + return disk.GetPropertyPath() +} + +// GetDiskSerialNumber returns the serial number of a disk. +func GetDiskSerialNumber(disk *storage.MSFT_Disk) (string, error) { + return disk.GetPropertySerialNumber() +} diff --git a/pkg/cim/iscsi.go b/pkg/cim/iscsi.go new file mode 100644 index 00000000..3a4e2541 --- /dev/null +++ b/pkg/cim/iscsi.go @@ -0,0 +1,362 @@ +//go:build windows +// +build windows + +package cim + +import ( + "fmt" + "strconv" + + "github.com/microsoft/wmi/pkg/base/query" + "github.com/microsoft/wmi/server2019/root/microsoft/windows/storage" +) + +var ( + ISCSITargetPortalDefaultSelectorList = []string{"TargetPortalAddress", "TargetPortalPortNumber"} +) + +// ListISCSITargetPortals retrieves a list of iSCSI target portals. +// +// The equivalent WMI query is: +// +// SELECT [selectors] FROM MSFT_IscsiTargetPortal +// +// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitargetportal +// for the WMI class definition. +func ListISCSITargetPortals(selectorList []string) ([]*storage.MSFT_iSCSITargetPortal, error) { + q := query.NewWmiQueryWithSelectList("MSFT_IscsiTargetPortal", selectorList) + instances, err := QueryInstances(WMINamespaceStorage, q) + if IgnoreNotFound(err) != nil { + return nil, err + } + + var targetPortals []*storage.MSFT_iSCSITargetPortal + for _, instance := range instances { + portal, err := storage.NewMSFT_iSCSITargetPortalEx1(instance) + if err != nil { + return nil, fmt.Errorf("failed to query iSCSI target portal %v. error: %v", instance, err) + } + + targetPortals = append(targetPortals, portal) + } + + return targetPortals, nil +} + +// QueryISCSITargetPortal retrieves information about a specific iSCSI target portal +// identified by its network address and port number. +// +// The equivalent WMI query is: +// +// SELECT [selectors] FROM MSFT_IscsiTargetPortal +// WHERE TargetPortalAddress = '
' +// AND TargetPortalPortNumber = '' +// +// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitargetportal +// for the WMI class definition. +func QueryISCSITargetPortal(address string, port uint32, selectorList []string) (*storage.MSFT_iSCSITargetPortal, error) { + portalQuery := query.NewWmiQueryWithSelectList( + "MSFT_iSCSITargetPortal", selectorList, + "TargetPortalAddress", address, + "TargetPortalPortNumber", strconv.Itoa(int(port))) + instances, err := QueryInstances(WMINamespaceStorage, portalQuery) + if err != nil { + return nil, err + } + + targetPortal, err := storage.NewMSFT_iSCSITargetPortalEx1(instances[0]) + if err != nil { + return nil, fmt.Errorf("failed to query iSCSI target portal at (%s:%d). error: %v", address, port, err) + } + + return targetPortal, nil +} + +// ListISCSITargetsByTargetPortalAddressAndPort retrieves ISCSI targets by address and port of an iSCSI target portal. +func ListISCSITargetsByTargetPortalAddressAndPort(address string, port uint32, selectorList []string) ([]*storage.MSFT_iSCSITarget, error) { + instance, err := QueryISCSITargetPortal(address, port, selectorList) + if err != nil { + return nil, err + } + + targets, err := ListISCSITargetsByTargetPortal([]*storage.MSFT_iSCSITargetPortal{instance}) + return targets, err +} + +// NewISCSITargetPortal creates a new iSCSI target portal. +// +// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitargetportal-new +// for the WMI method definition. +func NewISCSITargetPortal(targetPortalAddress string, + targetPortalPortNumber uint32, + initiatorInstanceName *string, + initiatorPortalAddress *string, + isHeaderDigest *bool, + isDataDigest *bool) (*storage.MSFT_iSCSITargetPortal, error) { + params := map[string]interface{}{ + "TargetPortalAddress": targetPortalAddress, + "TargetPortalPortNumber": targetPortalPortNumber, + } + if initiatorInstanceName != nil { + params["InitiatorInstanceName"] = *initiatorInstanceName + } + if initiatorPortalAddress != nil { + params["InitiatorPortalAddress"] = *initiatorPortalAddress + } + if isHeaderDigest != nil { + params["IsHeaderDigest"] = *isHeaderDigest + } + if isDataDigest != nil { + params["IsDataDigest"] = *isDataDigest + } + result, _, err := InvokeCimMethod(WMINamespaceStorage, "MSFT_iSCSITargetPortal", "New", params) + if err != nil { + return nil, fmt.Errorf("failed to create iSCSI target portal with %v. result: %d, error: %v", params, result, err) + } + + return QueryISCSITargetPortal(targetPortalAddress, targetPortalPortNumber, nil) +} + +// ParseISCSITargetPortal retrieves the portal address and port number of an iSCSI target portal. +func ParseISCSITargetPortal(instance *storage.MSFT_iSCSITargetPortal) (string, uint32, error) { + portalAddress, err := instance.GetPropertyTargetPortalAddress() + if err != nil { + return "", 0, fmt.Errorf("failed parsing target portal address %v. err: %w", instance, err) + } + + portalPort, err := instance.GetProperty("TargetPortalPortNumber") + if err != nil { + return "", 0, fmt.Errorf("failed parsing target portal port number %v. err: %w", instance, err) + } + + return portalAddress, uint32(portalPort.(int32)), nil +} + +// RemoveISCSITargetPortal removes an iSCSI target portal. +// +// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitargetportal-remove +// for the WMI method definition. +func RemoveISCSITargetPortal(instance *storage.MSFT_iSCSITargetPortal) (int, error) { + address, port, err := ParseISCSITargetPortal(instance) + if err != nil { + return 0, fmt.Errorf("failed to parse target portal %v. error: %v", instance, err) + } + + result, err := instance.InvokeMethodWithReturn("Remove", + nil, + nil, + int(port), + address, + ) + return int(result), err +} + +// ListISCSITargetsByTargetPortal retrieves all iSCSI targets from the specified iSCSI target portal +// using MSFT_iSCSITargetToiSCSITargetPortal association. +// +// WMI association MSFT_iSCSITargetToiSCSITargetPortal: +// +// iSCSITarget | iSCSITargetPortal +// ----------- | ----------------- +// MSFT_iSCSITarget (NodeAddress = "iqn.1991-05.com.microsoft:win-8e2evaq9q...) | MSFT_iSCSITargetPortal (TargetPortalAdd... +// +// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitarget +// for the WMI class definition. +func ListISCSITargetsByTargetPortal(portals []*storage.MSFT_iSCSITargetPortal) ([]*storage.MSFT_iSCSITarget, error) { + var targets []*storage.MSFT_iSCSITarget + for _, portal := range portals { + collection, err := portal.GetAssociated("MSFT_iSCSITargetToiSCSITargetPortal", "MSFT_iSCSITarget", "iSCSITarget", "iSCSITargetPortal") + if err != nil { + return nil, fmt.Errorf("failed to query associated iSCSITarget for %v. error: %v", portal, err) + } + + for _, instance := range collection { + target, err := storage.NewMSFT_iSCSITargetEx1(instance) + if err != nil { + return nil, fmt.Errorf("failed to query iSCSI target %v. error: %v", instance, err) + } + + targets = append(targets, target) + } + } + + return targets, nil +} + +// QueryISCSITarget retrieves the iSCSI target from the specified portal address, portal and node address. +func QueryISCSITarget(address string, port uint32, nodeAddress string) (*storage.MSFT_iSCSITarget, error) { + portal, err := QueryISCSITargetPortal(address, port, nil) + if err != nil { + return nil, err + } + + targets, err := ListISCSITargetsByTargetPortal([]*storage.MSFT_iSCSITargetPortal{portal}) + if err != nil { + return nil, err + } + + for _, target := range targets { + targetNodeAddress, err := GetISCSITargetNodeAddress(target) + if err != nil { + return nil, fmt.Errorf("failed to query iSCSI target %v. error: %v", target, err) + } + + if targetNodeAddress == nodeAddress { + return target, nil + } + } + + return nil, nil +} + +// GetISCSITargetNodeAddress returns the node address of an iSCSI target. +func GetISCSITargetNodeAddress(target *storage.MSFT_iSCSITarget) (string, error) { + nodeAddress, err := target.GetProperty("NodeAddress") + if err != nil { + return "", err + } + + return nodeAddress.(string), err +} + +// IsISCSITargetConnected returns whether the iSCSI target is connected. +func IsISCSITargetConnected(target *storage.MSFT_iSCSITarget) (bool, error) { + return target.GetPropertyIsConnected() +} + +// QueryISCSISessionByTarget retrieves the iSCSI session from the specified iSCSI target +// using MSFT_iSCSITargetToiSCSISession association. +// +// WMI association MSFT_iSCSITargetToiSCSISession: +// +// iSCSISession | iSCSITarget +// ------------ | ----------- +// MSFT_iSCSISession (SessionIdentifier = "ffffac0cacbff010-4000013700000016") | MSFT_iSCSITarget (NodeAddress = "iqn.199... +// +// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsisession +// for the WMI class definition. +func QueryISCSISessionByTarget(target *storage.MSFT_iSCSITarget) (*storage.MSFT_iSCSISession, error) { + collection, err := target.GetAssociated("MSFT_iSCSITargetToiSCSISession", "MSFT_iSCSISession", "iSCSISession", "iSCSITarget") + if err != nil { + return nil, fmt.Errorf("failed to query associated iSCSISession for %v. error: %v", target, err) + } + + if len(collection) == 0 { + return nil, nil + } + + session, err := storage.NewMSFT_iSCSISessionEx1(collection[0]) + return session, err +} + +// UnregisterISCSISession unregisters the iSCSI session so that it is no longer persistent. +// +// Refer https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsisession-unregister +// for the WMI method definition. +func UnregisterISCSISession(session *storage.MSFT_iSCSISession) (int, error) { + result, err := session.InvokeMethodWithReturn("Unregister") + return int(result), err +} + +// SetISCSISessionChapSecret sets a CHAP secret key for use with iSCSI initiator connections. +// +// Refer https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitarget-disconnect +// for the WMI method definition. +func SetISCSISessionChapSecret(mutualChapSecret string) (int, error) { + result, _, err := InvokeCimMethod(WMINamespaceStorage, "MSFT_iSCSISession", "SetCHAPSecret", map[string]interface{}{"ChapSecret": mutualChapSecret}) + return result, err +} + +// GetISCSISessionIdentifier returns the identifier of an iSCSI session. +func GetISCSISessionIdentifier(session *storage.MSFT_iSCSISession) (string, error) { + return session.GetPropertySessionIdentifier() +} + +// IsISCSISessionPersistent returns whether an iSCSI session is persistent. +func IsISCSISessionPersistent(session *storage.MSFT_iSCSISession) (bool, error) { + return session.GetPropertyIsPersistent() +} + +// ListDisksByTarget find all disks associated with an iSCSITarget. +// It finds out the iSCSIConnections from MSFT_iSCSITargetToiSCSIConnection association, +// then locate MSFT_Disk objects from MSFT_iSCSIConnectionToDisk association. +// +// WMI association MSFT_iSCSITargetToiSCSIConnection: +// +// iSCSIConnection | iSCSITarget +// --------------- | ----------- +// MSFT_iSCSIConnection (ConnectionIdentifier = "ffffac0cacbff010-15") | MSFT_iSCSITarget (NodeAddress = "iqn.1991-05.com... +// +// WMI association MSFT_iSCSIConnectionToDisk: +// +// Disk | iSCSIConnection +// ---- | --------------- +// MSFT_Disk (ObjectId = "{1}\\WIN-8E2EVAQ9QSB\root/Microsoft/Win...) | MSFT_iSCSIConnection (ConnectionIdentifier = "fff... +// +// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsiconnection +// for the WMI class definition. +func ListDisksByTarget(target *storage.MSFT_iSCSITarget) ([]*storage.MSFT_Disk, error) { + // list connections to the given iSCSI target + collection, err := target.GetAssociated("MSFT_iSCSITargetToiSCSIConnection", "MSFT_iSCSIConnection", "iSCSIConnection", "iSCSITarget") + if err != nil { + return nil, fmt.Errorf("failed to query associated iSCSISession for %v. error: %v", target, err) + } + + if len(collection) == 0 { + return nil, nil + } + + var result []*storage.MSFT_Disk + for _, conn := range collection { + instances, err := conn.GetAssociated("MSFT_iSCSIConnectionToDisk", "MSFT_Disk", "Disk", "iSCSIConnection") + if err != nil { + return nil, fmt.Errorf("failed to query associated disk for %v. error: %v", target, err) + } + + for _, instance := range instances { + disk, err := storage.NewMSFT_DiskEx1(instance) + if err != nil { + return nil, fmt.Errorf("failed to query associated disk %v. error: %v", instance, err) + } + + result = append(result, disk) + } + } + + return result, err +} + +// ConnectISCSITarget establishes a connection to an iSCSI target with optional CHAP authentication credential. +// +// Refer https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitarget-connect +// for the WMI method definition. +func ConnectISCSITarget(portalAddress string, portalPortNumber uint32, nodeAddress string, authType string, chapUsername *string, chapSecret *string) (int, error) { + inParams := map[string]interface{}{ + "NodeAddress": nodeAddress, + "TargetPortalAddress": portalAddress, + "TargetPortalPortNumber": int(portalPortNumber), + "AuthenticationType": authType, + } + // InitiatorPortalAddress + // IsDataDigest + // IsHeaderDigest + // ReportToPnP + if chapUsername != nil { + inParams["ChapUsername"] = *chapUsername + } + if chapSecret != nil { + inParams["ChapSecret"] = *chapSecret + } + + result, _, err := InvokeCimMethod(WMINamespaceStorage, "MSFT_iSCSITarget", "Connect", inParams) + return result, err +} + +// DisconnectISCSITarget disconnects the specified session between an iSCSI initiator and an iSCSI target. +// +// Refer https://learn.microsoft.com/en-us/previous-versions/windows/desktop/iscsidisc/msft-iscsitarget-disconnect +// for the WMI method definition. +func DisconnectISCSITarget(target *storage.MSFT_iSCSITarget, sessionIdentifier string) (int, error) { + result, err := target.InvokeMethodWithReturn("Disconnect", sessionIdentifier) + return int(result), err +} diff --git a/pkg/cim/smb.go b/pkg/cim/smb.go index 5868d456..2850ab78 100644 --- a/pkg/cim/smb.go +++ b/pkg/cim/smb.go @@ -4,6 +4,8 @@ package cim import ( + "strings" + "github.com/microsoft/wmi/pkg/base/query" cim "github.com/microsoft/wmi/pkg/wmiinstance" ) @@ -17,8 +19,24 @@ const ( SmbMappingStatusConnecting SmbMappingStatusReconnecting SmbMappingStatusUnavailable + + credentialDelimiter = ":" ) +// escapeQueryParameter escapes a parameter for WMI Queries +func escapeQueryParameter(s string) string { + s = strings.ReplaceAll(s, "'", "''") + s = strings.ReplaceAll(s, "\\", "\\\\") + return s +} + +func escapeUserName(userName string) string { + // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170 + userName = strings.ReplaceAll(userName, "\\", "\\\\") + userName = strings.ReplaceAll(userName, credentialDelimiter, "\\"+credentialDelimiter) + return userName +} + // QuerySmbGlobalMappingByRemotePath retrieves the SMB global mapping from its remote path. // // The equivalent WMI query is: @@ -28,7 +46,7 @@ const ( // Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping // for the WMI class definition. func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, error) { - smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath) + smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath)) instances, err := QueryInstances(WMINamespaceSmb, smbQuery) if err != nil { return nil, err @@ -37,12 +55,22 @@ func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, err return instances[0], err } -// RemoveSmbGlobalMappingByRemotePath removes a SMB global mapping matching to the remote path. +// GetSmbGlobalMappingStatus returns the status of an SMB global mapping. +func GetSmbGlobalMappingStatus(inst *cim.WmiInstance) (int32, error) { + statusProp, err := inst.GetProperty("Status") + if err != nil { + return SmbMappingStatusUnavailable, err + } + + return statusProp.(int32), nil +} + +// RemoveSmbGlobalMappingByRemotePath removes an SMB global mapping matching to the remote path. // // Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping // for the WMI class definition. func RemoveSmbGlobalMappingByRemotePath(remotePath string) error { - smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath) + smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath)) instances, err := QueryInstances(WMINamespaceSmb, smbQuery) if err != nil { return err @@ -51,3 +79,22 @@ func RemoveSmbGlobalMappingByRemotePath(remotePath string) error { _, err = instances[0].InvokeMethod("Remove", true) return err } + +// NewSmbGlobalMapping creates a new SMB global mapping to the remote path. +// +// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping +// for the WMI class definition. +func NewSmbGlobalMapping(remotePath, username, password string, requirePrivacy bool) (int, error) { + params := map[string]interface{}{ + "RemotePath": remotePath, + "RequirePrivacy": requirePrivacy, + } + if username != "" { + // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178 + // on how SMB credential is handled in PowerShell + params["Credential"] = escapeUserName(username) + credentialDelimiter + password + } + + result, _, err := InvokeCimMethod(WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params) + return result, err +} diff --git a/pkg/cim/system.go b/pkg/cim/system.go index 3ab32af6..80181794 100644 --- a/pkg/cim/system.go +++ b/pkg/cim/system.go @@ -10,6 +10,22 @@ import ( "github.com/microsoft/wmi/server2019/root/cimv2" ) +var ( + BIOSSelectorList = []string{"SerialNumber"} + ServiceSelectorList = []string{"DisplayName", "State", "StartMode"} +) + +type ServiceInterface interface { + GetPropertyName() (string, error) + GetPropertyDisplayName() (string, error) + GetPropertyState() (string, error) + GetPropertyStartMode() (string, error) + GetDependents() ([]ServiceInterface, error) + StartService() (result uint32, err error) + StopService() (result uint32, err error) + Refresh() error +} + // QueryBIOSElement retrieves the BIOS element. // // The equivalent WMI query is: @@ -33,6 +49,11 @@ func QueryBIOSElement(selectorList []string) (*cimv2.CIM_BIOSElement, error) { return bios, err } +// GetBIOSSerialNumber returns the BIOS serial number. +func GetBIOSSerialNumber(bios *cimv2.CIM_BIOSElement) (string, error) { + return bios.GetPropertySerialNumber() +} + // QueryServiceByName retrieves a specific service by its name. // // The equivalent WMI query is: @@ -55,3 +76,60 @@ func QueryServiceByName(name string, selectorList []string) (*cimv2.Win32_Servic return service, err } + +// GetServiceName returns the name of a service. +func GetServiceName(service ServiceInterface) (string, error) { + return service.GetPropertyName() +} + +// GetServiceDisplayName returns the display name of a service. +func GetServiceDisplayName(service ServiceInterface) (string, error) { + return service.GetPropertyDisplayName() +} + +// GetServiceState returns the state of a service. +func GetServiceState(service ServiceInterface) (string, error) { + return service.GetPropertyState() +} + +// GetServiceStartMode returns the start mode of a service. +func GetServiceStartMode(service ServiceInterface) (string, error) { + return service.GetPropertyStartMode() +} + +// Win32Service wraps the WMI class Win32_Service (mainly for testing) +type Win32Service struct { + *cimv2.Win32_Service +} + +func (s *Win32Service) GetDependents() ([]ServiceInterface, error) { + collection, err := s.GetAssociated("Win32_DependentService", "Win32_Service", "Dependent", "Antecedent") + if err != nil { + return nil, err + } + + var result []ServiceInterface + for _, coll := range collection { + service, err := cimv2.NewWin32_ServiceEx1(coll) + if err != nil { + return nil, err + } + + result = append(result, &Win32Service{ + service, + }) + } + return result, nil +} + +type Win32ServiceFactory struct { +} + +func (impl Win32ServiceFactory) GetService(name string) (ServiceInterface, error) { + service, err := QueryServiceByName(name, ServiceSelectorList) + if err != nil { + return nil, err + } + + return &Win32Service{Win32_Service: service}, nil +} diff --git a/pkg/cim/volume.go b/pkg/cim/volume.go index 085289ba..0c880fe1 100644 --- a/pkg/cim/volume.go +++ b/pkg/cim/volume.go @@ -7,10 +7,23 @@ import ( "fmt" "strconv" + "github.com/go-ole/go-ole" "github.com/microsoft/wmi/pkg/base/query" "github.com/microsoft/wmi/pkg/errors" - cim "github.com/microsoft/wmi/pkg/wmiinstance" "github.com/microsoft/wmi/server2019/root/microsoft/windows/storage" + "k8s.io/klog/v2" +) + +const ( + FileSystemUnknown = 0 +) + +var ( + VolumeSelectorListForFileSystemType = []string{"FileSystemType"} + VolumeSelectorListForStats = []string{"UniqueId", "SizeRemaining", "Size"} + VolumeSelectorListUniqueID = []string{"UniqueId"} + + PartitionSelectorListObjectID = []string{"ObjectId"} ) // QueryVolumeByUniqueID retrieves a specific volume by its unique identifier, @@ -79,6 +92,68 @@ func ListVolumes(selectorList []string) ([]*storage.MSFT_Volume, error) { return volumes, nil } +// FormatVolume formats the specified volume. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/format-msft-volume +// for the WMI method definition. +func FormatVolume(volume *storage.MSFT_Volume, params ...interface{}) (int, error) { + result, err := volume.InvokeMethodWithReturn("Format", params...) + return int(result), err +} + +// FlushVolume flushes the cached data in the volume's file system to disk. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-volume-flush +// for the WMI method definition. +func FlushVolume(volume *storage.MSFT_Volume) (int, error) { + result, err := volume.Flush() + return int(result), err +} + +// GetVolumeUniqueID returns the unique ID (object ID) of a volume. +func GetVolumeUniqueID(volume *storage.MSFT_Volume) (string, error) { + return volume.GetPropertyUniqueId() +} + +// GetVolumeFileSystemType returns the file system type of a volume. +func GetVolumeFileSystemType(volume *storage.MSFT_Volume) (int32, error) { + fsType, err := volume.GetProperty("FileSystemType") + if err != nil { + return 0, err + } + return fsType.(int32), nil +} + +// GetVolumeSize returns the size of a volume. +func GetVolumeSize(volume *storage.MSFT_Volume) (int64, error) { + volumeSizeVal, err := volume.GetProperty("Size") + if err != nil { + return -1, err + } + + volumeSize, err := strconv.ParseInt(volumeSizeVal.(string), 10, 64) + if err != nil { + return -1, err + } + + return volumeSize, err +} + +// GetVolumeSizeRemaining returns the remaining size of a volume. +func GetVolumeSizeRemaining(volume *storage.MSFT_Volume) (int64, error) { + volumeSizeRemainingVal, err := volume.GetProperty("SizeRemaining") + if err != nil { + return -1, err + } + + volumeSizeRemaining, err := strconv.ParseInt(volumeSizeRemainingVal.(string), 10, 64) + if err != nil { + return -1, err + } + + return volumeSizeRemaining, err +} + // ListPartitionsOnDisk retrieves all partitions or a partition with the specified number on a disk. // // The equivalent WMI query is: @@ -129,131 +204,84 @@ func ListPartitionsWithFilters(selectorList []string, filters ...*query.WmiQuery return partitions, nil } -// ListPartitionToVolumeMappings builds a mapping between partition and volume with partition Object ID as the key. -// -// The equivalent WMI query is: -// -// SELECT [selectors] FROM MSFT_PartitionToVolume -// -// Partition | Volume -// --------- | ------ -// MSFT_Partition (ObjectId = "{1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Win...) | MSFT_Volume (ObjectId = "{1}\\WIN-8E2EVAQ9QS... -// -// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partitiontovolume -// for the WMI class definition. -func ListPartitionToVolumeMappings() (map[string]string, error) { - return ListWMIInstanceMappings(WMINamespaceStorage, "MSFT_PartitionToVolume", nil, - mappingObjectRefIndexer("Partition", "MSFT_Partition", "ObjectId"), - mappingObjectRefIndexer("Volume", "MSFT_Volume", "ObjectId"), - ) -} - -// ListVolumeToPartitionMappings builds a mapping between volume and partition with volume Object ID as the key. -// -// The equivalent WMI query is: +// FindPartitionsByVolume finds all partitions associated with the given volumes +// using MSFT_PartitionToVolume association. // -// SELECT [selectors] FROM MSFT_PartitionToVolume +// WMI association MSFT_PartitionToVolume: // -// Partition | Volume -// --------- | ------ -// MSFT_Partition (ObjectId = "{1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Win...) | MSFT_Volume (ObjectId = "{1}\\WIN-8E2EVAQ9QS... +// Partition | Volume +// --------- | ------ +// MSFT_Partition (ObjectId = "{1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Win...) | MSFT_Volume (ObjectId = "{1}\\WIN-8E2EVAQ9QS... // // Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partitiontovolume // for the WMI class definition. -func ListVolumeToPartitionMappings() (map[string]string, error) { - return ListWMIInstanceMappings(WMINamespaceStorage, "MSFT_PartitionToVolume", nil, - mappingObjectRefIndexer("Volume", "MSFT_Volume", "ObjectId"), - mappingObjectRefIndexer("Partition", "MSFT_Partition", "ObjectId"), - ) -} - -// FindPartitionsByVolume finds all partitions associated with the given volumes -// using partition-to-volume mapping. -func FindPartitionsByVolume(partitions []*storage.MSFT_Partition, volumes []*storage.MSFT_Volume) ([]*storage.MSFT_Partition, error) { - var partitionInstances []*cim.WmiInstance - for _, part := range partitions { - partitionInstances = append(partitionInstances, part.WmiInstance) - } - - var volumeInstances []*cim.WmiInstance - for _, volume := range volumes { - volumeInstances = append(volumeInstances, volume.WmiInstance) - } - - partitionToVolumeMappings, err := ListPartitionToVolumeMappings() - if err != nil { - return nil, err - } - - filtered, err := FindInstancesByObjectIDMapping(partitionInstances, volumeInstances, partitionToVolumeMappings) - if err != nil { - return nil, err - } - +func FindPartitionsByVolume(volumes []*storage.MSFT_Volume) ([]*storage.MSFT_Partition, error) { var result []*storage.MSFT_Partition - for _, instance := range filtered { - part, err := storage.NewMSFT_PartitionEx1(instance) + for _, vol := range volumes { + collection, err := vol.GetAssociated("MSFT_PartitionToVolume", "MSFT_Partition", "Partition", "Volume") if err != nil { - return nil, fmt.Errorf("failed to query partition %v. error: %v", instance, err) + return nil, fmt.Errorf("failed to query associated partition for %v. error: %v", vol, err) } - result = append(result, part) + for _, instance := range collection { + part, err := storage.NewMSFT_PartitionEx1(instance) + if err != nil { + return nil, fmt.Errorf("failed to query partition %v. error: %v", instance, err) + } + + result = append(result, part) + } } return result, nil } // FindVolumesByPartition finds all volumes associated with the given partitions -// using volume-to-partition mapping. -func FindVolumesByPartition(volumes []*storage.MSFT_Volume, partitions []*storage.MSFT_Partition) ([]*storage.MSFT_Volume, error) { - var volumeInstances []*cim.WmiInstance - for _, volume := range volumes { - volumeInstances = append(volumeInstances, volume.WmiInstance) - } - - var partitionInstances []*cim.WmiInstance - for _, part := range partitions { - partitionInstances = append(partitionInstances, part.WmiInstance) - } - - volumeToPartitionMappings, err := ListVolumeToPartitionMappings() - if err != nil { - return nil, err - } - - filtered, err := FindInstancesByObjectIDMapping(volumeInstances, partitionInstances, volumeToPartitionMappings) - if err != nil { - return nil, err - } - +// using MSFT_PartitionToVolume association. +// +// WMI association MSFT_PartitionToVolume: +// +// Partition | Volume +// --------- | ------ +// MSFT_Partition (ObjectId = "{1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Win...) | MSFT_Volume (ObjectId = "{1}\\WIN-8E2EVAQ9QS... +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partitiontovolume +// for the WMI class definition. +func FindVolumesByPartition(partitions []*storage.MSFT_Partition) ([]*storage.MSFT_Volume, error) { var result []*storage.MSFT_Volume - for _, instance := range filtered { - volume, err := storage.NewMSFT_VolumeEx1(instance) + for _, part := range partitions { + collection, err := part.GetAssociated("MSFT_PartitionToVolume", "MSFT_Volume", "Volume", "Partition") if err != nil { - return nil, fmt.Errorf("failed to query volume %v. error: %v", instance, err) + return nil, fmt.Errorf("failed to query associated volumes for %v. error: %v", part, err) } - result = append(result, volume) + for _, instance := range collection { + volume, err := storage.NewMSFT_VolumeEx1(instance) + if err != nil { + return nil, fmt.Errorf("failed to query volume %v. error: %v", instance, err) + } + + result = append(result, volume) + } } return result, nil } // GetPartitionByVolumeUniqueID retrieves a specific partition from a volume identified by its unique ID. -func GetPartitionByVolumeUniqueID(volumeID string, partitionSelectorList []string) (*storage.MSFT_Partition, error) { +func GetPartitionByVolumeUniqueID(volumeID string) (*storage.MSFT_Partition, error) { volume, err := QueryVolumeByUniqueID(volumeID, []string{"ObjectId"}) if err != nil { return nil, err } - partitions, err := ListPartitionsWithFilters(partitionSelectorList) + result, err := FindPartitionsByVolume([]*storage.MSFT_Volume{volume}) if err != nil { return nil, err } - result, err := FindPartitionsByVolume(partitions, []*storage.MSFT_Volume{volume}) - if err != nil { - return nil, err + if len(result) == 0 { + return nil, errors.NotFound } return result[0], nil @@ -269,12 +297,7 @@ func GetVolumeByDriveLetter(driveLetter string, partitionSelectorList []string) return nil, err } - volumes, err := ListVolumes(partitionSelectorList) - if err != nil { - return nil, err - } - - result, err := FindVolumesByPartition(volumes, partitions) + result, err := FindVolumesByPartition(partitions) if err != nil { return nil, err } @@ -298,3 +321,78 @@ func GetPartitionDiskNumber(part *storage.MSFT_Partition) (uint32, error) { return uint32(diskNumber.(int32)), nil } + +// SetPartitionState takes a partition online or offline. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-online and +// https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-offline +// for the WMI method definition. +func SetPartitionState(part *storage.MSFT_Partition, online bool) (int, string, error) { + method := "Offline" + if online { + method = "Online" + } + + var status string + result, err := part.InvokeMethodWithReturn(method, &status) + return int(result), status, err +} + +// GetPartitionSupportedSize retrieves the minimum and maximum sizes that the partition can be resized to using the ResizePartition method. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-getsupportedsizes +// for the WMI method definition. +func GetPartitionSupportedSize(part *storage.MSFT_Partition) (result int, sizeMin, sizeMax int64, status string, err error) { + sizeMin = -1 + sizeMax = -1 + + var sizeMinVar, sizeMaxVar ole.VARIANT + invokeResult, err := part.InvokeMethodWithReturn("GetSupportedSize", &sizeMinVar, &sizeMaxVar, &status) + if invokeResult != 0 || err != nil { + result = int(invokeResult) + } + klog.V(5).Infof("got sizeMin (%v) sizeMax (%v) from partition (%v), status: %s", sizeMinVar, sizeMaxVar, part, status) + + sizeMin, err = strconv.ParseInt(sizeMinVar.ToString(), 10, 64) + if err != nil { + return + } + + sizeMax, err = strconv.ParseInt(sizeMaxVar.ToString(), 10, 64) + return +} + +// ResizePartition resizes a partition. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-resize +// for the WMI method definition. +func ResizePartition(part *storage.MSFT_Partition, size int64) (int, string, error) { + var status string + result, err := part.InvokeMethodWithReturn("Resize", strconv.Itoa(int(size)), &status) + return int(result), status, err +} + +// GetPartitionSize returns the size of a partition. +func GetPartitionSize(part *storage.MSFT_Partition) (int64, error) { + sizeProp, err := part.GetProperty("Size") + if err != nil { + return -1, err + } + + size, err := strconv.ParseInt(sizeProp.(string), 10, 64) + if err != nil { + return -1, err + } + + return size, err +} + +// FilterForPartitionOnDisk creates a WMI query filter to query a disk by its number. +func FilterForPartitionOnDisk(diskNumber uint32) *query.WmiQueryFilter { + return query.NewWmiQueryFilter("DiskNumber", strconv.Itoa(int(diskNumber)), query.Equals) +} + +// FilterForPartitionsOfTypeNormal creates a WMI query filter for all non-reserved partitions. +func FilterForPartitionsOfTypeNormal() *query.WmiQueryFilter { + return query.NewWmiQueryFilter("GptType", GPTPartitionTypeMicrosoftReserved, query.NotEquals) +} diff --git a/pkg/cim/wmi.go b/pkg/cim/wmi.go index 81e17701..ec9c8f08 100644 --- a/pkg/cim/wmi.go +++ b/pkg/cim/wmi.go @@ -4,33 +4,32 @@ package cim import ( + "errors" "fmt" - "strings" + "runtime" "github.com/go-ole/go-ole" "github.com/go-ole/go-ole/oleutil" "github.com/microsoft/wmi/pkg/base/query" - "github.com/microsoft/wmi/pkg/errors" + wmierrors "github.com/microsoft/wmi/pkg/errors" cim "github.com/microsoft/wmi/pkg/wmiinstance" + "golang.org/x/sys/windows" "k8s.io/klog/v2" ) const ( - WMINamespaceRoot = "Root\\CimV2" + WMINamespaceCimV2 = "Root\\CimV2" WMINamespaceStorage = "Root\\Microsoft\\Windows\\Storage" WMINamespaceSmb = "Root\\Microsoft\\Windows\\Smb" ) type InstanceHandler func(instance *cim.WmiInstance) (bool, error) -// An InstanceIndexer provides index key to a WMI Instance in a map -type InstanceIndexer func(instance *cim.WmiInstance) (string, error) - // NewWMISession creates a new local WMI session for the given namespace, defaulting // to root namespace if none specified. func NewWMISession(namespace string) (*cim.WmiSession, error) { if namespace == "" { - namespace = WMINamespaceRoot + namespace = WMINamespaceCimV2 } sessionManager := cim.NewWmiSessionManager() @@ -65,7 +64,7 @@ func QueryFromWMI(namespace string, query *query.WmiQuery, handler InstanceHandl } if len(instances) == 0 { - return errors.NotFound + return wmierrors.NotFound } var cont bool @@ -99,7 +98,7 @@ func executeClassMethodParam(classInst *cim.WmiInstance, method *cim.WmiMethod, iDispatchInstance := classInst.GetIDispatch() if iDispatchInstance == nil { - return nil, errors.Wrapf(errors.InvalidInput, "InvalidInstance") + return nil, wmierrors.Wrapf(wmierrors.InvalidInput, "InvalidInstance") } rawResult, err := iDispatchInstance.GetProperty("Methods_") if err != nil { @@ -239,130 +238,52 @@ func InvokeCimMethod(namespace, class, methodName string, inputParameters map[st return int(result.ReturnValue), outputParameters, nil } +// IsNotFound returns true if it's a "not found" error. +func IsNotFound(err error) bool { + return wmierrors.IsNotFound(err) +} + // IgnoreNotFound returns nil if the error is nil or a "not found" error, // otherwise returns the original error. func IgnoreNotFound(err error) error { - if err == nil || errors.IsNotFound(err) { + if err == nil || IsNotFound(err) { return nil } return err } -// parseObjectRef extracts the object ID from a WMI object reference string. -// The result string is in this format -// {1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Windows/Storage/Providers_v2\WSP_Partition.ObjectId="{b65bb3cd-da86-11ee-854b-806e6f6e6963}:PR:{00000000-0000-0000-0000-100000000000}\\?\scsi#disk&ven_vmware&prod_virtual_disk#4&2c28f6c4&0&000000#{53f56307-b6bf-11d0-94f2-00a0c91efb8b}" -// from an escape string -func parseObjectRef(input, objectClass, refName string) (string, error) { - tokens := strings.Split(input, fmt.Sprintf("%s.%s=", objectClass, refName)) - if len(tokens) < 2 { - return "", fmt.Errorf("invalid object ID value: %s", input) - } - - objectID := tokens[1] - objectID = strings.ReplaceAll(objectID, "\\\"", "\"") - objectID = strings.ReplaceAll(objectID, "\\\\", "\\") - objectID = objectID[1 : len(objectID)-1] - return objectID, nil -} - -// ListWMIInstanceMappings queries WMI instances and creates a map using custom indexing functions -// to extract keys and values from each instance. -func ListWMIInstanceMappings(namespace, mappingClassName string, selectorList []string, keyIndexer InstanceIndexer, valueIndexer InstanceIndexer) (map[string]string, error) { - q := query.NewWmiQueryWithSelectList(mappingClassName, selectorList) - mappingInstances, err := QueryInstances(namespace, q) - if err != nil { - return nil, err - } - - result := make(map[string]string) - for _, mapping := range mappingInstances { - key, err := keyIndexer(mapping) - if err != nil { - return nil, err - } - - value, err := valueIndexer(mapping) - if err != nil { - return nil, err - } - - result[key] = value - } - - return result, nil -} - -// FindInstancesByMapping filters instances based on a mapping relationship, -// matching instances through custom indexing and mapping functions. -func FindInstancesByMapping(instanceToFind []*cim.WmiInstance, instanceToFindIndex InstanceIndexer, associatedInstances []*cim.WmiInstance, associatedInstanceIndexer InstanceIndexer, instanceMappings map[string]string) ([]*cim.WmiInstance, error) { - associatedInstanceObjectIDMapping := map[string]*cim.WmiInstance{} - for _, inst := range associatedInstances { - key, err := associatedInstanceIndexer(inst) - if err != nil { - return nil, err - } - - associatedInstanceObjectIDMapping[key] = inst - } - - var filtered []*cim.WmiInstance - for _, inst := range instanceToFind { - key, err := instanceToFindIndex(inst) - if err != nil { - return nil, err - } - - valueObjectID, ok := instanceMappings[key] - if !ok { - continue - } - - _, ok = associatedInstanceObjectIDMapping[strings.ToUpper(valueObjectID)] - if !ok { - continue - } - filtered = append(filtered, inst) - } - - if len(filtered) == 0 { - return nil, errors.NotFound - } - - return filtered, nil -} - -// mappingObjectRefIndexer indexes an WMI object by the Object ID reference from a specified property. -func mappingObjectRefIndexer(propertyName, className, refName string) InstanceIndexer { - return func(instance *cim.WmiInstance) (string, error) { - valueVal, err := instance.GetProperty(propertyName) - if err != nil { - return "", err +// WithCOMThread runs the given function `fn` on a locked OS thread +// with COM initialized using COINIT_MULTITHREADED. +// +// This is necessary for using COM/OLE APIs directly (e.g., via go-ole), +// because COM requires that initialization and usage occur on the same thread. +// +// It performs the following steps: +// - Locks the current goroutine to its OS thread +// - Calls ole.CoInitializeEx with COINIT_MULTITHREADED +// - Executes the user-provided function +// - Uninitializes COM +// - Unlocks the thread +// +// If COM initialization fails, or if the user's function returns an error, +// that error is returned by WithCOMThread. +func WithCOMThread(fn func() error) error { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + if err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil { + var oleError *ole.OleError + if errors.As(err, &oleError) && oleError != nil && oleError.Code() == uintptr(windows.S_FALSE) { + klog.V(10).Infof("COM library has been already initialized for the calling thread, proceeding to the function with no error") + err = nil } - - refValue, err := parseObjectRef(valueVal.(string), className, refName) - return strings.ToUpper(refValue), err - } -} - -// stringPropertyIndexer indexes a WMI object from a string property. -func stringPropertyIndexer(propertyName string) InstanceIndexer { - return func(instance *cim.WmiInstance) (string, error) { - valueVal, err := instance.GetProperty(propertyName) if err != nil { - return "", err + return err } - - return strings.ToUpper(valueVal.(string)), err + } else { + klog.V(10).Infof("COM library is initialized for the calling thread") } -} - -var ( - // objectIDPropertyIndexer indexes a WMI object from its ObjectId property. - objectIDPropertyIndexer = stringPropertyIndexer("ObjectId") -) + defer ole.CoUninitialize() -// FindInstancesByObjectIDMapping filters instances based on ObjectId mapping -// between two sets of WMI instances. -func FindInstancesByObjectIDMapping(instanceToFind []*cim.WmiInstance, associatedInstances []*cim.WmiInstance, instanceMappings map[string]string) ([]*cim.WmiInstance, error) { - return FindInstancesByMapping(instanceToFind, objectIDPropertyIndexer, associatedInstances, objectIDPropertyIndexer, instanceMappings) + return fn() } diff --git a/pkg/os/disk/api.go b/pkg/os/disk/api.go index dc8637fd..5b46992c 100644 --- a/pkg/os/disk/api.go +++ b/pkg/os/disk/api.go @@ -3,14 +3,12 @@ package disk import ( "encoding/hex" "fmt" - "strconv" "strings" "syscall" "unsafe" "github.com/kubernetes-csi/csi-proxy/pkg/cim" shared "github.com/kubernetes-csi/csi-proxy/pkg/shared/disk" - "github.com/microsoft/wmi/pkg/base/query" "k8s.io/klog/v2" ) @@ -66,153 +64,162 @@ func New() DiskAPI { // ListDiskLocations - constructs a map with the disk number as the key and the DiskLocation structure // as the value. The DiskLocation struct has various fields like the Adapter, Bus, Target and LUNID. func (imp DiskAPI) ListDiskLocations() (map[uint32]shared.DiskLocation, error) { - // "location": "PCI Slot 3 : Adapter 0 : Port 0 : Target 1 : LUN 0" - disks, err := cim.ListDisks([]string{"Number", "Location"}) - if err != nil { - return nil, fmt.Errorf("could not query disk locations") - } - m := make(map[uint32]shared.DiskLocation) - for _, disk := range disks { - num, err := disk.GetProperty("Number") + err := cim.WithCOMThread(func() error { + // "location": "PCI Slot 3 : Adapter 0 : Port 0 : Target 1 : LUN 0" + disks, err := cim.ListDisks(cim.DiskSelectorListForDiskNumberAndLocation) if err != nil { - return m, fmt.Errorf("failed to query disk number: %v, %w", disk, err) + return fmt.Errorf("could not query disk locations") } - location, err := disk.GetPropertyLocation() - if err != nil { - return m, fmt.Errorf("failed to query disk location: %v, %w", disk, err) - } + for _, disk := range disks { + num, err := cim.GetDiskNumber(disk) + if err != nil { + return fmt.Errorf("failed to query disk number: %v, %w", disk, err) + } + + location, err := cim.GetDiskLocation(disk) + if err != nil { + return fmt.Errorf("failed to query disk location: %v, %w", disk, err) + } - found := false - s := strings.Split(location, ":") - if len(s) >= 5 { - var d shared.DiskLocation - for _, item := range s { - item = strings.TrimSpace(item) - itemSplit := strings.Split(item, " ") - if len(itemSplit) == 2 { - found = true - switch strings.TrimSpace(itemSplit[0]) { - case "Adapter": - d.Adapter = strings.TrimSpace(itemSplit[1]) - case "Target": - d.Target = strings.TrimSpace(itemSplit[1]) - case "LUN": - d.LUNID = strings.TrimSpace(itemSplit[1]) - default: - klog.Warningf("Got unknown field : %s=%s", itemSplit[0], itemSplit[1]) + found := false + s := strings.Split(location, ":") + if len(s) >= 5 { + var d shared.DiskLocation + for _, item := range s { + item = strings.TrimSpace(item) + itemSplit := strings.Split(item, " ") + if len(itemSplit) == 2 { + found = true + switch strings.TrimSpace(itemSplit[0]) { + case "Adapter": + d.Adapter = strings.TrimSpace(itemSplit[1]) + case "Target": + d.Target = strings.TrimSpace(itemSplit[1]) + case "LUN": + d.LUNID = strings.TrimSpace(itemSplit[1]) + default: + klog.Warningf("Got unknown field : %s=%s", itemSplit[0], itemSplit[1]) + } } } - } - if found { - m[uint32(num.(int32))] = d + if found { + m[num] = d + } } } - } - return m, nil + return nil + }) + return m, err } func (imp DiskAPI) Rescan() error { - result, _, err := cim.InvokeCimMethod(cim.WMINamespaceStorage, "MSFT_StorageSetting", "UpdateHostStorageCache", nil) - if err != nil { - return fmt.Errorf("error updating host storage cache output. result: %d, err: %v", result, err) - } - return nil + return cim.WithCOMThread(func() error { + result, err := cim.RescanDisks() + if err != nil { + return fmt.Errorf("error updating host storage cache output. result: %d, err: %v", result, err) + } + return nil + }) } func (imp DiskAPI) IsDiskInitialized(diskNumber uint32) (bool, error) { var partitionStyle int32 - disk, err := cim.QueryDiskByNumber(diskNumber, []string{"PartitionStyle"}) - if err != nil { - return false, fmt.Errorf("error checking initialized status of disk %d. %v", diskNumber, err) - } + err := cim.WithCOMThread(func() error { + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForPartitionStyle) + if err != nil { + return fmt.Errorf("error checking initialized status of disk %d: %v", diskNumber, err) + } - retValue, err := disk.GetProperty("PartitionStyle") - if err != nil { - return false, fmt.Errorf("failed to query partition style of disk %d: %w", diskNumber, err) - } + partitionStyle, err = cim.GetDiskPartitionStyle(disk) + if err != nil { + return fmt.Errorf("failed to query partition style of disk %d: %v", diskNumber, err) + } - partitionStyle = retValue.(int32) - return partitionStyle != cim.PartitionStyleUnknown, nil + return nil + }) + return partitionStyle != cim.PartitionStyleUnknown, err } func (imp DiskAPI) InitializeDisk(diskNumber uint32) error { - disk, err := cim.QueryDiskByNumber(diskNumber, nil) - if err != nil { - return fmt.Errorf("failed to initializing disk %d. error: %w", diskNumber, err) - } + return cim.WithCOMThread(func() error { + disk, err := cim.QueryDiskByNumber(diskNumber, nil) + if err != nil { + return fmt.Errorf("failed to initializing disk %d. error: %w", diskNumber, err) + } - result, err := disk.InvokeMethodWithReturn("Initialize", int32(cim.PartitionStyleGPT)) - if result != 0 || err != nil { - return fmt.Errorf("failed to initializing disk %d: result %d, error: %w", diskNumber, result, err) - } + result, err := cim.InitializeDisk(disk, cim.PartitionStyleGPT) + if result != 0 || err != nil { + return fmt.Errorf("failed to initializing disk %d: result %d, error: %w", diskNumber, result, err) + } - return nil + return nil + }) } func (imp DiskAPI) BasicPartitionsExist(diskNumber uint32) (bool, error) { - partitions, err := cim.ListPartitionsWithFilters(nil, - query.NewWmiQueryFilter("DiskNumber", strconv.Itoa(int(diskNumber)), query.Equals), - query.NewWmiQueryFilter("GptType", cim.GPTPartitionTypeMicrosoftReserved, query.NotEquals)) - if cim.IgnoreNotFound(err) != nil { - return false, fmt.Errorf("error checking presence of partitions on disk %d:, %v", diskNumber, err) - } + var exist bool + err := cim.WithCOMThread(func() error { + partitions, err := cim.ListPartitionsWithFilters(nil, cim.FilterForPartitionOnDisk(diskNumber), cim.FilterForPartitionsOfTypeNormal()) + if cim.IgnoreNotFound(err) != nil { + return fmt.Errorf("error checking presence of partitions on disk %d:, %v", diskNumber, err) + } - return len(partitions) > 0, nil + exist = len(partitions) > 0 + return nil + }) + return exist, err } func (imp DiskAPI) CreateBasicPartition(diskNumber uint32) error { - disk, err := cim.QueryDiskByNumber(diskNumber, nil) - if err != nil { - return err - } + return cim.WithCOMThread(func() error { + disk, err := cim.QueryDiskByNumber(diskNumber, nil) + if err != nil { + return err + } - result, err := disk.InvokeMethodWithReturn( - "CreatePartition", - nil, // Size - true, // UseMaximumSize - nil, // Offset - nil, // Alignment - nil, // DriveLetter - false, // AssignDriveLetter - nil, // MbrType, - cim.GPTPartitionTypeBasicData, // GPT Type - false, // IsHidden - false, // IsActive, - ) - // 42002 is returned by driver letter failed to assign after partition - if (result != 0 && result != 42002) || err != nil { - return fmt.Errorf("error creating partition on disk %d. result: %d, err: %v", diskNumber, result, err) - } + result, err := cim.CreatePartition( + disk, + nil, // Size + true, // UseMaximumSize + nil, // Offset + nil, // Alignment + nil, // DriveLetter + false, // AssignDriveLetter + nil, // MbrType, + cim.GPTPartitionTypeBasicData, // GPT Type + false, // IsHidden + false, // IsActive, + ) + if (result != 0 && result != cim.ErrorCodeCreatePartitionAccessPathAlreadyInUse) || err != nil { + return fmt.Errorf("error creating partition on disk %d. result: %d, err: %v", diskNumber, result, err) + } - var status string - result, err = disk.InvokeMethodWithReturn("Refresh", &status) - if result != 0 || err != nil { - return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err) - } + result, _, err = cim.RefreshDisk(disk) + if result != 0 || err != nil { + return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err) + } - partitions, err := cim.ListPartitionsWithFilters(nil, - query.NewWmiQueryFilter("DiskNumber", strconv.Itoa(int(diskNumber)), query.Equals), - query.NewWmiQueryFilter("GptType", cim.GPTPartitionTypeMicrosoftReserved, query.NotEquals)) - if err != nil { - return fmt.Errorf("error query basic partition on disk %d:, %v", diskNumber, err) - } + partitions, err := cim.ListPartitionsWithFilters(nil, cim.FilterForPartitionOnDisk(diskNumber), cim.FilterForPartitionsOfTypeNormal()) + if err != nil { + return fmt.Errorf("error query basic partition on disk %d:, %v", diskNumber, err) + } - if len(partitions) == 0 { - return fmt.Errorf("failed to create basic partition on disk %d:, %v", diskNumber, err) - } + if len(partitions) == 0 { + return fmt.Errorf("failed to create basic partition on disk %d:, %v", diskNumber, err) + } - partition := partitions[0] - result, err = partition.InvokeMethodWithReturn("Online", status) - if result != 0 || err != nil { - return fmt.Errorf("error bring partition %v on disk %d online. result: %d, status %s, err: %v", partition, diskNumber, result, status, err) - } + partition := partitions[0] + result, status, err := cim.SetPartitionState(partition, true) + if result != 0 || err != nil { + return fmt.Errorf("error bring partition %v on disk %d online. result: %d, status %s, err: %v", partition, diskNumber, result, status, err) + } - err = partition.Refresh() - return err + return nil + }) } func (imp DiskAPI) GetDiskNumberByName(page83ID string) (uint32, error) { @@ -272,28 +279,33 @@ func (imp DiskAPI) GetDiskPage83ID(disk syscall.Handle) (string, error) { } func (imp DiskAPI) GetDiskNumberWithID(page83ID string) (uint32, error) { - disks, err := cim.ListDisks([]string{"Path", "SerialNumber"}) - if err != nil { - return 0, err - } - - for _, disk := range disks { - path, err := disk.GetPropertyPath() + var diskNumberResult uint32 + err := cim.WithCOMThread(func() error { + disks, err := cim.ListDisks(cim.DiskSelectorListForPathAndSerialNumber) if err != nil { - return 0, fmt.Errorf("failed to query disk path: %v, %w", disk, err) + return err } - diskNumber, diskPage83ID, err := imp.GetDiskNumberAndPage83ID(path) - if err != nil { - return 0, err - } + for _, disk := range disks { + path, err := cim.GetDiskPath(disk) + if err != nil { + return fmt.Errorf("failed to query disk path: %v, %w", disk, err) + } - if diskPage83ID == page83ID { - return diskNumber, nil + diskNumber, diskPage83ID, err := imp.GetDiskNumberAndPage83ID(path) + if err != nil { + return err + } + + if diskPage83ID == page83ID { + diskNumberResult = diskNumber + return nil + } } - } - return 0, fmt.Errorf("could not find disk with Page83 ID %s", page83ID) + return fmt.Errorf("could not find disk with Page83 ID %s", page83ID) + }) + return diskNumberResult, err } func (imp DiskAPI) GetDiskNumberAndPage83ID(path string) (uint32, string, error) { @@ -319,91 +331,98 @@ func (imp DiskAPI) GetDiskNumberAndPage83ID(path string) (uint32, string, error) // ListDiskIDs - constructs a map with the disk number as the key and the DiskID structure // as the value. The DiskID struct has a field for the page83 ID. func (imp DiskAPI) ListDiskIDs() (map[uint32]shared.DiskIDs, error) { - disks, err := cim.ListDisks([]string{"Path", "SerialNumber"}) - if err != nil { - return nil, err - } - m := make(map[uint32]shared.DiskIDs) - for _, disk := range disks { - path, err := disk.GetPropertyPath() + err := cim.WithCOMThread(func() error { + disks, err := cim.ListDisks(cim.DiskSelectorListForPathAndSerialNumber) if err != nil { - return m, fmt.Errorf("failed to query disk path: %v, %w", disk, err) + return err } - sn, err := disk.GetPropertySerialNumber() - if err != nil { - return m, fmt.Errorf("failed to query disk serial number: %v, %w", disk, err) - } + for _, disk := range disks { + path, err := cim.GetDiskPath(disk) + if err != nil { + return fmt.Errorf("failed to query disk path: %v, %w", disk, err) + } - diskNumber, page83, err := imp.GetDiskNumberAndPage83ID(path) - if err != nil { - return m, err - } + sn, err := cim.GetDiskSerialNumber(disk) + if err != nil { + return fmt.Errorf("failed to query disk serial number: %v, %w", disk, err) + } + + diskNumber, page83, err := imp.GetDiskNumberAndPage83ID(path) + if err != nil { + return err + } - m[diskNumber] = shared.DiskIDs{ - Page83: page83, - SerialNumber: sn, + m[diskNumber] = shared.DiskIDs{ + Page83: page83, + SerialNumber: sn, + } } - } - return m, nil + + return nil + }) + return m, err } func (imp DiskAPI) GetDiskStats(diskNumber uint32) (int64, error) { // TODO: change to uint64 as it does not make sense to use int64 for size - var size int64 - disk, err := cim.QueryDiskByNumber(diskNumber, []string{"Size"}) - if err != nil { - return -1, err - } + size := int64(-1) + err := cim.WithCOMThread(func() error { + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForSize) + if err != nil { + return err + } - sz, err := disk.GetProperty("Size") - if err != nil { - return -1, fmt.Errorf("failed to query size of disk %d. %v", diskNumber, err) - } + size, err = cim.GetDiskSize(disk) + if err != nil { + return fmt.Errorf("failed to query size of disk %d. %v", diskNumber, err) + } - size, err = strconv.ParseInt(sz.(string), 10, 64) + return nil + }) return size, err } func (imp DiskAPI) SetDiskState(diskNumber uint32, isOnline bool) error { - disk, err := cim.QueryDiskByNumber(diskNumber, []string{"IsOffline"}) - if err != nil { - return err - } - - offline, err := disk.GetPropertyIsOffline() - if err != nil { - return fmt.Errorf("error setting disk %d attach state. error: %v", diskNumber, err) - } + return cim.WithCOMThread(func() error { + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForIsOffline) + if err != nil { + return err + } - if isOnline == !offline { - return nil - } + isOffline, err := cim.IsDiskOffline(disk) + if err != nil { + return fmt.Errorf("error setting disk %d attach state. error: %v", diskNumber, err) + } - method := "Offline" - if isOnline { - method = "Online" - } + if isOnline == !isOffline { + return nil + } - result, err := disk.InvokeMethodWithReturn(method) - if result != 0 || err != nil { - return fmt.Errorf("setting disk %d attach state %s: result %d, error: %w", diskNumber, method, result, err) - } + result, _, err := cim.SetDiskState(disk, isOnline) + if result != 0 || err != nil { + return fmt.Errorf("setting disk %d attach state (isOnline: %v): result %d, error: %w", diskNumber, isOnline, result, err) + } - return nil + return nil + }) } func (imp DiskAPI) GetDiskState(diskNumber uint32) (bool, error) { - disk, err := cim.QueryDiskByNumber(diskNumber, []string{"IsOffline"}) - if err != nil { - return false, err - } + var isOffline bool + err := cim.WithCOMThread(func() error { + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForIsOffline) + if err != nil { + return err + } - isOffline, err := disk.GetPropertyIsOffline() - if err != nil { - return false, fmt.Errorf("error parsing disk %d state. error: %v", diskNumber, err) - } + isOffline, err = cim.IsDiskOffline(disk) + if err != nil { + return fmt.Errorf("error parsing disk %d state. error: %v", diskNumber, err) + } - return !isOffline, nil + return nil + }) + return !isOffline, err } diff --git a/pkg/os/filesystem/api.go b/pkg/os/filesystem/api.go index f93fb2e9..458fd89f 100644 --- a/pkg/os/filesystem/api.go +++ b/pkg/os/filesystem/api.go @@ -4,7 +4,6 @@ import ( "fmt" "os" "path/filepath" - "strings" "github.com/kubernetes-csi/csi-proxy/pkg/utils" ) @@ -49,17 +48,6 @@ func (filesystemAPI) PathExists(path string) (bool, error) { return pathExists(path) } -func pathValid(path string) (bool, error) { - cmd := `Test-Path $Env:remotepath` - cmdEnv := fmt.Sprintf("remotepath=%s", path) - output, err := utils.RunPowershellCmd(cmd, cmdEnv) - if err != nil { - return false, fmt.Errorf("returned output: %s, error: %v", string(output), err) - } - - return strings.HasPrefix(strings.ToLower(string(output)), "true"), nil -} - // PathValid determines whether all elements of a path exist // // https://docs.microsoft.com/en-us/powershell/module/microsoft.powershell.management/test-path?view=powershell-7 @@ -68,7 +56,7 @@ func pathValid(path string) (bool, error) { // // e.g. in a SMB server connection, if password is changed, connection will be lost, this func will return false func (filesystemAPI) PathValid(path string) (bool, error) { - return pathValid(path) + return utils.IsPathValid(path) } // Mkdir makes a dir with `os.MkdirAll`. @@ -124,13 +112,18 @@ func (filesystemAPI) IsSymlink(tgt string) (bool, error) { // This code is similar to k8s.io/kubernetes/pkg/util/mount except the pathExists usage. // Also in a remote call environment the os error cannot be passed directly back, hence the callers // are expected to perform the isExists check before calling this call in CSI proxy. - stat, err := os.Lstat(tgt) + isSymlink, err := utils.IsPathSymlink(tgt) + if err != nil { + return false, err + } + + // mounted folder created by SetVolumeMountPoint may still report ModeSymlink == 0 + mountedFolder, err := utils.IsMountedFolder(tgt) if err != nil { return false, err } - // If its a link and it points to an existing file then its a mount point. - if stat.Mode()&os.ModeSymlink != 0 { + if isSymlink || mountedFolder { target, err := os.Readlink(tgt) if err != nil { return false, fmt.Errorf("readlink error: %v", err) diff --git a/pkg/os/iscsi/api.go b/pkg/os/iscsi/api.go index 559ed3b5..af8050e1 100644 --- a/pkg/os/iscsi/api.go +++ b/pkg/os/iscsi/api.go @@ -1,10 +1,12 @@ package iscsi import ( - "encoding/json" "fmt" + "strconv" + "strings" - "github.com/kubernetes-csi/csi-proxy/pkg/utils" + "github.com/kubernetes-csi/csi-proxy/pkg/cim" + "k8s.io/klog/v2" ) // Implements the iSCSI OS API calls. All code here should be very simple @@ -19,157 +21,209 @@ func New() APIImplementor { } func (APIImplementor) AddTargetPortal(portal *TargetPortal) error { - cmdLine := fmt.Sprintf( - `New-IscsiTargetPortal -TargetPortalAddress ${Env:iscsi_tp_address} ` + - `-TargetPortalPortNumber ${Env:iscsi_tp_port}`) - out, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("iscsi_tp_address=%s", portal.Address), - fmt.Sprintf("iscsi_tp_port=%d", portal.Port)) - if err != nil { - return fmt.Errorf("error adding target portal. cmd %s, output: %s, err: %v", cmdLine, string(out), err) - } - - return nil + return cim.WithCOMThread(func() error { + existing, err := cim.QueryISCSITargetPortal(portal.Address, portal.Port, nil) + if cim.IgnoreNotFound(err) != nil { + return err + } + + if existing != nil { + klog.V(2).Infof("target portal at (%s:%d) already exists", portal.Address, portal.Port) + return nil + } + + _, err = cim.NewISCSITargetPortal(portal.Address, portal.Port, nil, nil, nil, nil) + if err != nil { + return fmt.Errorf("error adding target portal at (%s:%d). err: %v", portal.Address, portal.Port, err) + } + + return nil + }) } func (APIImplementor) DiscoverTargetPortal(portal *TargetPortal) ([]string, error) { - // ConvertTo-Json is not part of the pipeline because powershell converts an - // array with one element to a single element - cmdLine := fmt.Sprintf( - `ConvertTo-Json -InputObject @(Get-IscsiTargetPortal -TargetPortalAddress ` + - `${Env:iscsi_tp_address} -TargetPortalPortNumber ${Env:iscsi_tp_port} | ` + - `Get-IscsiTarget | Select-Object -ExpandProperty NodeAddress)`) - out, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("iscsi_tp_address=%s", portal.Address), - fmt.Sprintf("iscsi_tp_port=%d", portal.Port)) - if err != nil { - return nil, fmt.Errorf("error discovering target portal. cmd: %s, output: %s, err: %w", cmdLine, string(out), err) - } - var iqns []string - err = json.Unmarshal(out, &iqns) - if err != nil { - return nil, fmt.Errorf("failed parsing iqn list. cmd: %s output: %s, err: %w", cmdLine, string(out), err) - } - - return iqns, nil + err := cim.WithCOMThread(func() error { + targets, err := cim.ListISCSITargetsByTargetPortalAddressAndPort(portal.Address, portal.Port, nil) + if err != nil { + return err + } + + for _, target := range targets { + iqn, err := cim.GetISCSITargetNodeAddress(target) + if err != nil { + return fmt.Errorf("failed parsing node address of target %v to target portal at (%s:%d). err: %w", target, portal.Address, portal.Port, err) + } + + iqns = append(iqns, iqn) + } + + return nil + }) + return iqns, err } func (APIImplementor) ListTargetPortals() ([]TargetPortal, error) { - cmdLine := fmt.Sprintf( - `ConvertTo-Json -InputObject @(Get-IscsiTargetPortal | ` + - `Select-Object TargetPortalAddress, TargetPortalPortNumber)`) - - out, err := utils.RunPowershellCmd(cmdLine) - if err != nil { - return nil, fmt.Errorf("error listing target portals. cmd %s, output: %s, err: %w", cmdLine, string(out), err) - } - var portals []TargetPortal - err = json.Unmarshal(out, &portals) - if err != nil { - return nil, fmt.Errorf("failed parsing target portal list. cmd: %s output: %s, err: %w", cmdLine, string(out), err) - } - - return portals, nil + err := cim.WithCOMThread(func() error { + instances, err := cim.ListISCSITargetPortals(cim.ISCSITargetPortalDefaultSelectorList) + if err != nil { + return err + } + + for _, instance := range instances { + address, port, err := cim.ParseISCSITargetPortal(instance) + if err != nil { + return fmt.Errorf("failed parsing target portal %v. err: %w", instance, err) + } + + portals = append(portals, TargetPortal{ + Address: address, + Port: port, + }) + } + + return nil + }) + return portals, err } func (APIImplementor) RemoveTargetPortal(portal *TargetPortal) error { - cmdLine := fmt.Sprintf( - `Get-IscsiTargetPortal -TargetPortalAddress ${Env:iscsi_tp_address} ` + - `-TargetPortalPortNumber ${Env:iscsi_tp_port} | Remove-IscsiTargetPortal ` + - `-Confirm:$false`) - - out, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("iscsi_tp_address=%s", portal.Address), - fmt.Sprintf("iscsi_tp_port=%d", portal.Port)) - if err != nil { - return fmt.Errorf("error removing target portal. cmd %s, output: %s, err: %w", cmdLine, string(out), err) - } - - return nil + return cim.WithCOMThread(func() error { + instance, err := cim.QueryISCSITargetPortal(portal.Address, portal.Port, nil) + if err != nil { + return err + } + + result, err := cim.RemoveISCSITargetPortal(instance) + if result != 0 || err != nil { + return fmt.Errorf("error removing target portal at (%s:%d). result: %d, err: %w", portal.Address, portal.Port, result, err) + } + + return nil + }) } -func (APIImplementor) ConnectTarget(portal *TargetPortal, iqn string, - authType string, chapUser string, chapSecret string) error { - // Not using InputObject as Connect-IscsiTarget's InputObject does not work. - // This is due to being a static WMI method together with a bug in the - // powershell version of the API. - cmdLine := fmt.Sprintf( - `Connect-IscsiTarget -TargetPortalAddress ${Env:iscsi_tp_address}` + - ` -TargetPortalPortNumber ${Env:iscsi_tp_port} -NodeAddress ${Env:iscsi_target_iqn}` + - ` -AuthenticationType ${Env:iscsi_auth_type}`) - - if chapUser != "" { - cmdLine += ` -ChapUsername ${Env:iscsi_chap_user}` - } - - if chapSecret != "" { - cmdLine += ` -ChapSecret ${Env:iscsi_chap_secret}` - } - - out, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("iscsi_tp_address=%s", portal.Address), - fmt.Sprintf("iscsi_tp_port=%d", portal.Port), - fmt.Sprintf("iscsi_target_iqn=%s", iqn), - fmt.Sprintf("iscsi_auth_type=%s", authType), - fmt.Sprintf("iscsi_chap_user=%s", chapUser), - fmt.Sprintf("iscsi_chap_secret=%s", chapSecret)) - if err != nil { - return fmt.Errorf("error connecting to target portal. cmd %s, output: %s, err: %w", cmdLine, string(out), err) - } - - return nil +func (APIImplementor) ConnectTarget(portal *TargetPortal, iqn string, authType string, chapUser string, chapSecret string) error { + return cim.WithCOMThread(func() error { + target, err := cim.QueryISCSITarget(portal.Address, portal.Port, iqn) + if err != nil { + return err + } + + connected, err := cim.IsISCSITargetConnected(target) + if err != nil { + return err + } + + if connected { + klog.V(2).Infof("target %s from target portal at (%s:%d) is connected.", iqn, portal.Address, portal.Port) + return nil + } + + targetAuthType := strings.ToUpper(strings.ReplaceAll(authType, "_", "")) + + result, err := cim.ConnectISCSITarget(portal.Address, portal.Port, iqn, targetAuthType, &chapUser, &chapSecret) + if err != nil { + return fmt.Errorf("error connecting to target portal. result: %d, err: %w", result, err) + } + + return nil + }) } func (APIImplementor) DisconnectTarget(portal *TargetPortal, iqn string) error { - // Using InputObject instead of pipe to verify input is not empty - cmdLine := fmt.Sprintf( - `Disconnect-IscsiTarget -InputObject (Get-IscsiTargetPortal ` + - `-TargetPortalAddress ${Env:iscsi_tp_address} -TargetPortalPortNumber ${Env:iscsi_tp_port} ` + - ` | Get-IscsiTarget | Where-Object { $_.NodeAddress -eq ${Env:iscsi_target_iqn} }) ` + - `-Confirm:$false`) - - out, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("iscsi_tp_address=%s", portal.Address), - fmt.Sprintf("iscsi_tp_port=%d", portal.Port), - fmt.Sprintf("iscsi_target_iqn=%s", iqn)) - if err != nil { - return fmt.Errorf("error disconnecting from target portal. cmd %s, output: %s, err: %w", cmdLine, string(out), err) - } - - return nil + return cim.WithCOMThread(func() error { + target, err := cim.QueryISCSITarget(portal.Address, portal.Port, iqn) + if err != nil { + return err + } + + connected, err := cim.IsISCSITargetConnected(target) + if err != nil { + return fmt.Errorf("error query connected of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) + } + + if !connected { + klog.V(2).Infof("target %s from target portal at (%s:%d) is not connected.", iqn, portal.Address, portal.Port) + return nil + } + + // get session + session, err := cim.QueryISCSISessionByTarget(target) + if err != nil { + return fmt.Errorf("error query session of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) + } + + sessionIdentifier, err := cim.GetISCSISessionIdentifier(session) + if err != nil { + return fmt.Errorf("error query session identifier of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) + } + + persistent, err := cim.IsISCSISessionPersistent(session) + if err != nil { + return fmt.Errorf("error query session persistency of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) + } + + if persistent { + result, err := cim.UnregisterISCSISession(session) + if err != nil { + return fmt.Errorf("error unregister session on target %s from target portal at (%s:%d). result: %d, err: %w", iqn, portal.Address, portal.Port, result, err) + } + } + + result, err := cim.DisconnectISCSITarget(target, sessionIdentifier) + if err != nil { + return fmt.Errorf("error disconnecting target %s from target portal at (%s:%d). result: %d, err: %w", iqn, portal.Address, portal.Port, result, err) + } + + return nil + }) } func (APIImplementor) GetTargetDisks(portal *TargetPortal, iqn string) ([]string, error) { - // Converting DiskNumber to string for compatibility with disk api group - // Not using pipeline in order to validate that items are non-empty - cmdLine := fmt.Sprintf( - `$ErrorActionPreference = "Stop"; ` + - `$tp = Get-IscsiTargetPortal -TargetPortalAddress ${Env:iscsi_tp_address} -TargetPortalPortNumber ${Env:iscsi_tp_port}; ` + - `$t = $tp | Get-IscsiTarget | Where-Object { $_.NodeAddress -eq ${Env:iscsi_target_iqn} }; ` + - `$c = Get-IscsiConnection -IscsiTarget $t; ` + - `$ids = $c | Get-Disk | Select -ExpandProperty Number | Out-String -Stream; ` + - `ConvertTo-Json -InputObject @($ids)`) - - out, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("iscsi_tp_address=%s", portal.Address), - fmt.Sprintf("iscsi_tp_port=%d", portal.Port), - fmt.Sprintf("iscsi_target_iqn=%s", iqn)) - if err != nil { - return nil, fmt.Errorf("error getting target disks. cmd %s, output: %s, err: %w", cmdLine, string(out), err) - } - var ids []string - err = json.Unmarshal(out, &ids) - if err != nil { - return nil, fmt.Errorf("error parsing iqn target disks. cmd: %s output: %s, err: %w", cmdLine, string(out), err) - } - - return ids, nil + err := cim.WithCOMThread(func() error { + target, err := cim.QueryISCSITarget(portal.Address, portal.Port, iqn) + if err != nil { + return err + } + + connected, err := cim.IsISCSITargetConnected(target) + if err != nil { + return fmt.Errorf("error query connected of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) + } + + if !connected { + klog.V(2).Infof("target %s from target portal at (%s:%d) is not connected.", iqn, portal.Address, portal.Port) + return nil + } + + disks, err := cim.ListDisksByTarget(target) + if err != nil { + return fmt.Errorf("error getting target disks on target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) + } + + for _, disk := range disks { + number, err := cim.GetDiskNumber(disk) + if err != nil { + return fmt.Errorf("error getting number of disk %v on target %s from target portal at (%s:%d). err: %w", disk, iqn, portal.Address, portal.Port, err) + } + + ids = append(ids, strconv.Itoa(int(number))) + } + return nil + }) + return ids, err } func (APIImplementor) SetMutualChapSecret(mutualChapSecret string) error { - cmdLine := `Set-IscsiChapSecret -ChapSecret ${Env:iscsi_mutual_chap_secret}` - out, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("iscsi_mutual_chap_secret=%s", mutualChapSecret)) - if err != nil { - return fmt.Errorf("error setting mutual chap secret. cmd %s,"+ - " output: %s, err: %v", cmdLine, string(out), err) - } - - return nil + return cim.WithCOMThread(func() error { + result, err := cim.SetISCSISessionChapSecret(mutualChapSecret) + if err != nil { + return fmt.Errorf("error setting mutual chap secret. result: %d, err: %v", result, err) + } + + return nil + }) } diff --git a/pkg/os/smb/api.go b/pkg/os/smb/api.go index 20b9544e..9e60c9dd 100644 --- a/pkg/os/smb/api.go +++ b/pkg/os/smb/api.go @@ -3,15 +3,9 @@ package smb import ( "fmt" "strings" - "syscall" "github.com/kubernetes-csi/csi-proxy/pkg/cim" "github.com/kubernetes-csi/csi-proxy/pkg/utils" - "golang.org/x/sys/windows" -) - -const ( - credentialDelimiter = ":" ) type API interface { @@ -33,61 +27,28 @@ func New(requirePrivacy bool) *SmbAPI { } } -func remotePathForQuery(remotePath string) string { - return strings.ReplaceAll(remotePath, "\\", "\\\\") -} - -func escapeUserName(userName string) string { - // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170 - escaped := strings.ReplaceAll(userName, "\\", "\\\\") - escaped = strings.ReplaceAll(escaped, credentialDelimiter, "\\"+credentialDelimiter) - return escaped -} - -func createSymlink(link, target string, isDir bool) error { - linkPtr, err := syscall.UTF16PtrFromString(link) - if err != nil { - return err - } - targetPtr, err := syscall.UTF16PtrFromString(target) - if err != nil { - return err - } - - var flags uint32 - if isDir { - flags = windows.SYMBOLIC_LINK_FLAG_DIRECTORY - } - - err = windows.CreateSymbolicLink( - linkPtr, - targetPtr, - flags, - ) - return err -} - func (*SmbAPI) IsSmbMapped(remotePath string) (bool, error) { - inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePathForQuery(remotePath)) - if err != nil { - return false, cim.IgnoreNotFound(err) - } - - status, err := inst.GetProperty("Status") - if err != nil { - return false, err - } - - return status.(int32) == cim.SmbMappingStatusOK, nil + var isMapped bool + err := cim.WithCOMThread(func() error { + inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePath) + if err != nil { + return err + } + + status, err := cim.GetSmbGlobalMappingStatus(inst) + if err != nil { + return err + } + + isMapped = status == cim.SmbMappingStatusOK + return nil + }) + return isMapped, cim.IgnoreNotFound(err) } // NewSmbLink - creates a directory symbolic link to the remote share. // The os.Symlink was having issue for cases where the destination was an SMB share - the container -// runtime would complain stating "Access Denied". Because of this, we had to perform -// this operation with powershell commandlet creating an directory softlink. -// Since os.Symlink is currently being used in working code paths, no attempt is made in -// alpha to merge the paths. -// TODO (for beta release): Merge the link paths - os.Symlink and Powershell link path. +// runtime would complain stating "Access Denied". func (*SmbAPI) NewSmbLink(remotePath, localPath string) error { if !strings.HasSuffix(remotePath, "\\") { // Golang has issues resolving paths mapped to file shares if they do not end in a trailing \ @@ -97,7 +58,7 @@ func (*SmbAPI) NewSmbLink(remotePath, localPath string) error { longRemotePath := utils.EnsureLongPath(remotePath) longLocalPath := utils.EnsureLongPath(localPath) - err := createSymlink(longLocalPath, longRemotePath, true) + err := utils.CreateSymlink(longLocalPath, longRemotePath, true) if err != nil { return fmt.Errorf("error linking %s to %s. err: %v", remotePath, localPath, err) } @@ -106,29 +67,21 @@ func (*SmbAPI) NewSmbLink(remotePath, localPath string) error { } func (api *SmbAPI) NewSmbGlobalMapping(remotePath, username, password string) error { - params := map[string]interface{}{ - "RemotePath": remotePath, - "RequirePrivacy": api.RequirePrivacy, - } - if username != "" { - // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178 - // on how SMB credential is handled in PowerShell - params["Credential"] = escapeUserName(username) + credentialDelimiter + password - } - - result, _, err := cim.InvokeCimMethod(cim.WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params) - if err != nil { - return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err) - } - - return nil + return cim.WithCOMThread(func() error { + result, err := cim.NewSmbGlobalMapping(remotePath, username, password, api.RequirePrivacy) + if err != nil { + return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err) + } + return nil + }) } func (*SmbAPI) RemoveSmbGlobalMapping(remotePath string) error { - err := cim.RemoveSmbGlobalMappingByRemotePath(remotePathForQuery(remotePath)) - if err != nil { - return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err) - } - - return nil + return cim.WithCOMThread(func() error { + err := cim.RemoveSmbGlobalMappingByRemotePath(remotePath) + if err != nil { + return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err) + } + return nil + }) } diff --git a/pkg/os/system/api.go b/pkg/os/system/api.go index a09c83a4..128704ba 100644 --- a/pkg/os/system/api.go +++ b/pkg/os/system/api.go @@ -2,10 +2,12 @@ package system import ( "fmt" + "time" "github.com/kubernetes-csi/csi-proxy/pkg/cim" "github.com/kubernetes-csi/csi-proxy/pkg/server/system/impl" - "github.com/kubernetes-csi/csi-proxy/pkg/utils" + "github.com/pkg/errors" + "k8s.io/klog/v2" ) // Implements the System OS API calls. All code here should be very simple @@ -24,6 +26,29 @@ type ServiceInfo struct { Status uint32 `json:"Status"` } +type stateCheckFunc func(cim.ServiceInterface, string) (bool, string, error) +type stateTransitionFunc func(cim.ServiceInterface) error + +const ( + // startServiceErrorCodeAccepted indicates the request is accepted + startServiceErrorCodeAccepted = 0 + + // startServiceErrorCodeAlreadyRunning indicates a service is already running + startServiceErrorCodeAlreadyRunning = 10 + + // stopServiceErrorCodeAccepted indicates the request is accepted + stopServiceErrorCodeAccepted = 0 + + // stopServiceErrorCodeStopPending indicates the request cannot be sent to the service because the state of the service is 0,1,2 (pending) + stopServiceErrorCodeStopPending = 5 + + // stopServiceErrorCodeDependentRunning indicates a service cannot be stopped as its dependents may still be running + stopServiceErrorCodeDependentRunning = 3 + + serviceStateRunning = "Running" + serviceStateStopped = "Stopped" +) + var ( startModeMappings = map[string]uint32{ "Boot": impl.START_TYPE_BOOT, @@ -33,16 +58,20 @@ var ( "Disabled": impl.START_TYPE_DISABLED, } - statusMappings = map[string]uint32{ - "Unknown": impl.SERVICE_STATUS_UNKNOWN, - "Stopped": impl.SERVICE_STATUS_STOPPED, - "Start Pending": impl.SERVICE_STATUS_START_PENDING, - "Stop Pending": impl.SERVICE_STATUS_STOP_PENDING, - "Running": impl.SERVICE_STATUS_RUNNING, - "Continue Pending": impl.SERVICE_STATUS_CONTINUE_PENDING, - "Pause Pending": impl.SERVICE_STATUS_PAUSE_PENDING, - "Paused": impl.SERVICE_STATUS_PAUSED, + stateMappings = map[string]uint32{ + "Unknown": impl.SERVICE_STATUS_UNKNOWN, + serviceStateStopped: impl.SERVICE_STATUS_STOPPED, + "Start Pending": impl.SERVICE_STATUS_START_PENDING, + "Stop Pending": impl.SERVICE_STATUS_STOP_PENDING, + serviceStateRunning: impl.SERVICE_STATUS_RUNNING, + "Continue Pending": impl.SERVICE_STATUS_CONTINUE_PENDING, + "Pause Pending": impl.SERVICE_STATUS_PAUSE_PENDING, + "Paused": impl.SERVICE_STATUS_PAUSED, } + + serviceStateCheckInternal = 200 * time.Millisecond + serviceStateCheckTimeout = 30 * time.Second + errTimedOut = errors.New("Timed out") ) func serviceStartModeToStartType(startMode string) uint32 { @@ -50,75 +79,277 @@ func serviceStartModeToStartType(startMode string) uint32 { } func serviceState(status string) uint32 { - return statusMappings[status] + return stateMappings[status] +} + +type ServiceManager interface { + WaitUntilServiceState(cim.ServiceInterface, stateTransitionFunc, stateCheckFunc, time.Duration, time.Duration) (string, error) + GetDependentsForService(string) ([]string, error) } -type APIImplementor struct{} +type ServiceFactory interface { + GetService(name string) (cim.ServiceInterface, error) +} + +type APIImplementor struct { + serviceFactory ServiceFactory + serviceManager ServiceManager +} func New() APIImplementor { - return APIImplementor{} + serviceFactory := cim.Win32ServiceFactory{} + return APIImplementor{ + serviceFactory: serviceFactory, + serviceManager: ServiceManagerImpl{ + serviceFactory: serviceFactory, + }, + } } func (APIImplementor) GetBIOSSerialNumber() (string, error) { - bios, err := cim.QueryBIOSElement([]string{"SerialNumber"}) - if err != nil { - return "", fmt.Errorf("failed to get BIOS element: %w", err) + var sn string + err := cim.WithCOMThread(func() error { + bios, err := cim.QueryBIOSElement(cim.BIOSSelectorList) + if err != nil { + return fmt.Errorf("failed to get BIOS element: %w", err) + } + + sn, err = cim.GetBIOSSerialNumber(bios) + if err != nil { + return fmt.Errorf("failed to get BIOS serial number property: %w", err) + } + + return nil + }) + return sn, err +} + +func (impl APIImplementor) GetService(name string) (*ServiceInfo, error) { + var serviceInfo *ServiceInfo + err := cim.WithCOMThread(func() error { + service, err := impl.serviceFactory.GetService(name) + if err != nil { + return fmt.Errorf("failed to get service %s: %w", name, err) + } + + displayName, err := cim.GetServiceDisplayName(service) + if err != nil { + return fmt.Errorf("failed to get displayName property of service %s: %w", name, err) + } + + state, err := cim.GetServiceState(service) + if err != nil { + return fmt.Errorf("failed to get state property of service %s: %w", name, err) + } + + startMode, err := cim.GetServiceStartMode(service) + if err != nil { + return fmt.Errorf("failed to get startMode property of service %s: %w", name, err) + } + + serviceInfo = &ServiceInfo{ + DisplayName: displayName, + StartType: serviceStartModeToStartType(startMode), + Status: serviceState(state), + } + return nil + }) + return serviceInfo, err +} + +func (impl APIImplementor) StartService(name string) error { + startService := func(service cim.ServiceInterface) error { + retVal, err := service.StartService() + if err != nil || (retVal != startServiceErrorCodeAccepted && retVal != startServiceErrorCodeAlreadyRunning) { + return fmt.Errorf("error starting service name %s. return value: %d, error: %v", name, retVal, err) + } + return nil } + serviceRunningCheck := func(service cim.ServiceInterface, state string) (bool, string, error) { + err := service.Refresh() + if err != nil { + return false, "", err + } - sn, err := bios.GetPropertySerialNumber() - if err != nil { - return "", fmt.Errorf("failed to get BIOS serial number property: %w", err) + newState, err := cim.GetServiceState(service) + if err != nil { + return false, state, err + } + + klog.V(6).Infof("service (%v) state check: %s => %s", service, state, newState) + return state == serviceStateRunning, newState, err } - return sn, nil + return cim.WithCOMThread(func() error { + service, err := impl.serviceFactory.GetService(name) + if err != nil { + return err + } + + state, err := impl.serviceManager.WaitUntilServiceState(service, startService, serviceRunningCheck, serviceStateCheckInternal, serviceStateCheckTimeout) + if err != nil && !errors.Is(err, errTimedOut) { + return err + } + + if state != serviceStateRunning { + return fmt.Errorf("timed out waiting for service %s to become running", name) + } + + return nil + }) } -func (APIImplementor) GetService(name string) (*ServiceInfo, error) { - service, err := cim.QueryServiceByName(name, []string{"DisplayName", "State", "StartMode"}) - if err != nil { - return nil, fmt.Errorf("failed to get service %s: %w", name, err) +func (impl APIImplementor) stopSingleService(name string) (bool, error) { + var dependentRunning bool + stopService := func(service cim.ServiceInterface) error { + retVal, err := service.StopService() + if err != nil || (retVal != stopServiceErrorCodeAccepted && retVal != stopServiceErrorCodeStopPending) { + if retVal == stopServiceErrorCodeDependentRunning { + dependentRunning = true + return fmt.Errorf("error stopping service %s as dependent services are not stopped", name) + } + return fmt.Errorf("error stopping service %s. return value: %d, error: %v", name, retVal, err) + } + return nil } + serviceStoppedCheck := func(service cim.ServiceInterface, state string) (bool, string, error) { + err := service.Refresh() + if err != nil { + return false, "", err + } - displayName, err := service.GetPropertyDisplayName() - if err != nil { - return nil, fmt.Errorf("failed to get displayName property of service %s: %w", name, err) + newState, err := cim.GetServiceState(service) + if err != nil { + return false, state, err + } + + klog.V(6).Infof("service (%v) state check: %s => %s", service, state, newState) + return newState == serviceStateStopped, newState, err } - state, err := service.GetPropertyState() + service, err := impl.serviceFactory.GetService(name) if err != nil { - return nil, fmt.Errorf("failed to get state property of service %s: %w", name, err) + return dependentRunning, err } - startMode, err := service.GetPropertyStartMode() - if err != nil { - return nil, fmt.Errorf("failed to get startMode property of service %s: %w", name, err) + state, err := impl.serviceManager.WaitUntilServiceState(service, stopService, serviceStoppedCheck, serviceStateCheckInternal, serviceStateCheckTimeout) + if err != nil && !errors.Is(err, errTimedOut) { + return dependentRunning, fmt.Errorf("error stopping service name %s. current state: %s", name, state) + } + + if state != serviceStateStopped { + return dependentRunning, fmt.Errorf("timed out waiting for service %s to stop", name) } - return &ServiceInfo{ - DisplayName: displayName, - StartType: serviceStartModeToStartType(startMode), - Status: serviceState(state), - }, nil + return dependentRunning, nil } -func (APIImplementor) StartService(name string) error { - // Note: both StartService and StopService are not implemented by WMI - script := `Start-Service -Name $env:ServiceName` - cmdEnv := fmt.Sprintf("ServiceName=%s", name) - out, err := utils.RunPowershellCmd(script, cmdEnv) +func (impl APIImplementor) StopService(name string, force bool) error { + return cim.WithCOMThread(func() error { + dependentRunning, err := impl.stopSingleService(name) + if err == nil || !dependentRunning || !force { + return err + } + + serviceNames, err := impl.serviceManager.GetDependentsForService(name) + if err != nil { + return fmt.Errorf("error getting dependent services for service name %s", name) + } + + for _, serviceName := range serviceNames { + _, err = impl.stopSingleService(serviceName) + if err != nil { + return err + } + } + + return nil + }) +} + +type ServiceManagerImpl struct { + serviceFactory ServiceFactory +} + +func (impl ServiceManagerImpl) WaitUntilServiceState(service cim.ServiceInterface, stateTransition stateTransitionFunc, stateCheck stateCheckFunc, interval time.Duration, timeout time.Duration) (string, error) { + done, state, err := stateCheck(service, "") if err != nil { - return fmt.Errorf("error starting service name=%s. cmd: %s, output: %s, error: %v", name, script, string(out), err) + return state, err + } + if done { + return state, err } - return nil + // Perform transition if not already in desired state + if err := stateTransition(service); err != nil { + return state, err + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + timeoutChan := time.After(timeout) + + for { + select { + case <-ticker.C: + klog.V(6).Infof("Checking service (%v) state...", service) + done, state, err = stateCheck(service, state) + if err != nil { + return state, fmt.Errorf("check failed: %w", err) + } + if done { + klog.V(6).Infof("service (%v) state is %s and transition done.", service, state) + return state, nil + } + case <-timeoutChan: + done, state, err = stateCheck(service, state) + return state, errTimedOut + } + } } -func (APIImplementor) StopService(name string, force bool) error { - script := `Stop-Service -Name $env:ServiceName -Force:$([System.Convert]::ToBoolean($env:Force))` - out, err := utils.RunPowershellCmd(script, fmt.Sprintf("ServiceName=%s", name), fmt.Sprintf("Force=%t", force)) +func (impl ServiceManagerImpl) GetDependentsForService(name string) ([]string, error) { + var serviceNames []string + var servicesToCheck []cim.ServiceInterface + servicesByName := map[string]string{} + + service, err := impl.serviceFactory.GetService(name) if err != nil { - return fmt.Errorf("error stopping service name=%s. cmd: %s, output: %s, error: %v", name, script, string(out), err) + return serviceNames, err + } + + servicesToCheck = append(servicesToCheck, service) + i := 0 + for i < len(servicesToCheck) { + service = servicesToCheck[i] + i += 1 + + serviceName, err := cim.GetServiceName(service) + if err != nil { + return serviceNames, err + } + + currentState, err := cim.GetServiceState(service) + if err != nil { + return serviceNames, err + } + + if currentState != serviceStateRunning { + continue + } + + servicesByName[serviceName] = serviceName + // prepend the current service to the front + serviceNames = append([]string{serviceName}, serviceNames...) + + dependents, err := service.GetDependents() + if err != nil { + return serviceNames, err + } + + servicesToCheck = append(servicesToCheck, dependents...) } - return nil + return serviceNames, nil } diff --git a/pkg/os/system/api_test.go b/pkg/os/system/api_test.go new file mode 100644 index 00000000..b977c74c --- /dev/null +++ b/pkg/os/system/api_test.go @@ -0,0 +1,214 @@ +package system + +import ( + "fmt" + "testing" + "time" + + "github.com/kubernetes-csi/csi-proxy/pkg/cim" + "github.com/pkg/errors" +) + +type MockService struct { + Name string + DisplayName string + State string + StartMode string + Dependents []cim.ServiceInterface + + StartResult uint32 + StopResult uint32 + + Err error +} + +func (m *MockService) GetPropertyName() (string, error) { + return m.Name, m.Err +} + +func (m *MockService) GetPropertyDisplayName() (string, error) { + return m.DisplayName, m.Err +} + +func (m *MockService) GetPropertyState() (string, error) { + return m.State, m.Err +} + +func (m *MockService) GetPropertyStartMode() (string, error) { + return m.StartMode, m.Err +} + +func (m *MockService) GetDependents() ([]cim.ServiceInterface, error) { + return m.Dependents, m.Err +} + +func (m *MockService) StartService() (uint32, error) { + m.State = "Running" + return m.StartResult, m.Err +} + +func (m *MockService) StopService() (uint32, error) { + m.State = "Stopped" + return m.StopResult, m.Err +} + +func (m *MockService) Refresh() error { + return nil +} + +type MockServiceFactory struct { + Services map[string]cim.ServiceInterface + Err error +} + +func (f *MockServiceFactory) GetService(name string) (cim.ServiceInterface, error) { + svc, ok := f.Services[name] + if !ok { + return nil, fmt.Errorf("service not found: %s", name) + } + return svc, f.Err +} + +func TestWaitUntilServiceState_Success(t *testing.T) { + svc := &MockService{Name: "svc", State: "Stopped"} + + stateChanged := false + + stateCheck := func(s cim.ServiceInterface, state string) (bool, string, error) { + if stateChanged { + svc.State = serviceStateRunning + return true, svc.State, nil + } + return false, svc.State, nil + } + + stateTransition := func(s cim.ServiceInterface) error { + stateChanged = true + return nil + } + + impl := ServiceManagerImpl{} + state, err := impl.WaitUntilServiceState(svc, stateTransition, stateCheck, 10*time.Millisecond, 500*time.Millisecond) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state != serviceStateRunning { + t.Fatalf("expected state %q, got %q", serviceStateRunning, state) + } +} + +func TestWaitUntilServiceState_Timeout(t *testing.T) { + svc := &MockService{Name: "svc", State: "Stopped"} + + stateCheck := func(s cim.ServiceInterface, state string) (bool, string, error) { + return false, svc.State, nil + } + + stateTransition := func(s cim.ServiceInterface) error { + return nil + } + + impl := ServiceManagerImpl{} + state, err := impl.WaitUntilServiceState(svc, stateTransition, stateCheck, 10*time.Millisecond, 50*time.Millisecond) + if !errors.Is(err, errTimedOut) { + t.Fatalf("expected timeout error, got %v", err) + } + if state != svc.State { + t.Fatalf("expected state %q, got %q", svc.State, state) + } +} + +func TestWaitUntilServiceState_TransitionFails(t *testing.T) { + svc := &MockService{Name: "svc", State: "Stopped"} + + stateCheck := func(s cim.ServiceInterface, state string) (bool, string, error) { + return false, svc.State, nil + } + + stateTransition := func(s cim.ServiceInterface) error { + return fmt.Errorf("transition failed") + } + + impl := ServiceManagerImpl{} + _, err := impl.WaitUntilServiceState(svc, stateTransition, stateCheck, 10*time.Millisecond, 50*time.Millisecond) + if err == nil || err.Error() != "transition failed" { + t.Fatalf("expected transition error, got %v", err) + } +} + +func TestGetDependentsForService(t *testing.T) { + // Construct the dependency tree + svcC := &MockService{Name: "C", State: serviceStateRunning} + svcB := &MockService{Name: "B", State: serviceStateRunning, Dependents: []cim.ServiceInterface{svcC}} + svcA := &MockService{Name: "A", State: serviceStateRunning, Dependents: []cim.ServiceInterface{svcB}} + + factory := &MockServiceFactory{ + Services: map[string]cim.ServiceInterface{ + "A": svcA, + "B": svcB, + "C": svcC, + }, + } + + impl := ServiceManagerImpl{ + serviceFactory: factory, + } + + names, err := impl.GetDependentsForService("A") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []string{"C", "B", "A"} + if len(names) != len(expected) { + t.Fatalf("expected %d services, got %d", len(expected), len(names)) + } + for i, name := range expected { + if names[i] != name { + t.Errorf("expected %s at position %d, got %s", name, i, names[i]) + } + } +} + +func TestGetDependentsForService_SkipsNonRunning(t *testing.T) { + svcB := &MockService{Name: "B", State: "Stopped"} + svcA := &MockService{Name: "A", State: serviceStateRunning, Dependents: []cim.ServiceInterface{svcB}} + + factory := &MockServiceFactory{ + Services: map[string]cim.ServiceInterface{ + "A": svcA, + "B": svcB, + }, + } + + impl := ServiceManagerImpl{ + serviceFactory: factory, + } + + names, err := impl.GetDependentsForService("A") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []string{"A"} // B is skipped due to stopped state + if len(names) != len(expected) { + t.Fatalf("expected %d services, got %d", len(expected), len(names)) + } +} + +func TestGetDependenciesForService_Winmgmt(t *testing.T) { + impl := ServiceManagerImpl{ + serviceFactory: cim.Win32ServiceFactory{}, + } + + serviceName := "Winmgmt" + names, err := impl.GetDependentsForService(serviceName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := 4 + if len(names) != expected || names[len(names)-1] != serviceName { + t.Fatalf("expected %d services, got %d", expected, len(names)) + } +} diff --git a/pkg/os/volume/api.go b/pkg/os/volume/api.go index 5bdf0e04..87a10f0f 100644 --- a/pkg/os/volume/api.go +++ b/pkg/os/volume/api.go @@ -2,21 +2,21 @@ package volume import ( "fmt" - "os" "path/filepath" "regexp" - "strconv" "strings" - "github.com/go-ole/go-ole" "github.com/kubernetes-csi/csi-proxy/pkg/cim" "github.com/kubernetes-csi/csi-proxy/pkg/utils" - wmierrors "github.com/microsoft/wmi/pkg/errors" "github.com/pkg/errors" "golang.org/x/sys/windows" "k8s.io/klog/v2" ) +const ( + minimumResizeSize = 100 * 1024 * 1024 +) + // API exposes the internal volume operations available in the server type API interface { // ListVolumesOnDisk lists volumes on a disk identified by a `diskNumber` and optionally a partition identified by `partitionNumber`. @@ -69,57 +69,55 @@ func New() VolumeAPI { // ListVolumesOnDisk - returns back list of volumes(volumeIDs) in a disk and a partition. func (VolumeAPI) ListVolumesOnDisk(diskNumber uint32, partitionNumber uint32) (volumeIDs []string, err error) { - partitions, err := cim.ListPartitionsOnDisk(diskNumber, partitionNumber, []string{"ObjectId"}) - if err != nil { - return nil, errors.Wrapf(err, "failed to list partition on disk %d", diskNumber) - } - - volumes, err := cim.ListVolumes([]string{"ObjectId", "UniqueId"}) - if err != nil { - return nil, errors.Wrapf(err, "failed to list volumes") - } + err = cim.WithCOMThread(func() error { + partitions, err := cim.ListPartitionsOnDisk(diskNumber, partitionNumber, cim.PartitionSelectorListObjectID) + if err != nil { + return errors.Wrapf(err, "failed to list partition on disk %d", diskNumber) + } - filtered, err := cim.FindVolumesByPartition(volumes, partitions) - if cim.IgnoreNotFound(err) != nil { - return nil, errors.Wrapf(err, "failed to list volumes on disk %d", diskNumber) - } + volumes, err := cim.FindVolumesByPartition(partitions) + if cim.IgnoreNotFound(err) != nil { + return errors.Wrapf(err, "failed to list volumes on disk %d", diskNumber) + } - for _, volume := range filtered { - uniqueID, err := volume.GetPropertyUniqueId() - if err != nil { - return nil, errors.Wrapf(err, "failed to list volumes") + for _, volume := range volumes { + uniqueID, err := cim.GetVolumeUniqueID(volume) + if err != nil { + return errors.Wrapf(err, "failed to get unique ID for volume %v", volume) + } + volumeIDs = append(volumeIDs, uniqueID) } - volumeIDs = append(volumeIDs, uniqueID) - } - return volumeIDs, nil + return nil + }) + return } // FormatVolume - Formats a volume with the NTFS format. func (VolumeAPI) FormatVolume(volumeID string) (err error) { - volume, err := cim.QueryVolumeByUniqueID(volumeID, nil) - if err != nil { - return fmt.Errorf("error formatting volume (%s). error: %v", volumeID, err) - } + return cim.WithCOMThread(func() error { + volume, err := cim.QueryVolumeByUniqueID(volumeID, nil) + if err != nil { + return fmt.Errorf("error formatting volume (%s). error: %v", volumeID, err) + } - result, err := volume.InvokeMethodWithReturn( - "Format", - "NTFS", // Format, - "", // FileSystemLabel, - nil, // AllocationUnitSize, - false, // Full, - true, // Force - nil, // Compress, - nil, // ShortFileNameSupport, - nil, // SetIntegrityStreams, - nil, // UseLargeFRS, - nil, // DisableHeatGathering, - ) - if result != 0 || err != nil { - return fmt.Errorf("error formatting volume (%s). result: %d, error: %v", volumeID, result, err) - } - // TODO: Do we need to handle anything for len(out) == 0 - return nil + result, err := cim.FormatVolume(volume, + "NTFS", // Format, + "", // FileSystemLabel, + nil, // AllocationUnitSize, + false, // Full, + true, // Force + nil, // Compress, + nil, // ShortFileNameSupport, + nil, // SetIntegrityStreams, + nil, // UseLargeFRS, + nil, // DisableHeatGathering, + ) + if result != 0 || err != nil { + return fmt.Errorf("error formatting volume (%s). result: %d, error: %v", volumeID, result, err) + } + return nil + }) } // WriteVolumeCache - Writes the file system cache to disk with the given volume id @@ -129,18 +127,22 @@ func (VolumeAPI) WriteVolumeCache(volumeID string) (err error) { // IsVolumeFormatted - Check if the volume is formatted with the pre specified filesystem(typically ntfs). func (VolumeAPI) IsVolumeFormatted(volumeID string) (bool, error) { - volume, err := cim.QueryVolumeByUniqueID(volumeID, []string{"FileSystemType"}) - if err != nil { - return false, fmt.Errorf("error checking if volume (%s) is formatted. error: %v", volumeID, err) - } + var formatted bool + err := cim.WithCOMThread(func() error { + volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForFileSystemType) + if err != nil { + return fmt.Errorf("error checking if volume (%s) is formatted. error: %v", volumeID, err) + } - fsType, err := volume.GetProperty("FileSystemType") - if err != nil { - return false, fmt.Errorf("failed to query volume file system type (%s): %w", volumeID, err) - } + fsType, err := cim.GetVolumeFileSystemType(volume) + if err != nil { + return fmt.Errorf("failed to query volume file system type (%s): %w", volumeID, err) + } - const FileSystemUnknown = 0 - return fsType.(int32) != FileSystemUnknown, nil + formatted = fsType != cim.FileSystemUnknown + return nil + }) + return formatted, err } // MountVolume - mounts a volume to a path. This is done using Win32 API SetVolumeMountPoint for presenting the volume via a path. @@ -190,124 +192,112 @@ func (VolumeAPI) UnmountVolume(volumeID, path string) error { // ResizeVolume - resizes a volume with the given size, if size == 0 then max supported size is used func (VolumeAPI) ResizeVolume(volumeID string, size int64) error { - var err error - var finalSize int64 - part, err := cim.GetPartitionByVolumeUniqueID(volumeID, nil) - if err != nil { - return err - } - - // If size is 0 then we will resize to the maximum size possible, otherwise just resize to size - if size == 0 { - var sizeMin, sizeMax ole.VARIANT - var status string - result, err := part.InvokeMethodWithReturn("GetSupportedSize", &sizeMin, &sizeMax, &status) - if result != 0 || err != nil { - return fmt.Errorf("error getting sizeMin, sizeMax from volume(%s). result: %d, status: %s, error: %v", volumeID, result, status, err) - } - klog.V(5).Infof("got sizeMin(%v) sizeMax(%v) from volume(%s), status: %s", sizeMin, sizeMax, volumeID, status) - - finalSizeStr := sizeMax.ToString() - finalSize, err = strconv.ParseInt(finalSizeStr, 10, 64) + return cim.WithCOMThread(func() error { + var err error + var finalSize int64 + part, err := cim.GetPartitionByVolumeUniqueID(volumeID) if err != nil { - return fmt.Errorf("error parsing the sizeMax of volume (%s) with error (%v)", volumeID, err) + return err } - } else { - finalSize = size - } - currentSizeVal, err := part.GetProperty("Size") - if err != nil { - return fmt.Errorf("error getting the current size of volume (%s) with error (%v)", volumeID, err) - } + // If size is 0 then we will resize to the maximum size possible, otherwise just resize to size + if size == 0 { + var result int + var status string + result, _, finalSize, status, err = cim.GetPartitionSupportedSize(part) + if result != 0 || err != nil { + return fmt.Errorf("error getting sizeMin, sizeMax from volume (%s). result: %d, status: %s, error: %v", volumeID, result, status, err) + } - currentSize, err := strconv.ParseInt(currentSizeVal.(string), 10, 64) - if err != nil { - return fmt.Errorf("error parsing the current size of volume (%s) with error (%v)", volumeID, err) - } + } else { + finalSize = size + } - // only resize if finalSize - currentSize is greater than 100MB - if finalSize-currentSize < 100*1024*1024 { - klog.V(2).Infof("minimum resize difference(1GB) not met, skipping resize. volumeID=%s currentSize=%d finalSize=%d", volumeID, currentSize, finalSize) - return nil - } + currentSize, err := cim.GetPartitionSize(part) + if err != nil { + return fmt.Errorf("error getting the current size of volume (%s) with error (%v)", volumeID, err) + } - //if the partition's size is already the size we want this is a noop, just return - if currentSize >= finalSize { - klog.V(2).Infof("Attempted to resize volume (%s) to a lower size, from currentBytes=%d wantedBytes=%d", volumeID, currentSize, finalSize) - return nil - } + // only resize if finalSize - currentSize is greater than 100MB + if finalSize-currentSize < minimumResizeSize { + klog.V(2).Infof("minimum resize difference (100MB) not met, skipping resize. volumeID=%s currentSize=%d finalSize=%d", volumeID, currentSize, finalSize) + return nil + } - var status string - result, err := part.InvokeMethodWithReturn("Resize", strconv.Itoa(int(finalSize)), &status) + //if the partition's size is already the size we want this is a noop, just return + if currentSize >= finalSize { + klog.V(2).Infof("Attempted to resize volume (%s) to a lower size, from currentBytes=%d wantedBytes=%d", volumeID, currentSize, finalSize) + return nil + } - if result != 0 || err != nil { - return fmt.Errorf("error resizing volume (%s). size:%v, finalSize %v, error: %v", volumeID, size, finalSize, err) - } + result, _, err := cim.ResizePartition(part, finalSize) + if result != 0 || err != nil { + return fmt.Errorf("error resizing volume (%s). size:%v, finalSize %v, error: %v", volumeID, size, finalSize, err) + } - diskNumber, err := cim.GetPartitionDiskNumber(part) - if err != nil { - return fmt.Errorf("error parsing disk number of volume (%s). error: %v", volumeID, err) - } + diskNumber, err := cim.GetPartitionDiskNumber(part) + if err != nil { + return fmt.Errorf("error parsing disk number of volume (%s). error: %v", volumeID, err) + } - disk, err := cim.QueryDiskByNumber(diskNumber, nil) - if err != nil { - return fmt.Errorf("error parsing disk number of volume (%s). error: %v", volumeID, err) - } + disk, err := cim.QueryDiskByNumber(diskNumber, nil) + if err != nil { + return fmt.Errorf("error query disk of volume (%s). error: %v", volumeID, err) + } - result, err = disk.InvokeMethodWithReturn("Refresh", &status) - if result != 0 || err != nil { - return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err) - } + result, _, err = cim.RefreshDisk(disk) + if result != 0 || err != nil { + return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err) + } - return nil + return nil + }) } // GetVolumeStats - retrieves the volume stats for a given volume -func (VolumeAPI) GetVolumeStats(volumeID string) (int64, int64, error) { - volume, err := cim.QueryVolumeByUniqueID(volumeID, []string{"UniqueId", "SizeRemaining", "Size"}) - if err != nil { - return -1, -1, fmt.Errorf("error getting capacity and used size of volume (%s). error: %v", volumeID, err) - } - - volumeSizeVal, err := volume.GetProperty("Size") - if err != nil { - return -1, -1, fmt.Errorf("failed to query volume size (%s): %w", volumeID, err) - } - - volumeSize, err := strconv.ParseInt(volumeSizeVal.(string), 10, 64) - if err != nil { - return -1, -1, fmt.Errorf("failed to parse volume size (%s): %w", volumeID, err) - } +func (VolumeAPI) GetVolumeStats(volumeID string) (volumeSize, volumeUsedSize int64, err error) { + volumeSize = -1 + volumeUsedSize = -1 + err = cim.WithCOMThread(func() error { + volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForStats) + if err != nil { + return fmt.Errorf("error getting capacity and used size of volume (%s). error: %v", volumeID, err) + } - volumeSizeRemainingVal, err := volume.GetProperty("SizeRemaining") - if err != nil { - return -1, -1, fmt.Errorf("failed to query volume remaining size (%s): %w", volumeID, err) - } + volumeSize, err = cim.GetVolumeSize(volume) + if err != nil { + return fmt.Errorf("failed to query volume size (%s): %w", volumeID, err) + } - volumeSizeRemaining, err := strconv.ParseInt(volumeSizeRemainingVal.(string), 10, 64) - if err != nil { - return -1, -1, fmt.Errorf("failed to parse volume remaining size (%s): %w", volumeID, err) - } + volumeSizeRemaining, err := cim.GetVolumeSizeRemaining(volume) + if err != nil { + return fmt.Errorf("failed to query volume remaining size (%s): %w", volumeID, err) + } - volumeUsedSize := volumeSize - volumeSizeRemaining - return volumeSize, volumeUsedSize, nil + volumeUsedSize = volumeSize - volumeSizeRemaining + return nil + }) + return } // GetDiskNumberFromVolumeID - gets the disk number where the volume is. func (VolumeAPI) GetDiskNumberFromVolumeID(volumeID string) (uint32, error) { - // get the size and sizeRemaining for the volume - part, err := cim.GetPartitionByVolumeUniqueID(volumeID, []string{"DiskNumber"}) - if err != nil { - return 0, err - } + var diskNumber uint32 + err := cim.WithCOMThread(func() error { + // get the size and sizeRemaining for the volume + part, err := cim.GetPartitionByVolumeUniqueID(volumeID) + if err != nil { + return err + } - diskNumber, err := part.GetProperty("DiskNumber") - if err != nil { - return 0, fmt.Errorf("error query disk number of volume (%s). error: %v", volumeID, err) - } + diskNumber, err = cim.GetPartitionDiskNumber(part) + if err != nil { + return fmt.Errorf("error query disk number of volume (%s). error: %v", volumeID, err) + } - return uint32(diskNumber.(int32)), nil + return nil + }) + return diskNumber, err } // GetVolumeIDFromTargetPath - gets the volume ID given a mount point, the function is recursive until it find a volume or errors out @@ -321,7 +311,7 @@ func (VolumeAPI) GetVolumeIDFromTargetPath(mount string) (string, error) { } func getTarget(mount string) (string, error) { - mountedFolder, err := isMountedFolder(mount) + mountedFolder, err := utils.IsMountedFolder(mount) if err != nil { return "", err } @@ -361,7 +351,7 @@ func (VolumeAPI) GetClosestVolumeIDFromTargetPath(targetPath string) (string, er } // findClosestVolume finds the closest volume id for a given target path -// by following symlinks and moving up in the filesystem, if after moving up in the filesystem +// by following symlinks and moving up in the filesystem. if after moving up in the filesystem // we get to a DriveLetter then the volume corresponding to this drive letter is returned instead. func findClosestVolume(path string) (string, error) { candidatePath := path @@ -370,22 +360,20 @@ func findClosestVolume(path string) (string, error) { // while trying to follow symlinks // // The maximum path length in Windows is 260, it could be possible to end - // up in a sceneario where we do more than 256 iterations (e.g. by following symlinks from + // up in a scenario where we do more than 256 iterations (e.g. by following symlinks from // a place high in the hierarchy to a nested sibling location many times) // https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file#:~:text=In%20editions%20of%20Windows%20before,required%20to%20remove%20the%20limit. // // The number of iterations is 256, which is similar to the number of iterations in filepath-securejoin // https://github.com/cyphar/filepath-securejoin/blob/64536a8a66ae59588c981e2199f1dcf410508e07/join.go#L51 for i := 0; i < 256; i += 1 { - fi, err := os.Lstat(candidatePath) + isSymlink, err := utils.IsPathSymlink(candidatePath) if err != nil { return "", err } - // for windows NTFS, check if the path is symlink instead of directory. - isSymlink := fi.Mode()&os.ModeSymlink != 0 || fi.Mode()&os.ModeIrregular != 0 // mounted folder created by SetVolumeMountPoint may still report ModeSymlink == 0 - mountedFolder, err := isMountedFolder(candidatePath) + mountedFolder, err := utils.IsMountedFolder(candidatePath) if err != nil { return "", err } @@ -422,67 +410,40 @@ func findClosestVolume(path string) (string, error) { return "", fmt.Errorf("failed to find the closest volume for path=%s", path) } -// isMountedFolder checks whether the `path` is a mounted folder. -func isMountedFolder(path string) (bool, error) { - // https://learn.microsoft.com/en-us/windows/win32/fileio/determining-whether-a-directory-is-a-volume-mount-point - utf16Path, _ := windows.UTF16PtrFromString(path) - attrs, err := windows.GetFileAttributes(utf16Path) - if err != nil { - return false, err - } - - if (attrs & windows.FILE_ATTRIBUTE_REPARSE_POINT) == 0 { - return false, nil - } - - var findData windows.Win32finddata - findHandle, err := windows.FindFirstFile(utf16Path, &findData) - if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) { - return false, err - } - - for err == nil { - if findData.Reserved0&windows.IO_REPARSE_TAG_MOUNT_POINT != 0 { - return true, nil - } - - err = windows.FindNextFile(findHandle, &findData) - if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) { - return false, err - } - } - - return false, nil -} - // getVolumeForDriveLetter gets a volume from a drive letter (e.g. C:/). func getVolumeForDriveLetter(path string) (string, error) { if len(path) != 1 { return "", fmt.Errorf("the path %s is not a valid drive letter", path) } - volume, err := cim.GetVolumeByDriveLetter(path, []string{"UniqueId"}) - if err != nil { - return "", nil - } + var uniqueID string + err := cim.WithCOMThread(func() error { + volume, err := cim.GetVolumeByDriveLetter(path, cim.VolumeSelectorListUniqueID) + if err != nil { + return err + } - uniqueID, err := volume.GetPropertyUniqueId() - if err != nil { - return "", fmt.Errorf("error query unique ID of volume (%v). error: %v", volume, err) - } + uniqueID, err = cim.GetVolumeUniqueID(volume) + if err != nil { + return fmt.Errorf("error query unique ID of volume (%v). error: %v", volume, err) + } - return uniqueID, nil + return nil + }) + return uniqueID, err } func writeCache(volumeID string) error { - volume, err := cim.QueryVolumeByUniqueID(volumeID, []string{}) - if err != nil && !wmierrors.IsNotFound(err) { - return fmt.Errorf("error writing volume (%s) cache. error: %v", volumeID, err) - } + return cim.WithCOMThread(func() error { + volume, err := cim.QueryVolumeByUniqueID(volumeID, nil) + if err != nil { + return fmt.Errorf("error writing volume (%s) cache. error: %v", volumeID, err) + } - result, err := volume.Flush() - if result != 0 || err != nil { - return fmt.Errorf("error writing volume (%s) cache. result: %d, error: %v", volumeID, result, err) - } - return nil + result, err := cim.FlushVolume(volume) + if result != 0 || err != nil { + return fmt.Errorf("error writing volume (%s) cache. result: %d, error: %v", volumeID, result, err) + } + return nil + }) } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 102675ac..f7a69492 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -1,10 +1,12 @@ package utils import ( + "errors" + "fmt" "os" - "os/exec" "strings" + "golang.org/x/sys/windows" "k8s.io/klog/v2" ) @@ -22,10 +24,88 @@ func EnsureLongPath(path string) string { return path } -func RunPowershellCmd(command string, envs ...string) ([]byte, error) { - cmd := exec.Command("powershell", "-Mta", "-NoProfile", "-Command", command) - cmd.Env = append(os.Environ(), envs...) - klog.V(8).Infof("Executing command: %q", cmd.String()) - out, err := cmd.CombinedOutput() - return out, err +func IsPathValid(path string) (bool, error) { + pathString, err := windows.UTF16PtrFromString(path) + if err != nil { + return false, fmt.Errorf("invalid path: %w", err) + } + + attrs, err := windows.GetFileAttributes(pathString) + if err != nil { + if errors.Is(err, windows.ERROR_PATH_NOT_FOUND) || errors.Is(err, windows.ERROR_FILE_NOT_FOUND) || errors.Is(err, windows.ERROR_INVALID_NAME) { + return false, nil + } + + // GetFileAttribute returns user or password incorrect for a disconnected SMB connection after the password is changed + return false, fmt.Errorf("failed to get path %s attribute: %w", path, err) + } + + klog.V(6).Infof("Path %s attribute: %d", path, attrs) + return attrs != windows.INVALID_FILE_ATTRIBUTES, nil +} + +// IsMountedFolder checks whether the `path` is a mounted folder. +func IsMountedFolder(path string) (bool, error) { + // https://learn.microsoft.com/en-us/windows/win32/fileio/determining-whether-a-directory-is-a-volume-mount-point + utf16Path, _ := windows.UTF16PtrFromString(path) + attrs, err := windows.GetFileAttributes(utf16Path) + if err != nil { + return false, err + } + + if (attrs & windows.FILE_ATTRIBUTE_REPARSE_POINT) == 0 { + return false, nil + } + + var findData windows.Win32finddata + findHandle, err := windows.FindFirstFile(utf16Path, &findData) + if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) { + return false, err + } + + for err == nil { + if findData.Reserved0&windows.IO_REPARSE_TAG_MOUNT_POINT != 0 { + return true, nil + } + + err = windows.FindNextFile(findHandle, &findData) + if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) { + return false, err + } + } + + return false, nil +} + +func IsPathSymlink(path string) (bool, error) { + fi, err := os.Lstat(path) + if err != nil { + return false, err + } + // for windows NTFS, check if the path is symlink instead of directory. + isSymlink := fi.Mode()&os.ModeSymlink != 0 || fi.Mode()&os.ModeIrregular != 0 + return isSymlink, nil +} + +func CreateSymlink(link, target string, isDir bool) error { + linkPtr, err := windows.UTF16PtrFromString(link) + if err != nil { + return err + } + targetPtr, err := windows.UTF16PtrFromString(target) + if err != nil { + return err + } + + var flags uint32 + if isDir { + flags = windows.SYMBOLIC_LINK_FLAG_DIRECTORY + } + + err = windows.CreateSymbolicLink( + linkPtr, + targetPtr, + flags, + ) + return err }