From a63a4ec358b7da2a6a3f4b932d9d73d9f9afbe08 Mon Sep 17 00:00:00 2001 From: Zhongcheng Lao Date: Fri, 20 Jun 2025 12:56:48 +0800 Subject: [PATCH 1/2] Move WMI functions to cim package --- pkg/cim/disk.go | 103 +++++++++++++++++++++++++++++ pkg/cim/smb.go | 53 ++++++++++++++- pkg/cim/volume.go | 151 +++++++++++++++++++++++++++++++++++++++++++ pkg/cim/wmi.go | 13 ++-- pkg/os/disk/api.go | 82 +++++++++-------------- pkg/os/smb/api.go | 68 ++----------------- pkg/os/volume/api.go | 131 ++++++++++--------------------------- pkg/utils/utils.go | 75 ++++++++++++++++++--- 8 files changed, 453 insertions(+), 223 deletions(-) diff --git a/pkg/cim/disk.go b/pkg/cim/disk.go index 0298b793..58c8f376 100644 --- a/pkg/cim/disk.go +++ b/pkg/cim/disk.go @@ -23,6 +23,17 @@ const ( // GPTPartitionTypeMicrosoftReserved is the GUID for Microsoft Reserved Partition (MSR) // Reserved by Windows for system use GPTPartitionTypeMicrosoftReserved = "{e3c9e316-0b5c-4db8-817d-f92df00215ae}" + + // ErrorCodeCreatePartitionAccessPathAlreadyInUse is the error code (42002) returned when the driver letter failed to assign after partition created + ErrorCodeCreatePartitionAccessPathAlreadyInUse = 42002 +) + +var ( + DiskSelectorListForDiskNumberAndLocation = []string{"Number", "Location"} + DiskSelectorListForPartitionStyle = []string{"PartitionStyle"} + DiskSelectorListForPathAndSerialNumber = []string{"Path", "SerialNumber"} + DiskSelectorListForIsOffline = []string{"IsOffline"} + DiskSelectorListForSize = []string{"Size"} ) // QueryDiskByNumber retrieves disk information for a specific disk identified by its number. @@ -77,6 +88,60 @@ func ListDisks(selectorList []string) ([]*storage.MSFT_Disk, error) { return disks, nil } +// InitializeDisk initializes a RAW disk with a particular partition style. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/initialize-msft-disk +// for the WMI method definition. +func InitializeDisk(disk *storage.MSFT_Disk, partitionStyle int) (int, error) { + result, err := disk.InvokeMethodWithReturn("Initialize", int32(partitionStyle)) + return int(result), err +} + +// RefreshDisk Refreshes the cached disk layout information. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-refresh +// for the WMI method definition. +func RefreshDisk(disk *storage.MSFT_Disk) (int, string, error) { + var status string + result, err := disk.InvokeMethodWithReturn("Refresh", &status) + return int(result), status, err +} + +// CreatePartition creates a partition on a disk. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/createpartition-msft-disk +// for the WMI method definition. +func CreatePartition(disk *storage.MSFT_Disk, params ...interface{}) (int, error) { + result, err := disk.InvokeMethodWithReturn("CreatePartition", params...) + return int(result), err +} + +// SetDiskState takes a disk online or offline. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-online and +// https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-offline +// for the WMI method definition. +func SetDiskState(disk *storage.MSFT_Disk, online bool) (int, string, error) { + method := "Offline" + if online { + method = "Online" + } + + var status string + result, err := disk.InvokeMethodWithReturn(method, &status) + return int(result), status, err +} + +// RescanDisks rescans all changes by updating the internal cache of software objects (that is, Disks, Partitions, Volumes) +// for the storage setting. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-storagesetting-updatehoststoragecache +// for the WMI method definition. +func RescanDisks() (int, error) { + result, _, err := InvokeCimMethod(WMINamespaceStorage, "MSFT_StorageSetting", "UpdateHostStorageCache", nil) + return result, err +} + // GetDiskNumber returns the number of a disk. func GetDiskNumber(disk *storage.MSFT_Disk) (uint32, error) { number, err := disk.GetProperty("Number") @@ -85,3 +150,41 @@ func GetDiskNumber(disk *storage.MSFT_Disk) (uint32, error) { } return uint32(number.(int32)), err } + +// GetDiskLocation returns the location of a disk. +func GetDiskLocation(disk *storage.MSFT_Disk) (string, error) { + return disk.GetPropertyLocation() +} + +// GetDiskPartitionStyle returns the partition style of a disk. +func GetDiskPartitionStyle(disk *storage.MSFT_Disk) (int32, error) { + retValue, err := disk.GetProperty("PartitionStyle") + if err != nil { + return 0, err + } + return retValue.(int32), err +} + +// IsDiskOffline returns whether a disk is offline. +func IsDiskOffline(disk *storage.MSFT_Disk) (bool, error) { + return disk.GetPropertyIsOffline() +} + +// GetDiskSize returns the size of a disk. +func GetDiskSize(disk *storage.MSFT_Disk) (int64, error) { + sz, err := disk.GetProperty("Size") + if err != nil { + return -1, err + } + return strconv.ParseInt(sz.(string), 10, 64) +} + +// GetDiskPath returns the path of a disk. +func GetDiskPath(disk *storage.MSFT_Disk) (string, error) { + return disk.GetPropertyPath() +} + +// GetDiskSerialNumber returns the serial number of a disk. +func GetDiskSerialNumber(disk *storage.MSFT_Disk) (string, error) { + return disk.GetPropertySerialNumber() +} diff --git a/pkg/cim/smb.go b/pkg/cim/smb.go index 5868d456..2850ab78 100644 --- a/pkg/cim/smb.go +++ b/pkg/cim/smb.go @@ -4,6 +4,8 @@ package cim import ( + "strings" + "github.com/microsoft/wmi/pkg/base/query" cim "github.com/microsoft/wmi/pkg/wmiinstance" ) @@ -17,8 +19,24 @@ const ( SmbMappingStatusConnecting SmbMappingStatusReconnecting SmbMappingStatusUnavailable + + credentialDelimiter = ":" ) +// escapeQueryParameter escapes a parameter for WMI Queries +func escapeQueryParameter(s string) string { + s = strings.ReplaceAll(s, "'", "''") + s = strings.ReplaceAll(s, "\\", "\\\\") + return s +} + +func escapeUserName(userName string) string { + // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170 + userName = strings.ReplaceAll(userName, "\\", "\\\\") + userName = strings.ReplaceAll(userName, credentialDelimiter, "\\"+credentialDelimiter) + return userName +} + // QuerySmbGlobalMappingByRemotePath retrieves the SMB global mapping from its remote path. // // The equivalent WMI query is: @@ -28,7 +46,7 @@ const ( // Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping // for the WMI class definition. func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, error) { - smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath) + smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath)) instances, err := QueryInstances(WMINamespaceSmb, smbQuery) if err != nil { return nil, err @@ -37,12 +55,22 @@ func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, err return instances[0], err } -// RemoveSmbGlobalMappingByRemotePath removes a SMB global mapping matching to the remote path. +// GetSmbGlobalMappingStatus returns the status of an SMB global mapping. +func GetSmbGlobalMappingStatus(inst *cim.WmiInstance) (int32, error) { + statusProp, err := inst.GetProperty("Status") + if err != nil { + return SmbMappingStatusUnavailable, err + } + + return statusProp.(int32), nil +} + +// RemoveSmbGlobalMappingByRemotePath removes an SMB global mapping matching to the remote path. // // Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping // for the WMI class definition. func RemoveSmbGlobalMappingByRemotePath(remotePath string) error { - smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath) + smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath)) instances, err := QueryInstances(WMINamespaceSmb, smbQuery) if err != nil { return err @@ -51,3 +79,22 @@ func RemoveSmbGlobalMappingByRemotePath(remotePath string) error { _, err = instances[0].InvokeMethod("Remove", true) return err } + +// NewSmbGlobalMapping creates a new SMB global mapping to the remote path. +// +// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping +// for the WMI class definition. +func NewSmbGlobalMapping(remotePath, username, password string, requirePrivacy bool) (int, error) { + params := map[string]interface{}{ + "RemotePath": remotePath, + "RequirePrivacy": requirePrivacy, + } + if username != "" { + // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178 + // on how SMB credential is handled in PowerShell + params["Credential"] = escapeUserName(username) + credentialDelimiter + password + } + + result, _, err := InvokeCimMethod(WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params) + return result, err +} diff --git a/pkg/cim/volume.go b/pkg/cim/volume.go index b77fd6d6..0c880fe1 100644 --- a/pkg/cim/volume.go +++ b/pkg/cim/volume.go @@ -7,9 +7,23 @@ import ( "fmt" "strconv" + "github.com/go-ole/go-ole" "github.com/microsoft/wmi/pkg/base/query" "github.com/microsoft/wmi/pkg/errors" "github.com/microsoft/wmi/server2019/root/microsoft/windows/storage" + "k8s.io/klog/v2" +) + +const ( + FileSystemUnknown = 0 +) + +var ( + VolumeSelectorListForFileSystemType = []string{"FileSystemType"} + VolumeSelectorListForStats = []string{"UniqueId", "SizeRemaining", "Size"} + VolumeSelectorListUniqueID = []string{"UniqueId"} + + PartitionSelectorListObjectID = []string{"ObjectId"} ) // QueryVolumeByUniqueID retrieves a specific volume by its unique identifier, @@ -78,6 +92,68 @@ func ListVolumes(selectorList []string) ([]*storage.MSFT_Volume, error) { return volumes, nil } +// FormatVolume formats the specified volume. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/format-msft-volume +// for the WMI method definition. +func FormatVolume(volume *storage.MSFT_Volume, params ...interface{}) (int, error) { + result, err := volume.InvokeMethodWithReturn("Format", params...) + return int(result), err +} + +// FlushVolume flushes the cached data in the volume's file system to disk. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-volume-flush +// for the WMI method definition. +func FlushVolume(volume *storage.MSFT_Volume) (int, error) { + result, err := volume.Flush() + return int(result), err +} + +// GetVolumeUniqueID returns the unique ID (object ID) of a volume. +func GetVolumeUniqueID(volume *storage.MSFT_Volume) (string, error) { + return volume.GetPropertyUniqueId() +} + +// GetVolumeFileSystemType returns the file system type of a volume. +func GetVolumeFileSystemType(volume *storage.MSFT_Volume) (int32, error) { + fsType, err := volume.GetProperty("FileSystemType") + if err != nil { + return 0, err + } + return fsType.(int32), nil +} + +// GetVolumeSize returns the size of a volume. +func GetVolumeSize(volume *storage.MSFT_Volume) (int64, error) { + volumeSizeVal, err := volume.GetProperty("Size") + if err != nil { + return -1, err + } + + volumeSize, err := strconv.ParseInt(volumeSizeVal.(string), 10, 64) + if err != nil { + return -1, err + } + + return volumeSize, err +} + +// GetVolumeSizeRemaining returns the remaining size of a volume. +func GetVolumeSizeRemaining(volume *storage.MSFT_Volume) (int64, error) { + volumeSizeRemainingVal, err := volume.GetProperty("SizeRemaining") + if err != nil { + return -1, err + } + + volumeSizeRemaining, err := strconv.ParseInt(volumeSizeRemainingVal.(string), 10, 64) + if err != nil { + return -1, err + } + + return volumeSizeRemaining, err +} + // ListPartitionsOnDisk retrieves all partitions or a partition with the specified number on a disk. // // The equivalent WMI query is: @@ -245,3 +321,78 @@ func GetPartitionDiskNumber(part *storage.MSFT_Partition) (uint32, error) { return uint32(diskNumber.(int32)), nil } + +// SetPartitionState takes a partition online or offline. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-online and +// https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-offline +// for the WMI method definition. +func SetPartitionState(part *storage.MSFT_Partition, online bool) (int, string, error) { + method := "Offline" + if online { + method = "Online" + } + + var status string + result, err := part.InvokeMethodWithReturn(method, &status) + return int(result), status, err +} + +// GetPartitionSupportedSize retrieves the minimum and maximum sizes that the partition can be resized to using the ResizePartition method. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-getsupportedsizes +// for the WMI method definition. +func GetPartitionSupportedSize(part *storage.MSFT_Partition) (result int, sizeMin, sizeMax int64, status string, err error) { + sizeMin = -1 + sizeMax = -1 + + var sizeMinVar, sizeMaxVar ole.VARIANT + invokeResult, err := part.InvokeMethodWithReturn("GetSupportedSize", &sizeMinVar, &sizeMaxVar, &status) + if invokeResult != 0 || err != nil { + result = int(invokeResult) + } + klog.V(5).Infof("got sizeMin (%v) sizeMax (%v) from partition (%v), status: %s", sizeMinVar, sizeMaxVar, part, status) + + sizeMin, err = strconv.ParseInt(sizeMinVar.ToString(), 10, 64) + if err != nil { + return + } + + sizeMax, err = strconv.ParseInt(sizeMaxVar.ToString(), 10, 64) + return +} + +// ResizePartition resizes a partition. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-resize +// for the WMI method definition. +func ResizePartition(part *storage.MSFT_Partition, size int64) (int, string, error) { + var status string + result, err := part.InvokeMethodWithReturn("Resize", strconv.Itoa(int(size)), &status) + return int(result), status, err +} + +// GetPartitionSize returns the size of a partition. +func GetPartitionSize(part *storage.MSFT_Partition) (int64, error) { + sizeProp, err := part.GetProperty("Size") + if err != nil { + return -1, err + } + + size, err := strconv.ParseInt(sizeProp.(string), 10, 64) + if err != nil { + return -1, err + } + + return size, err +} + +// FilterForPartitionOnDisk creates a WMI query filter to query a disk by its number. +func FilterForPartitionOnDisk(diskNumber uint32) *query.WmiQueryFilter { + return query.NewWmiQueryFilter("DiskNumber", strconv.Itoa(int(diskNumber)), query.Equals) +} + +// FilterForPartitionsOfTypeNormal creates a WMI query filter for all non-reserved partitions. +func FilterForPartitionsOfTypeNormal() *query.WmiQueryFilter { + return query.NewWmiQueryFilter("GptType", GPTPartitionTypeMicrosoftReserved, query.NotEquals) +} diff --git a/pkg/cim/wmi.go b/pkg/cim/wmi.go index 1dacce8a..ba75f747 100644 --- a/pkg/cim/wmi.go +++ b/pkg/cim/wmi.go @@ -9,7 +9,7 @@ import ( "github.com/go-ole/go-ole" "github.com/go-ole/go-ole/oleutil" "github.com/microsoft/wmi/pkg/base/query" - "github.com/microsoft/wmi/pkg/errors" + wmierrors "github.com/microsoft/wmi/pkg/errors" cim "github.com/microsoft/wmi/pkg/wmiinstance" "k8s.io/klog/v2" ) @@ -61,7 +61,7 @@ func QueryFromWMI(namespace string, query *query.WmiQuery, handler InstanceHandl } if len(instances) == 0 { - return errors.NotFound + return wmierrors.NotFound } var cont bool @@ -95,7 +95,7 @@ func executeClassMethodParam(classInst *cim.WmiInstance, method *cim.WmiMethod, iDispatchInstance := classInst.GetIDispatch() if iDispatchInstance == nil { - return nil, errors.Wrapf(errors.InvalidInput, "InvalidInstance") + return nil, wmierrors.Wrapf(wmierrors.InvalidInput, "InvalidInstance") } rawResult, err := iDispatchInstance.GetProperty("Methods_") if err != nil { @@ -235,10 +235,15 @@ func InvokeCimMethod(namespace, class, methodName string, inputParameters map[st return int(result.ReturnValue), outputParameters, nil } +// IsNotFound returns true if it's a "not found" error. +func IsNotFound(err error) bool { + return wmierrors.IsNotFound(err) +} + // IgnoreNotFound returns nil if the error is nil or a "not found" error, // otherwise returns the original error. func IgnoreNotFound(err error) error { - if err == nil || errors.IsNotFound(err) { + if err == nil || IsNotFound(err) { return nil } return err diff --git a/pkg/os/disk/api.go b/pkg/os/disk/api.go index dc8637fd..b366b977 100644 --- a/pkg/os/disk/api.go +++ b/pkg/os/disk/api.go @@ -3,14 +3,12 @@ package disk import ( "encoding/hex" "fmt" - "strconv" "strings" "syscall" "unsafe" "github.com/kubernetes-csi/csi-proxy/pkg/cim" shared "github.com/kubernetes-csi/csi-proxy/pkg/shared/disk" - "github.com/microsoft/wmi/pkg/base/query" "k8s.io/klog/v2" ) @@ -67,19 +65,19 @@ func New() DiskAPI { // as the value. The DiskLocation struct has various fields like the Adapter, Bus, Target and LUNID. func (imp DiskAPI) ListDiskLocations() (map[uint32]shared.DiskLocation, error) { // "location": "PCI Slot 3 : Adapter 0 : Port 0 : Target 1 : LUN 0" - disks, err := cim.ListDisks([]string{"Number", "Location"}) + disks, err := cim.ListDisks(cim.DiskSelectorListForDiskNumberAndLocation) if err != nil { return nil, fmt.Errorf("could not query disk locations") } m := make(map[uint32]shared.DiskLocation) for _, disk := range disks { - num, err := disk.GetProperty("Number") + num, err := cim.GetDiskNumber(disk) if err != nil { return m, fmt.Errorf("failed to query disk number: %v, %w", disk, err) } - location, err := disk.GetPropertyLocation() + location, err := cim.GetDiskLocation(disk) if err != nil { return m, fmt.Errorf("failed to query disk location: %v, %w", disk, err) } @@ -107,7 +105,7 @@ func (imp DiskAPI) ListDiskLocations() (map[uint32]shared.DiskLocation, error) { } if found { - m[uint32(num.(int32))] = d + m[num] = d } } } @@ -116,7 +114,7 @@ func (imp DiskAPI) ListDiskLocations() (map[uint32]shared.DiskLocation, error) { } func (imp DiskAPI) Rescan() error { - result, _, err := cim.InvokeCimMethod(cim.WMINamespaceStorage, "MSFT_StorageSetting", "UpdateHostStorageCache", nil) + result, err := cim.RescanDisks() if err != nil { return fmt.Errorf("error updating host storage cache output. result: %d, err: %v", result, err) } @@ -124,18 +122,16 @@ func (imp DiskAPI) Rescan() error { } func (imp DiskAPI) IsDiskInitialized(diskNumber uint32) (bool, error) { - var partitionStyle int32 - disk, err := cim.QueryDiskByNumber(diskNumber, []string{"PartitionStyle"}) + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForPartitionStyle) if err != nil { - return false, fmt.Errorf("error checking initialized status of disk %d. %v", diskNumber, err) + return false, fmt.Errorf("error checking initialized status of disk %d: %v", diskNumber, err) } - retValue, err := disk.GetProperty("PartitionStyle") + partitionStyle, err := cim.GetDiskPartitionStyle(disk) if err != nil { - return false, fmt.Errorf("failed to query partition style of disk %d: %w", diskNumber, err) + return false, fmt.Errorf("failed to query partition style of disk %d: %v", diskNumber, err) } - partitionStyle = retValue.(int32) return partitionStyle != cim.PartitionStyleUnknown, nil } @@ -145,7 +141,7 @@ func (imp DiskAPI) InitializeDisk(diskNumber uint32) error { return fmt.Errorf("failed to initializing disk %d. error: %w", diskNumber, err) } - result, err := disk.InvokeMethodWithReturn("Initialize", int32(cim.PartitionStyleGPT)) + result, err := cim.InitializeDisk(disk, cim.PartitionStyleGPT) if result != 0 || err != nil { return fmt.Errorf("failed to initializing disk %d: result %d, error: %w", diskNumber, result, err) } @@ -154,9 +150,7 @@ func (imp DiskAPI) InitializeDisk(diskNumber uint32) error { } func (imp DiskAPI) BasicPartitionsExist(diskNumber uint32) (bool, error) { - partitions, err := cim.ListPartitionsWithFilters(nil, - query.NewWmiQueryFilter("DiskNumber", strconv.Itoa(int(diskNumber)), query.Equals), - query.NewWmiQueryFilter("GptType", cim.GPTPartitionTypeMicrosoftReserved, query.NotEquals)) + partitions, err := cim.ListPartitionsWithFilters(nil, cim.FilterForPartitionOnDisk(diskNumber), cim.FilterForPartitionsOfTypeNormal()) if cim.IgnoreNotFound(err) != nil { return false, fmt.Errorf("error checking presence of partitions on disk %d:, %v", diskNumber, err) } @@ -170,8 +164,8 @@ func (imp DiskAPI) CreateBasicPartition(diskNumber uint32) error { return err } - result, err := disk.InvokeMethodWithReturn( - "CreatePartition", + result, err := cim.CreatePartition( + disk, nil, // Size true, // UseMaximumSize nil, // Offset @@ -183,20 +177,16 @@ func (imp DiskAPI) CreateBasicPartition(diskNumber uint32) error { false, // IsHidden false, // IsActive, ) - // 42002 is returned by driver letter failed to assign after partition - if (result != 0 && result != 42002) || err != nil { + if (result != 0 && result != cim.ErrorCodeCreatePartitionAccessPathAlreadyInUse) || err != nil { return fmt.Errorf("error creating partition on disk %d. result: %d, err: %v", diskNumber, result, err) } - var status string - result, err = disk.InvokeMethodWithReturn("Refresh", &status) + result, _, err = cim.RefreshDisk(disk) if result != 0 || err != nil { return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err) } - partitions, err := cim.ListPartitionsWithFilters(nil, - query.NewWmiQueryFilter("DiskNumber", strconv.Itoa(int(diskNumber)), query.Equals), - query.NewWmiQueryFilter("GptType", cim.GPTPartitionTypeMicrosoftReserved, query.NotEquals)) + partitions, err := cim.ListPartitionsWithFilters(nil, cim.FilterForPartitionOnDisk(diskNumber), cim.FilterForPartitionsOfTypeNormal()) if err != nil { return fmt.Errorf("error query basic partition on disk %d:, %v", diskNumber, err) } @@ -206,13 +196,12 @@ func (imp DiskAPI) CreateBasicPartition(diskNumber uint32) error { } partition := partitions[0] - result, err = partition.InvokeMethodWithReturn("Online", status) + result, status, err := cim.SetPartitionState(partition, true) if result != 0 || err != nil { return fmt.Errorf("error bring partition %v on disk %d online. result: %d, status %s, err: %v", partition, diskNumber, result, status, err) } - err = partition.Refresh() - return err + return nil } func (imp DiskAPI) GetDiskNumberByName(page83ID string) (uint32, error) { @@ -272,13 +261,13 @@ func (imp DiskAPI) GetDiskPage83ID(disk syscall.Handle) (string, error) { } func (imp DiskAPI) GetDiskNumberWithID(page83ID string) (uint32, error) { - disks, err := cim.ListDisks([]string{"Path", "SerialNumber"}) + disks, err := cim.ListDisks(cim.DiskSelectorListForPathAndSerialNumber) if err != nil { return 0, err } for _, disk := range disks { - path, err := disk.GetPropertyPath() + path, err := cim.GetDiskPath(disk) if err != nil { return 0, fmt.Errorf("failed to query disk path: %v, %w", disk, err) } @@ -319,19 +308,19 @@ func (imp DiskAPI) GetDiskNumberAndPage83ID(path string) (uint32, string, error) // ListDiskIDs - constructs a map with the disk number as the key and the DiskID structure // as the value. The DiskID struct has a field for the page83 ID. func (imp DiskAPI) ListDiskIDs() (map[uint32]shared.DiskIDs, error) { - disks, err := cim.ListDisks([]string{"Path", "SerialNumber"}) + disks, err := cim.ListDisks(cim.DiskSelectorListForPathAndSerialNumber) if err != nil { return nil, err } m := make(map[uint32]shared.DiskIDs) for _, disk := range disks { - path, err := disk.GetPropertyPath() + path, err := cim.GetDiskPath(disk) if err != nil { return m, fmt.Errorf("failed to query disk path: %v, %w", disk, err) } - sn, err := disk.GetPropertySerialNumber() + sn, err := cim.GetDiskSerialNumber(disk) if err != nil { return m, fmt.Errorf("failed to query disk serial number: %v, %w", disk, err) } @@ -351,56 +340,49 @@ func (imp DiskAPI) ListDiskIDs() (map[uint32]shared.DiskIDs, error) { func (imp DiskAPI) GetDiskStats(diskNumber uint32) (int64, error) { // TODO: change to uint64 as it does not make sense to use int64 for size - var size int64 - disk, err := cim.QueryDiskByNumber(diskNumber, []string{"Size"}) + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForSize) if err != nil { return -1, err } - sz, err := disk.GetProperty("Size") + size, err := cim.GetDiskSize(disk) if err != nil { return -1, fmt.Errorf("failed to query size of disk %d. %v", diskNumber, err) } - size, err = strconv.ParseInt(sz.(string), 10, 64) return size, err } func (imp DiskAPI) SetDiskState(diskNumber uint32, isOnline bool) error { - disk, err := cim.QueryDiskByNumber(diskNumber, []string{"IsOffline"}) + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForIsOffline) if err != nil { return err } - offline, err := disk.GetPropertyIsOffline() + isOffline, err := cim.IsDiskOffline(disk) if err != nil { return fmt.Errorf("error setting disk %d attach state. error: %v", diskNumber, err) } - if isOnline == !offline { + if isOnline == !isOffline { return nil } - method := "Offline" - if isOnline { - method = "Online" - } - - result, err := disk.InvokeMethodWithReturn(method) + result, _, err := cim.SetDiskState(disk, isOnline) if result != 0 || err != nil { - return fmt.Errorf("setting disk %d attach state %s: result %d, error: %w", diskNumber, method, result, err) + return fmt.Errorf("setting disk %d attach state (isOnline: %v): result %d, error: %w", diskNumber, isOnline, result, err) } return nil } func (imp DiskAPI) GetDiskState(diskNumber uint32) (bool, error) { - disk, err := cim.QueryDiskByNumber(diskNumber, []string{"IsOffline"}) + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForIsOffline) if err != nil { return false, err } - isOffline, err := disk.GetPropertyIsOffline() + isOffline, err := cim.IsDiskOffline(disk) if err != nil { return false, fmt.Errorf("error parsing disk %d state. error: %v", diskNumber, err) } diff --git a/pkg/os/smb/api.go b/pkg/os/smb/api.go index 20b9544e..f0d28da4 100644 --- a/pkg/os/smb/api.go +++ b/pkg/os/smb/api.go @@ -3,15 +3,9 @@ package smb import ( "fmt" "strings" - "syscall" "github.com/kubernetes-csi/csi-proxy/pkg/cim" "github.com/kubernetes-csi/csi-proxy/pkg/utils" - "golang.org/x/sys/windows" -) - -const ( - credentialDelimiter = ":" ) type API interface { @@ -33,61 +27,23 @@ func New(requirePrivacy bool) *SmbAPI { } } -func remotePathForQuery(remotePath string) string { - return strings.ReplaceAll(remotePath, "\\", "\\\\") -} - -func escapeUserName(userName string) string { - // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170 - escaped := strings.ReplaceAll(userName, "\\", "\\\\") - escaped = strings.ReplaceAll(escaped, credentialDelimiter, "\\"+credentialDelimiter) - return escaped -} - -func createSymlink(link, target string, isDir bool) error { - linkPtr, err := syscall.UTF16PtrFromString(link) - if err != nil { - return err - } - targetPtr, err := syscall.UTF16PtrFromString(target) - if err != nil { - return err - } - - var flags uint32 - if isDir { - flags = windows.SYMBOLIC_LINK_FLAG_DIRECTORY - } - - err = windows.CreateSymbolicLink( - linkPtr, - targetPtr, - flags, - ) - return err -} - func (*SmbAPI) IsSmbMapped(remotePath string) (bool, error) { - inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePathForQuery(remotePath)) + inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePath) if err != nil { return false, cim.IgnoreNotFound(err) } - status, err := inst.GetProperty("Status") + status, err := cim.GetSmbGlobalMappingStatus(inst) if err != nil { return false, err } - return status.(int32) == cim.SmbMappingStatusOK, nil + return status == cim.SmbMappingStatusOK, nil } // NewSmbLink - creates a directory symbolic link to the remote share. // The os.Symlink was having issue for cases where the destination was an SMB share - the container -// runtime would complain stating "Access Denied". Because of this, we had to perform -// this operation with powershell commandlet creating an directory softlink. -// Since os.Symlink is currently being used in working code paths, no attempt is made in -// alpha to merge the paths. -// TODO (for beta release): Merge the link paths - os.Symlink and Powershell link path. +// runtime would complain stating "Access Denied". func (*SmbAPI) NewSmbLink(remotePath, localPath string) error { if !strings.HasSuffix(remotePath, "\\") { // Golang has issues resolving paths mapped to file shares if they do not end in a trailing \ @@ -97,7 +53,7 @@ func (*SmbAPI) NewSmbLink(remotePath, localPath string) error { longRemotePath := utils.EnsureLongPath(remotePath) longLocalPath := utils.EnsureLongPath(localPath) - err := createSymlink(longLocalPath, longRemotePath, true) + err := utils.CreateSymlink(longLocalPath, longRemotePath, true) if err != nil { return fmt.Errorf("error linking %s to %s. err: %v", remotePath, localPath, err) } @@ -106,17 +62,7 @@ func (*SmbAPI) NewSmbLink(remotePath, localPath string) error { } func (api *SmbAPI) NewSmbGlobalMapping(remotePath, username, password string) error { - params := map[string]interface{}{ - "RemotePath": remotePath, - "RequirePrivacy": api.RequirePrivacy, - } - if username != "" { - // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178 - // on how SMB credential is handled in PowerShell - params["Credential"] = escapeUserName(username) + credentialDelimiter + password - } - - result, _, err := cim.InvokeCimMethod(cim.WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params) + result, err := cim.NewSmbGlobalMapping(remotePath, username, password, api.RequirePrivacy) if err != nil { return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err) } @@ -125,7 +71,7 @@ func (api *SmbAPI) NewSmbGlobalMapping(remotePath, username, password string) er } func (*SmbAPI) RemoveSmbGlobalMapping(remotePath string) error { - err := cim.RemoveSmbGlobalMappingByRemotePath(remotePathForQuery(remotePath)) + err := cim.RemoveSmbGlobalMappingByRemotePath(remotePath) if err != nil { return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err) } diff --git a/pkg/os/volume/api.go b/pkg/os/volume/api.go index fcd2e6f8..ea667798 100644 --- a/pkg/os/volume/api.go +++ b/pkg/os/volume/api.go @@ -2,21 +2,21 @@ package volume import ( "fmt" - "os" "path/filepath" "regexp" - "strconv" "strings" - "github.com/go-ole/go-ole" "github.com/kubernetes-csi/csi-proxy/pkg/cim" "github.com/kubernetes-csi/csi-proxy/pkg/utils" - wmierrors "github.com/microsoft/wmi/pkg/errors" "github.com/pkg/errors" "golang.org/x/sys/windows" "k8s.io/klog/v2" ) +const ( + minimumResizeSize = 100 * 1024 * 1024 +) + // API exposes the internal volume operations available in the server type API interface { // ListVolumesOnDisk lists volumes on a disk identified by a `diskNumber` and optionally a partition identified by `partitionNumber`. @@ -69,7 +69,7 @@ func New() VolumeAPI { // ListVolumesOnDisk - returns back list of volumes(volumeIDs) in a disk and a partition. func (VolumeAPI) ListVolumesOnDisk(diskNumber uint32, partitionNumber uint32) (volumeIDs []string, err error) { - partitions, err := cim.ListPartitionsOnDisk(diskNumber, partitionNumber, []string{"ObjectId"}) + partitions, err := cim.ListPartitionsOnDisk(diskNumber, partitionNumber, cim.PartitionSelectorListObjectID) if err != nil { return nil, errors.Wrapf(err, "failed to list partition on disk %d", diskNumber) } @@ -80,9 +80,9 @@ func (VolumeAPI) ListVolumesOnDisk(diskNumber uint32, partitionNumber uint32) (v } for _, volume := range volumes { - uniqueID, err := volume.GetPropertyUniqueId() + uniqueID, err := cim.GetVolumeUniqueID(volume) if err != nil { - return nil, errors.Wrapf(err, "failed to list volumes") + return nil, errors.Wrapf(err, "failed to get unique ID for volume %v", volume) } volumeIDs = append(volumeIDs, uniqueID) } @@ -97,8 +97,7 @@ func (VolumeAPI) FormatVolume(volumeID string) (err error) { return fmt.Errorf("error formatting volume (%s). error: %v", volumeID, err) } - result, err := volume.InvokeMethodWithReturn( - "Format", + result, err := cim.FormatVolume(volume, "NTFS", // Format, "", // FileSystemLabel, nil, // AllocationUnitSize, @@ -113,7 +112,6 @@ func (VolumeAPI) FormatVolume(volumeID string) (err error) { if result != 0 || err != nil { return fmt.Errorf("error formatting volume (%s). result: %d, error: %v", volumeID, result, err) } - // TODO: Do we need to handle anything for len(out) == 0 return nil } @@ -124,18 +122,17 @@ func (VolumeAPI) WriteVolumeCache(volumeID string) (err error) { // IsVolumeFormatted - Check if the volume is formatted with the pre specified filesystem(typically ntfs). func (VolumeAPI) IsVolumeFormatted(volumeID string) (bool, error) { - volume, err := cim.QueryVolumeByUniqueID(volumeID, []string{"FileSystemType"}) + volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForFileSystemType) if err != nil { return false, fmt.Errorf("error checking if volume (%s) is formatted. error: %v", volumeID, err) } - fsType, err := volume.GetProperty("FileSystemType") + fsType, err := cim.GetVolumeFileSystemType(volume) if err != nil { return false, fmt.Errorf("failed to query volume file system type (%s): %w", volumeID, err) } - const FileSystemUnknown = 0 - return fsType.(int32) != FileSystemUnknown, nil + return fsType != cim.FileSystemUnknown, nil } // MountVolume - mounts a volume to a path. This is done using Win32 API SetVolumeMountPoint for presenting the volume via a path. @@ -194,36 +191,25 @@ func (VolumeAPI) ResizeVolume(volumeID string, size int64) error { // If size is 0 then we will resize to the maximum size possible, otherwise just resize to size if size == 0 { - var sizeMin, sizeMax ole.VARIANT + var result int var status string - result, err := part.InvokeMethodWithReturn("GetSupportedSize", &sizeMin, &sizeMax, &status) + result, _, finalSize, status, err = cim.GetPartitionSupportedSize(part) if result != 0 || err != nil { - return fmt.Errorf("error getting sizeMin, sizeMax from volume(%s). result: %d, status: %s, error: %v", volumeID, result, status, err) + return fmt.Errorf("error getting sizeMin, sizeMax from volume (%s). result: %d, status: %s, error: %v", volumeID, result, status, err) } - klog.V(5).Infof("got sizeMin(%v) sizeMax(%v) from volume(%s), status: %s", sizeMin, sizeMax, volumeID, status) - finalSizeStr := sizeMax.ToString() - finalSize, err = strconv.ParseInt(finalSizeStr, 10, 64) - if err != nil { - return fmt.Errorf("error parsing the sizeMax of volume (%s) with error (%v)", volumeID, err) - } } else { finalSize = size } - currentSizeVal, err := part.GetProperty("Size") + currentSize, err := cim.GetPartitionSize(part) if err != nil { return fmt.Errorf("error getting the current size of volume (%s) with error (%v)", volumeID, err) } - currentSize, err := strconv.ParseInt(currentSizeVal.(string), 10, 64) - if err != nil { - return fmt.Errorf("error parsing the current size of volume (%s) with error (%v)", volumeID, err) - } - // only resize if finalSize - currentSize is greater than 100MB - if finalSize-currentSize < 100*1024*1024 { - klog.V(2).Infof("minimum resize difference(1GB) not met, skipping resize. volumeID=%s currentSize=%d finalSize=%d", volumeID, currentSize, finalSize) + if finalSize-currentSize < minimumResizeSize { + klog.V(2).Infof("minimum resize difference (100MB) not met, skipping resize. volumeID=%s currentSize=%d finalSize=%d", volumeID, currentSize, finalSize) return nil } @@ -233,9 +219,7 @@ func (VolumeAPI) ResizeVolume(volumeID string, size int64) error { return nil } - var status string - result, err := part.InvokeMethodWithReturn("Resize", strconv.Itoa(int(finalSize)), &status) - + result, _, err := cim.ResizePartition(part, finalSize) if result != 0 || err != nil { return fmt.Errorf("error resizing volume (%s). size:%v, finalSize %v, error: %v", volumeID, size, finalSize, err) } @@ -247,10 +231,10 @@ func (VolumeAPI) ResizeVolume(volumeID string, size int64) error { disk, err := cim.QueryDiskByNumber(diskNumber, nil) if err != nil { - return fmt.Errorf("error parsing disk number of volume (%s). error: %v", volumeID, err) + return fmt.Errorf("error query disk of volume (%s). error: %v", volumeID, err) } - result, err = disk.InvokeMethodWithReturn("Refresh", &status) + result, _, err = cim.RefreshDisk(disk) if result != 0 || err != nil { return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err) } @@ -260,31 +244,21 @@ func (VolumeAPI) ResizeVolume(volumeID string, size int64) error { // GetVolumeStats - retrieves the volume stats for a given volume func (VolumeAPI) GetVolumeStats(volumeID string) (int64, int64, error) { - volume, err := cim.QueryVolumeByUniqueID(volumeID, []string{"UniqueId", "SizeRemaining", "Size"}) + volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForStats) if err != nil { return -1, -1, fmt.Errorf("error getting capacity and used size of volume (%s). error: %v", volumeID, err) } - volumeSizeVal, err := volume.GetProperty("Size") + volumeSize, err := cim.GetVolumeSize(volume) if err != nil { return -1, -1, fmt.Errorf("failed to query volume size (%s): %w", volumeID, err) } - volumeSize, err := strconv.ParseInt(volumeSizeVal.(string), 10, 64) - if err != nil { - return -1, -1, fmt.Errorf("failed to parse volume size (%s): %w", volumeID, err) - } - - volumeSizeRemainingVal, err := volume.GetProperty("SizeRemaining") + volumeSizeRemaining, err := cim.GetVolumeSizeRemaining(volume) if err != nil { return -1, -1, fmt.Errorf("failed to query volume remaining size (%s): %w", volumeID, err) } - volumeSizeRemaining, err := strconv.ParseInt(volumeSizeRemainingVal.(string), 10, 64) - if err != nil { - return -1, -1, fmt.Errorf("failed to parse volume remaining size (%s): %w", volumeID, err) - } - volumeUsedSize := volumeSize - volumeSizeRemaining return volumeSize, volumeUsedSize, nil } @@ -297,12 +271,12 @@ func (VolumeAPI) GetDiskNumberFromVolumeID(volumeID string) (uint32, error) { return 0, err } - diskNumber, err := part.GetProperty("DiskNumber") + diskNumber, err := cim.GetPartitionDiskNumber(part) if err != nil { return 0, fmt.Errorf("error query disk number of volume (%s). error: %v", volumeID, err) } - return uint32(diskNumber.(int32)), nil + return diskNumber, nil } // GetVolumeIDFromTargetPath - gets the volume ID given a mount point, the function is recursive until it find a volume or errors out @@ -316,7 +290,7 @@ func (VolumeAPI) GetVolumeIDFromTargetPath(mount string) (string, error) { } func getTarget(mount string) (string, error) { - mountedFolder, err := isMountedFolder(mount) + mountedFolder, err := utils.IsMountedFolder(mount) if err != nil { return "", err } @@ -356,7 +330,7 @@ func (VolumeAPI) GetClosestVolumeIDFromTargetPath(targetPath string) (string, er } // findClosestVolume finds the closest volume id for a given target path -// by following symlinks and moving up in the filesystem, if after moving up in the filesystem +// by following symlinks and moving up in the filesystem. if after moving up in the filesystem // we get to a DriveLetter then the volume corresponding to this drive letter is returned instead. func findClosestVolume(path string) (string, error) { candidatePath := path @@ -365,22 +339,20 @@ func findClosestVolume(path string) (string, error) { // while trying to follow symlinks // // The maximum path length in Windows is 260, it could be possible to end - // up in a sceneario where we do more than 256 iterations (e.g. by following symlinks from + // up in a scenario where we do more than 256 iterations (e.g. by following symlinks from // a place high in the hierarchy to a nested sibling location many times) // https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file#:~:text=In%20editions%20of%20Windows%20before,required%20to%20remove%20the%20limit. // // The number of iterations is 256, which is similar to the number of iterations in filepath-securejoin // https://github.com/cyphar/filepath-securejoin/blob/64536a8a66ae59588c981e2199f1dcf410508e07/join.go#L51 for i := 0; i < 256; i += 1 { - fi, err := os.Lstat(candidatePath) + isSymlink, err := utils.IsPathSymlink(candidatePath) if err != nil { return "", err } - // for windows NTFS, check if the path is symlink instead of directory. - isSymlink := fi.Mode()&os.ModeSymlink != 0 || fi.Mode()&os.ModeIrregular != 0 // mounted folder created by SetVolumeMountPoint may still report ModeSymlink == 0 - mountedFolder, err := isMountedFolder(candidatePath) + mountedFolder, err := utils.IsMountedFolder(candidatePath) if err != nil { return "", err } @@ -417,51 +389,18 @@ func findClosestVolume(path string) (string, error) { return "", fmt.Errorf("failed to find the closest volume for path=%s", path) } -// isMountedFolder checks whether the `path` is a mounted folder. -func isMountedFolder(path string) (bool, error) { - // https://learn.microsoft.com/en-us/windows/win32/fileio/determining-whether-a-directory-is-a-volume-mount-point - utf16Path, _ := windows.UTF16PtrFromString(path) - attrs, err := windows.GetFileAttributes(utf16Path) - if err != nil { - return false, err - } - - if (attrs & windows.FILE_ATTRIBUTE_REPARSE_POINT) == 0 { - return false, nil - } - - var findData windows.Win32finddata - findHandle, err := windows.FindFirstFile(utf16Path, &findData) - if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) { - return false, err - } - - for err == nil { - if findData.Reserved0&windows.IO_REPARSE_TAG_MOUNT_POINT != 0 { - return true, nil - } - - err = windows.FindNextFile(findHandle, &findData) - if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) { - return false, err - } - } - - return false, nil -} - // getVolumeForDriveLetter gets a volume from a drive letter (e.g. C:/). func getVolumeForDriveLetter(path string) (string, error) { if len(path) != 1 { return "", fmt.Errorf("the path %s is not a valid drive letter", path) } - volume, err := cim.GetVolumeByDriveLetter(path, []string{"UniqueId"}) + volume, err := cim.GetVolumeByDriveLetter(path, cim.VolumeSelectorListUniqueID) if err != nil { return "", nil } - uniqueID, err := volume.GetPropertyUniqueId() + uniqueID, err := cim.GetVolumeUniqueID(volume) if err != nil { return "", fmt.Errorf("error query unique ID of volume (%v). error: %v", volume, err) } @@ -470,12 +409,12 @@ func getVolumeForDriveLetter(path string) (string, error) { } func writeCache(volumeID string) error { - volume, err := cim.QueryVolumeByUniqueID(volumeID, []string{}) - if err != nil && !wmierrors.IsNotFound(err) { + volume, err := cim.QueryVolumeByUniqueID(volumeID, nil) + if err != nil { return fmt.Errorf("error writing volume (%s) cache. error: %v", volumeID, err) } - result, err := volume.Flush() + result, err := cim.FlushVolume(volume) if result != 0 || err != nil { return fmt.Errorf("error writing volume (%s) cache. result: %d, error: %v", volumeID, result, err) } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index bfe446f7..90c21125 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -3,7 +3,6 @@ package utils import ( "fmt" "os" - "os/exec" "strings" "github.com/pkg/errors" @@ -25,14 +24,6 @@ func EnsureLongPath(path string) string { return path } -func RunPowershellCmd(command string, envs ...string) ([]byte, error) { - cmd := exec.Command("powershell", "-Mta", "-NoProfile", "-Command", command) - cmd.Env = append(os.Environ(), envs...) - klog.V(8).Infof("Executing command: %q", cmd.String()) - out, err := cmd.CombinedOutput() - return out, err -} - func IsPathValid(path string) (bool, error) { pathString, err := windows.UTF16PtrFromString(path) if err != nil { @@ -52,3 +43,69 @@ func IsPathValid(path string) (bool, error) { klog.V(6).Infof("Path %s attribute: %d", path, attrs) return attrs != windows.INVALID_FILE_ATTRIBUTES, nil } + +// IsMountedFolder checks whether the `path` is a mounted folder. +func IsMountedFolder(path string) (bool, error) { + // https://learn.microsoft.com/en-us/windows/win32/fileio/determining-whether-a-directory-is-a-volume-mount-point + utf16Path, _ := windows.UTF16PtrFromString(path) + attrs, err := windows.GetFileAttributes(utf16Path) + if err != nil { + return false, err + } + + if (attrs & windows.FILE_ATTRIBUTE_REPARSE_POINT) == 0 { + return false, nil + } + + var findData windows.Win32finddata + findHandle, err := windows.FindFirstFile(utf16Path, &findData) + if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) { + return false, err + } + + for err == nil { + if findData.Reserved0&windows.IO_REPARSE_TAG_MOUNT_POINT != 0 { + return true, nil + } + + err = windows.FindNextFile(findHandle, &findData) + if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) { + return false, err + } + } + + return false, nil +} + +func IsPathSymlink(path string) (bool, error) { + fi, err := os.Lstat(path) + if err != nil { + return false, err + } + // for windows NTFS, check if the path is symlink instead of directory. + isSymlink := fi.Mode()&os.ModeSymlink != 0 || fi.Mode()&os.ModeIrregular != 0 + return isSymlink, nil +} + +func CreateSymlink(link, target string, isDir bool) error { + linkPtr, err := windows.UTF16PtrFromString(link) + if err != nil { + return err + } + targetPtr, err := windows.UTF16PtrFromString(target) + if err != nil { + return err + } + + var flags uint32 + if isDir { + flags = windows.SYMBOLIC_LINK_FLAG_DIRECTORY + } + + err = windows.CreateSymbolicLink( + linkPtr, + targetPtr, + flags, + ) + return err +} From db83ecfb41baceb87899d2d280edc670855fbed8 Mon Sep 17 00:00:00 2001 From: Zhongcheng Lao Date: Tue, 17 Jun 2025 15:41:05 +0800 Subject: [PATCH 2/2] Ensure COM threading apartment for API calls --- pkg/cim/wmi.go | 39 +++++ pkg/os/disk/api.go | 379 ++++++++++++++++++++++++------------------- pkg/os/iscsi/api.go | 288 +++++++++++++++++--------------- pkg/os/smb/api.go | 53 +++--- pkg/os/system/api.go | 131 ++++++++------- pkg/os/volume/api.go | 297 ++++++++++++++++++--------------- 6 files changed, 665 insertions(+), 522 deletions(-) diff --git a/pkg/cim/wmi.go b/pkg/cim/wmi.go index ba75f747..ec9c8f08 100644 --- a/pkg/cim/wmi.go +++ b/pkg/cim/wmi.go @@ -4,13 +4,16 @@ package cim import ( + "errors" "fmt" + "runtime" "github.com/go-ole/go-ole" "github.com/go-ole/go-ole/oleutil" "github.com/microsoft/wmi/pkg/base/query" wmierrors "github.com/microsoft/wmi/pkg/errors" cim "github.com/microsoft/wmi/pkg/wmiinstance" + "golang.org/x/sys/windows" "k8s.io/klog/v2" ) @@ -248,3 +251,39 @@ func IgnoreNotFound(err error) error { } return err } + +// WithCOMThread runs the given function `fn` on a locked OS thread +// with COM initialized using COINIT_MULTITHREADED. +// +// This is necessary for using COM/OLE APIs directly (e.g., via go-ole), +// because COM requires that initialization and usage occur on the same thread. +// +// It performs the following steps: +// - Locks the current goroutine to its OS thread +// - Calls ole.CoInitializeEx with COINIT_MULTITHREADED +// - Executes the user-provided function +// - Uninitializes COM +// - Unlocks the thread +// +// If COM initialization fails, or if the user's function returns an error, +// that error is returned by WithCOMThread. +func WithCOMThread(fn func() error) error { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + if err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil { + var oleError *ole.OleError + if errors.As(err, &oleError) && oleError != nil && oleError.Code() == uintptr(windows.S_FALSE) { + klog.V(10).Infof("COM library has been already initialized for the calling thread, proceeding to the function with no error") + err = nil + } + if err != nil { + return err + } + } else { + klog.V(10).Infof("COM library is initialized for the calling thread") + } + defer ole.CoUninitialize() + + return fn() +} diff --git a/pkg/os/disk/api.go b/pkg/os/disk/api.go index b366b977..5b46992c 100644 --- a/pkg/os/disk/api.go +++ b/pkg/os/disk/api.go @@ -64,144 +64,162 @@ func New() DiskAPI { // ListDiskLocations - constructs a map with the disk number as the key and the DiskLocation structure // as the value. The DiskLocation struct has various fields like the Adapter, Bus, Target and LUNID. func (imp DiskAPI) ListDiskLocations() (map[uint32]shared.DiskLocation, error) { - // "location": "PCI Slot 3 : Adapter 0 : Port 0 : Target 1 : LUN 0" - disks, err := cim.ListDisks(cim.DiskSelectorListForDiskNumberAndLocation) - if err != nil { - return nil, fmt.Errorf("could not query disk locations") - } - m := make(map[uint32]shared.DiskLocation) - for _, disk := range disks { - num, err := cim.GetDiskNumber(disk) + err := cim.WithCOMThread(func() error { + // "location": "PCI Slot 3 : Adapter 0 : Port 0 : Target 1 : LUN 0" + disks, err := cim.ListDisks(cim.DiskSelectorListForDiskNumberAndLocation) if err != nil { - return m, fmt.Errorf("failed to query disk number: %v, %w", disk, err) + return fmt.Errorf("could not query disk locations") } - location, err := cim.GetDiskLocation(disk) - if err != nil { - return m, fmt.Errorf("failed to query disk location: %v, %w", disk, err) - } + for _, disk := range disks { + num, err := cim.GetDiskNumber(disk) + if err != nil { + return fmt.Errorf("failed to query disk number: %v, %w", disk, err) + } + + location, err := cim.GetDiskLocation(disk) + if err != nil { + return fmt.Errorf("failed to query disk location: %v, %w", disk, err) + } - found := false - s := strings.Split(location, ":") - if len(s) >= 5 { - var d shared.DiskLocation - for _, item := range s { - item = strings.TrimSpace(item) - itemSplit := strings.Split(item, " ") - if len(itemSplit) == 2 { - found = true - switch strings.TrimSpace(itemSplit[0]) { - case "Adapter": - d.Adapter = strings.TrimSpace(itemSplit[1]) - case "Target": - d.Target = strings.TrimSpace(itemSplit[1]) - case "LUN": - d.LUNID = strings.TrimSpace(itemSplit[1]) - default: - klog.Warningf("Got unknown field : %s=%s", itemSplit[0], itemSplit[1]) + found := false + s := strings.Split(location, ":") + if len(s) >= 5 { + var d shared.DiskLocation + for _, item := range s { + item = strings.TrimSpace(item) + itemSplit := strings.Split(item, " ") + if len(itemSplit) == 2 { + found = true + switch strings.TrimSpace(itemSplit[0]) { + case "Adapter": + d.Adapter = strings.TrimSpace(itemSplit[1]) + case "Target": + d.Target = strings.TrimSpace(itemSplit[1]) + case "LUN": + d.LUNID = strings.TrimSpace(itemSplit[1]) + default: + klog.Warningf("Got unknown field : %s=%s", itemSplit[0], itemSplit[1]) + } } } - } - if found { - m[num] = d + if found { + m[num] = d + } } } - } - return m, nil + return nil + }) + return m, err } func (imp DiskAPI) Rescan() error { - result, err := cim.RescanDisks() - if err != nil { - return fmt.Errorf("error updating host storage cache output. result: %d, err: %v", result, err) - } - return nil + return cim.WithCOMThread(func() error { + result, err := cim.RescanDisks() + if err != nil { + return fmt.Errorf("error updating host storage cache output. result: %d, err: %v", result, err) + } + return nil + }) } func (imp DiskAPI) IsDiskInitialized(diskNumber uint32) (bool, error) { - disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForPartitionStyle) - if err != nil { - return false, fmt.Errorf("error checking initialized status of disk %d: %v", diskNumber, err) - } + var partitionStyle int32 + err := cim.WithCOMThread(func() error { + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForPartitionStyle) + if err != nil { + return fmt.Errorf("error checking initialized status of disk %d: %v", diskNumber, err) + } - partitionStyle, err := cim.GetDiskPartitionStyle(disk) - if err != nil { - return false, fmt.Errorf("failed to query partition style of disk %d: %v", diskNumber, err) - } + partitionStyle, err = cim.GetDiskPartitionStyle(disk) + if err != nil { + return fmt.Errorf("failed to query partition style of disk %d: %v", diskNumber, err) + } - return partitionStyle != cim.PartitionStyleUnknown, nil + return nil + }) + return partitionStyle != cim.PartitionStyleUnknown, err } func (imp DiskAPI) InitializeDisk(diskNumber uint32) error { - disk, err := cim.QueryDiskByNumber(diskNumber, nil) - if err != nil { - return fmt.Errorf("failed to initializing disk %d. error: %w", diskNumber, err) - } + return cim.WithCOMThread(func() error { + disk, err := cim.QueryDiskByNumber(diskNumber, nil) + if err != nil { + return fmt.Errorf("failed to initializing disk %d. error: %w", diskNumber, err) + } - result, err := cim.InitializeDisk(disk, cim.PartitionStyleGPT) - if result != 0 || err != nil { - return fmt.Errorf("failed to initializing disk %d: result %d, error: %w", diskNumber, result, err) - } + result, err := cim.InitializeDisk(disk, cim.PartitionStyleGPT) + if result != 0 || err != nil { + return fmt.Errorf("failed to initializing disk %d: result %d, error: %w", diskNumber, result, err) + } - return nil + return nil + }) } func (imp DiskAPI) BasicPartitionsExist(diskNumber uint32) (bool, error) { - partitions, err := cim.ListPartitionsWithFilters(nil, cim.FilterForPartitionOnDisk(diskNumber), cim.FilterForPartitionsOfTypeNormal()) - if cim.IgnoreNotFound(err) != nil { - return false, fmt.Errorf("error checking presence of partitions on disk %d:, %v", diskNumber, err) - } + var exist bool + err := cim.WithCOMThread(func() error { + partitions, err := cim.ListPartitionsWithFilters(nil, cim.FilterForPartitionOnDisk(diskNumber), cim.FilterForPartitionsOfTypeNormal()) + if cim.IgnoreNotFound(err) != nil { + return fmt.Errorf("error checking presence of partitions on disk %d:, %v", diskNumber, err) + } - return len(partitions) > 0, nil + exist = len(partitions) > 0 + return nil + }) + return exist, err } func (imp DiskAPI) CreateBasicPartition(diskNumber uint32) error { - disk, err := cim.QueryDiskByNumber(diskNumber, nil) - if err != nil { - return err - } + return cim.WithCOMThread(func() error { + disk, err := cim.QueryDiskByNumber(diskNumber, nil) + if err != nil { + return err + } - result, err := cim.CreatePartition( - disk, - nil, // Size - true, // UseMaximumSize - nil, // Offset - nil, // Alignment - nil, // DriveLetter - false, // AssignDriveLetter - nil, // MbrType, - cim.GPTPartitionTypeBasicData, // GPT Type - false, // IsHidden - false, // IsActive, - ) - if (result != 0 && result != cim.ErrorCodeCreatePartitionAccessPathAlreadyInUse) || err != nil { - return fmt.Errorf("error creating partition on disk %d. result: %d, err: %v", diskNumber, result, err) - } + result, err := cim.CreatePartition( + disk, + nil, // Size + true, // UseMaximumSize + nil, // Offset + nil, // Alignment + nil, // DriveLetter + false, // AssignDriveLetter + nil, // MbrType, + cim.GPTPartitionTypeBasicData, // GPT Type + false, // IsHidden + false, // IsActive, + ) + if (result != 0 && result != cim.ErrorCodeCreatePartitionAccessPathAlreadyInUse) || err != nil { + return fmt.Errorf("error creating partition on disk %d. result: %d, err: %v", diskNumber, result, err) + } - result, _, err = cim.RefreshDisk(disk) - if result != 0 || err != nil { - return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err) - } + result, _, err = cim.RefreshDisk(disk) + if result != 0 || err != nil { + return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err) + } - partitions, err := cim.ListPartitionsWithFilters(nil, cim.FilterForPartitionOnDisk(diskNumber), cim.FilterForPartitionsOfTypeNormal()) - if err != nil { - return fmt.Errorf("error query basic partition on disk %d:, %v", diskNumber, err) - } + partitions, err := cim.ListPartitionsWithFilters(nil, cim.FilterForPartitionOnDisk(diskNumber), cim.FilterForPartitionsOfTypeNormal()) + if err != nil { + return fmt.Errorf("error query basic partition on disk %d:, %v", diskNumber, err) + } - if len(partitions) == 0 { - return fmt.Errorf("failed to create basic partition on disk %d:, %v", diskNumber, err) - } + if len(partitions) == 0 { + return fmt.Errorf("failed to create basic partition on disk %d:, %v", diskNumber, err) + } - partition := partitions[0] - result, status, err := cim.SetPartitionState(partition, true) - if result != 0 || err != nil { - return fmt.Errorf("error bring partition %v on disk %d online. result: %d, status %s, err: %v", partition, diskNumber, result, status, err) - } + partition := partitions[0] + result, status, err := cim.SetPartitionState(partition, true) + if result != 0 || err != nil { + return fmt.Errorf("error bring partition %v on disk %d online. result: %d, status %s, err: %v", partition, diskNumber, result, status, err) + } - return nil + return nil + }) } func (imp DiskAPI) GetDiskNumberByName(page83ID string) (uint32, error) { @@ -261,28 +279,33 @@ func (imp DiskAPI) GetDiskPage83ID(disk syscall.Handle) (string, error) { } func (imp DiskAPI) GetDiskNumberWithID(page83ID string) (uint32, error) { - disks, err := cim.ListDisks(cim.DiskSelectorListForPathAndSerialNumber) - if err != nil { - return 0, err - } - - for _, disk := range disks { - path, err := cim.GetDiskPath(disk) + var diskNumberResult uint32 + err := cim.WithCOMThread(func() error { + disks, err := cim.ListDisks(cim.DiskSelectorListForPathAndSerialNumber) if err != nil { - return 0, fmt.Errorf("failed to query disk path: %v, %w", disk, err) + return err } - diskNumber, diskPage83ID, err := imp.GetDiskNumberAndPage83ID(path) - if err != nil { - return 0, err - } + for _, disk := range disks { + path, err := cim.GetDiskPath(disk) + if err != nil { + return fmt.Errorf("failed to query disk path: %v, %w", disk, err) + } - if diskPage83ID == page83ID { - return diskNumber, nil + diskNumber, diskPage83ID, err := imp.GetDiskNumberAndPage83ID(path) + if err != nil { + return err + } + + if diskPage83ID == page83ID { + diskNumberResult = diskNumber + return nil + } } - } - return 0, fmt.Errorf("could not find disk with Page83 ID %s", page83ID) + return fmt.Errorf("could not find disk with Page83 ID %s", page83ID) + }) + return diskNumberResult, err } func (imp DiskAPI) GetDiskNumberAndPage83ID(path string) (uint32, string, error) { @@ -308,84 +331,98 @@ func (imp DiskAPI) GetDiskNumberAndPage83ID(path string) (uint32, string, error) // ListDiskIDs - constructs a map with the disk number as the key and the DiskID structure // as the value. The DiskID struct has a field for the page83 ID. func (imp DiskAPI) ListDiskIDs() (map[uint32]shared.DiskIDs, error) { - disks, err := cim.ListDisks(cim.DiskSelectorListForPathAndSerialNumber) - if err != nil { - return nil, err - } - m := make(map[uint32]shared.DiskIDs) - for _, disk := range disks { - path, err := cim.GetDiskPath(disk) + err := cim.WithCOMThread(func() error { + disks, err := cim.ListDisks(cim.DiskSelectorListForPathAndSerialNumber) if err != nil { - return m, fmt.Errorf("failed to query disk path: %v, %w", disk, err) + return err } - sn, err := cim.GetDiskSerialNumber(disk) - if err != nil { - return m, fmt.Errorf("failed to query disk serial number: %v, %w", disk, err) - } + for _, disk := range disks { + path, err := cim.GetDiskPath(disk) + if err != nil { + return fmt.Errorf("failed to query disk path: %v, %w", disk, err) + } - diskNumber, page83, err := imp.GetDiskNumberAndPage83ID(path) - if err != nil { - return m, err - } + sn, err := cim.GetDiskSerialNumber(disk) + if err != nil { + return fmt.Errorf("failed to query disk serial number: %v, %w", disk, err) + } - m[diskNumber] = shared.DiskIDs{ - Page83: page83, - SerialNumber: sn, + diskNumber, page83, err := imp.GetDiskNumberAndPage83ID(path) + if err != nil { + return err + } + + m[diskNumber] = shared.DiskIDs{ + Page83: page83, + SerialNumber: sn, + } } - } - return m, nil + + return nil + }) + return m, err } func (imp DiskAPI) GetDiskStats(diskNumber uint32) (int64, error) { // TODO: change to uint64 as it does not make sense to use int64 for size - disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForSize) - if err != nil { - return -1, err - } + size := int64(-1) + err := cim.WithCOMThread(func() error { + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForSize) + if err != nil { + return err + } - size, err := cim.GetDiskSize(disk) - if err != nil { - return -1, fmt.Errorf("failed to query size of disk %d. %v", diskNumber, err) - } + size, err = cim.GetDiskSize(disk) + if err != nil { + return fmt.Errorf("failed to query size of disk %d. %v", diskNumber, err) + } + return nil + }) return size, err } func (imp DiskAPI) SetDiskState(diskNumber uint32, isOnline bool) error { - disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForIsOffline) - if err != nil { - return err - } + return cim.WithCOMThread(func() error { + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForIsOffline) + if err != nil { + return err + } - isOffline, err := cim.IsDiskOffline(disk) - if err != nil { - return fmt.Errorf("error setting disk %d attach state. error: %v", diskNumber, err) - } + isOffline, err := cim.IsDiskOffline(disk) + if err != nil { + return fmt.Errorf("error setting disk %d attach state. error: %v", diskNumber, err) + } - if isOnline == !isOffline { - return nil - } + if isOnline == !isOffline { + return nil + } - result, _, err := cim.SetDiskState(disk, isOnline) - if result != 0 || err != nil { - return fmt.Errorf("setting disk %d attach state (isOnline: %v): result %d, error: %w", diskNumber, isOnline, result, err) - } + result, _, err := cim.SetDiskState(disk, isOnline) + if result != 0 || err != nil { + return fmt.Errorf("setting disk %d attach state (isOnline: %v): result %d, error: %w", diskNumber, isOnline, result, err) + } - return nil + return nil + }) } func (imp DiskAPI) GetDiskState(diskNumber uint32) (bool, error) { - disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForIsOffline) - if err != nil { - return false, err - } + var isOffline bool + err := cim.WithCOMThread(func() error { + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForIsOffline) + if err != nil { + return err + } - isOffline, err := cim.IsDiskOffline(disk) - if err != nil { - return false, fmt.Errorf("error parsing disk %d state. error: %v", diskNumber, err) - } + isOffline, err = cim.IsDiskOffline(disk) + if err != nil { + return fmt.Errorf("error parsing disk %d state. error: %v", diskNumber, err) + } - return !isOffline, nil + return nil + }) + return !isOffline, err } diff --git a/pkg/os/iscsi/api.go b/pkg/os/iscsi/api.go index 1dd385db..9508ac9c 100644 --- a/pkg/os/iscsi/api.go +++ b/pkg/os/iscsi/api.go @@ -21,190 +21,210 @@ func New() APIImplementor { } func (APIImplementor) AddTargetPortal(portal *TargetPortal) error { - existing, err := cim.QueryISCSITargetPortal(portal.Address, portal.Port, nil) - if cim.IgnoreNotFound(err) != nil { - return fmt.Errorf("error query target portal at (%s:%d). err: %v", portal.Address, portal.Port, err) - } + return cim.WithCOMThread(func() error { + existing, err := cim.QueryISCSITargetPortal(portal.Address, portal.Port, nil) + if cim.IgnoreNotFound(err) != nil { + return fmt.Errorf("error query target portal at (%s:%d). err: %v", portal.Address, portal.Port, err) + } - if existing != nil { - klog.V(2).Infof("target portal at (%s:%d) already exists", portal.Address, portal.Port) - return nil - } + if existing != nil { + klog.V(2).Infof("target portal at (%s:%d) already exists", portal.Address, portal.Port) + return nil + } - _, err = cim.NewISCSITargetPortal(portal.Address, portal.Port, nil, nil, nil, nil) - if err != nil { - return fmt.Errorf("error adding target portal at (%s:%d). err: %v", portal.Address, portal.Port, err) - } + _, err = cim.NewISCSITargetPortal(portal.Address, portal.Port, nil, nil, nil, nil) + if err != nil { + return fmt.Errorf("error adding target portal at (%s:%d). err: %v", portal.Address, portal.Port, err) + } - return nil + return nil + }) } func (APIImplementor) DiscoverTargetPortal(portal *TargetPortal) ([]string, error) { - targets, err := cim.ListISCSITargetsByTargetPortalAddressAndPort(portal.Address, portal.Port, nil) - if err != nil { - return nil, fmt.Errorf("error list targets by target portal at (%s:%d). err: %v", portal.Address, portal.Port, err) - } - var iqns []string - for _, target := range targets { - iqn, err := cim.GetISCSITargetNodeAddress(target) + err := cim.WithCOMThread(func() error { + targets, err := cim.ListISCSITargetsByTargetPortalAddressAndPort(portal.Address, portal.Port, nil) if err != nil { - return nil, fmt.Errorf("failed parsing node address of target %v to target portal at (%s:%d). err: %w", target, portal.Address, portal.Port, err) + return fmt.Errorf("error list targets by target portal at (%s:%d). err: %v", portal.Address, portal.Port, err) } - iqns = append(iqns, iqn) - } + for _, target := range targets { + iqn, err := cim.GetISCSITargetNodeAddress(target) + if err != nil { + return fmt.Errorf("failed parsing node address of target %v to target portal at (%s:%d). err: %w", target, portal.Address, portal.Port, err) + } + + iqns = append(iqns, iqn) + } - return iqns, nil + return nil + }) + return iqns, err } func (APIImplementor) ListTargetPortals() ([]TargetPortal, error) { - instances, err := cim.ListISCSITargetPortals(cim.ISCSITargetPortalDefaultSelectorList) - if err != nil { - return nil, fmt.Errorf("error list target portals. err: %v", err) - } - var portals []TargetPortal - for _, instance := range instances { - address, port, err := cim.ParseISCSITargetPortal(instance) + err := cim.WithCOMThread(func() error { + instances, err := cim.ListISCSITargetPortals(cim.ISCSITargetPortalDefaultSelectorList) if err != nil { - return nil, fmt.Errorf("failed parsing target portal %v. err: %w", instance, err) + return fmt.Errorf("error list target portals. err: %v", err) } - portals = append(portals, TargetPortal{ - Address: address, - Port: port, - }) - } + for _, instance := range instances { + address, port, err := cim.ParseISCSITargetPortal(instance) + if err != nil { + return fmt.Errorf("failed parsing target portal %v. err: %w", instance, err) + } + + portals = append(portals, TargetPortal{ + Address: address, + Port: port, + }) + } - return portals, nil + return nil + }) + return portals, err } func (APIImplementor) RemoveTargetPortal(portal *TargetPortal) error { - instance, err := cim.QueryISCSITargetPortal(portal.Address, portal.Port, nil) - if err != nil { - return fmt.Errorf("error query target portal at (%s:%d). err: %v", portal.Address, portal.Port, err) - } + return cim.WithCOMThread(func() error { + instance, err := cim.QueryISCSITargetPortal(portal.Address, portal.Port, nil) + if err != nil { + return fmt.Errorf("error query target portal at (%s:%d). err: %v", portal.Address, portal.Port, err) + } - result, err := cim.RemoveISCSITargetPortal(instance) - if result != 0 || err != nil { - return fmt.Errorf("error removing target portal at (%s:%d). result: %d, err: %w", portal.Address, portal.Port, result, err) - } + result, err := cim.RemoveISCSITargetPortal(instance) + if result != 0 || err != nil { + return fmt.Errorf("error removing target portal at (%s:%d). result: %d, err: %w", portal.Address, portal.Port, result, err) + } - return nil + return nil + }) } func (APIImplementor) ConnectTarget(portal *TargetPortal, iqn string, authType string, chapUser string, chapSecret string) error { - target, err := cim.QueryISCSITarget(portal.Address, portal.Port, iqn) - if err != nil { - return fmt.Errorf("error query target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) - } - - connected, err := cim.IsISCSITargetConnected(target) - if err != nil { - return fmt.Errorf("error query connected of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) - } - - if connected { - klog.V(2).Infof("target %s from target portal at (%s:%d) is connected.", iqn, portal.Address, portal.Port) - return nil - } + return cim.WithCOMThread(func() error { + target, err := cim.QueryISCSITarget(portal.Address, portal.Port, iqn) + if err != nil { + return fmt.Errorf("error query target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) + } - targetAuthType := strings.ToUpper(strings.ReplaceAll(authType, "_", "")) + connected, err := cim.IsISCSITargetConnected(target) + if err != nil { + return fmt.Errorf("error query connected of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) + } - result, err := cim.ConnectISCSITarget(portal.Address, portal.Port, iqn, targetAuthType, &chapUser, &chapSecret) - if err != nil { - return fmt.Errorf("error connecting to target portal. result: %d, err: %w", result, err) - } + if connected { + klog.V(2).Infof("target %s from target portal at (%s:%d) is connected.", iqn, portal.Address, portal.Port) + return nil + } - return nil + targetAuthType := strings.ToUpper(strings.ReplaceAll(authType, "_", "")) + + result, err := cim.ConnectISCSITarget(portal.Address, portal.Port, iqn, targetAuthType, &chapUser, &chapSecret) + if err != nil { + return fmt.Errorf("error connecting to target portal. result: %d, err: %w", result, err) + } + + return nil + }) } func (APIImplementor) DisconnectTarget(portal *TargetPortal, iqn string) error { - target, err := cim.QueryISCSITarget(portal.Address, portal.Port, iqn) - if err != nil { - return fmt.Errorf("error query target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) - } - - connected, err := cim.IsISCSITargetConnected(target) - if err != nil { - return fmt.Errorf("error query connected of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) - } - - if !connected { - klog.V(2).Infof("target %s from target portal at (%s:%d) is not connected.", iqn, portal.Address, portal.Port) - return nil - } + return cim.WithCOMThread(func() error { + target, err := cim.QueryISCSITarget(portal.Address, portal.Port, iqn) + if err != nil { + return fmt.Errorf("error query target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) + } - // get session - session, err := cim.QueryISCSISessionByTarget(target) - if err != nil { - return fmt.Errorf("error query session of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) - } + connected, err := cim.IsISCSITargetConnected(target) + if err != nil { + return fmt.Errorf("error query connected of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) + } - sessionIdentifier, err := cim.GetISCSISessionIdentifier(session) - if err != nil { - return fmt.Errorf("error query session identifier of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) - } + if !connected { + klog.V(2).Infof("target %s from target portal at (%s:%d) is not connected.", iqn, portal.Address, portal.Port) + return nil + } - persistent, err := cim.IsISCSISessionPersistent(session) - if err != nil { - return fmt.Errorf("error query session persistency of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) - } + // get session + session, err := cim.QueryISCSISessionByTarget(target) + if err != nil { + return fmt.Errorf("error query session of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) + } - if persistent { - result, err := cim.UnregisterISCSISession(session) + sessionIdentifier, err := cim.GetISCSISessionIdentifier(session) if err != nil { - return fmt.Errorf("error unregister session on target %s from target portal at (%s:%d). result: %d, err: %w", iqn, portal.Address, portal.Port, result, err) + return fmt.Errorf("error query session identifier of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) } - } - result, err := cim.DisconnectISCSITarget(target, sessionIdentifier) - if err != nil { - return fmt.Errorf("error disconnecting target %s from target portal at (%s:%d). result: %d, err: %w", iqn, portal.Address, portal.Port, result, err) - } + persistent, err := cim.IsISCSISessionPersistent(session) + if err != nil { + return fmt.Errorf("error query session persistency of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) + } - return nil + if persistent { + result, err := cim.UnregisterISCSISession(session) + if err != nil { + return fmt.Errorf("error unregister session on target %s from target portal at (%s:%d). result: %d, err: %w", iqn, portal.Address, portal.Port, result, err) + } + } + + result, err := cim.DisconnectISCSITarget(target, sessionIdentifier) + if err != nil { + return fmt.Errorf("error disconnecting target %s from target portal at (%s:%d). result: %d, err: %w", iqn, portal.Address, portal.Port, result, err) + } + + return nil + }) } func (APIImplementor) GetTargetDisks(portal *TargetPortal, iqn string) ([]string, error) { - target, err := cim.QueryISCSITarget(portal.Address, portal.Port, iqn) - if err != nil { - return nil, fmt.Errorf("error query target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) - } - - connected, err := cim.IsISCSITargetConnected(target) - if err != nil { - return nil, fmt.Errorf("error query connected of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) - } - - if !connected { - klog.V(2).Infof("target %s from target portal at (%s:%d) is not connected.", iqn, portal.Address, portal.Port) - return nil, nil - } - - disks, err := cim.ListDisksByTarget(target) - if err != nil { - return nil, fmt.Errorf("error getting target disks on target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) - } - var ids []string - for _, disk := range disks { - number, err := cim.GetDiskNumber(disk) + err := cim.WithCOMThread(func() error { + target, err := cim.QueryISCSITarget(portal.Address, portal.Port, iqn) if err != nil { - return nil, fmt.Errorf("error getting number of disk %v on target %s from target portal at (%s:%d). err: %w", disk, iqn, portal.Address, portal.Port, err) + return fmt.Errorf("error query target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) } - ids = append(ids, strconv.Itoa(int(number))) - } - return ids, nil + connected, err := cim.IsISCSITargetConnected(target) + if err != nil { + return fmt.Errorf("error query connected of target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) + } + + if !connected { + klog.V(2).Infof("target %s from target portal at (%s:%d) is not connected.", iqn, portal.Address, portal.Port) + return nil + } + + disks, err := cim.ListDisksByTarget(target) + if err != nil { + return fmt.Errorf("error getting target disks on target %s from target portal at (%s:%d). err: %w", iqn, portal.Address, portal.Port, err) + } + + for _, disk := range disks { + number, err := cim.GetDiskNumber(disk) + if err != nil { + return fmt.Errorf("error getting number of disk %v on target %s from target portal at (%s:%d). err: %w", disk, iqn, portal.Address, portal.Port, err) + } + + ids = append(ids, strconv.Itoa(int(number))) + } + + return nil + }) + return ids, err } func (APIImplementor) SetMutualChapSecret(mutualChapSecret string) error { - result, err := cim.SetISCSISessionChapSecret(mutualChapSecret) - if err != nil { - return fmt.Errorf("error setting mutual chap secret. result: %d, err: %v", result, err) - } + return cim.WithCOMThread(func() error { + result, err := cim.SetISCSISessionChapSecret(mutualChapSecret) + if err != nil { + return fmt.Errorf("error setting mutual chap secret. result: %d, err: %v", result, err) + } - return nil + return nil + }) } diff --git a/pkg/os/smb/api.go b/pkg/os/smb/api.go index f0d28da4..9e60c9dd 100644 --- a/pkg/os/smb/api.go +++ b/pkg/os/smb/api.go @@ -28,17 +28,22 @@ func New(requirePrivacy bool) *SmbAPI { } func (*SmbAPI) IsSmbMapped(remotePath string) (bool, error) { - inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePath) - if err != nil { - return false, cim.IgnoreNotFound(err) - } - - status, err := cim.GetSmbGlobalMappingStatus(inst) - if err != nil { - return false, err - } - - return status == cim.SmbMappingStatusOK, nil + var isMapped bool + err := cim.WithCOMThread(func() error { + inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePath) + if err != nil { + return err + } + + status, err := cim.GetSmbGlobalMappingStatus(inst) + if err != nil { + return err + } + + isMapped = status == cim.SmbMappingStatusOK + return nil + }) + return isMapped, cim.IgnoreNotFound(err) } // NewSmbLink - creates a directory symbolic link to the remote share. @@ -62,19 +67,21 @@ func (*SmbAPI) NewSmbLink(remotePath, localPath string) error { } func (api *SmbAPI) NewSmbGlobalMapping(remotePath, username, password string) error { - result, err := cim.NewSmbGlobalMapping(remotePath, username, password, api.RequirePrivacy) - if err != nil { - return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err) - } - - return nil + return cim.WithCOMThread(func() error { + result, err := cim.NewSmbGlobalMapping(remotePath, username, password, api.RequirePrivacy) + if err != nil { + return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err) + } + return nil + }) } func (*SmbAPI) RemoveSmbGlobalMapping(remotePath string) error { - err := cim.RemoveSmbGlobalMappingByRemotePath(remotePath) - if err != nil { - return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err) - } - - return nil + return cim.WithCOMThread(func() error { + err := cim.RemoveSmbGlobalMappingByRemotePath(remotePath) + if err != nil { + return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err) + } + return nil + }) } diff --git a/pkg/os/system/api.go b/pkg/os/system/api.go index 2a8ddaf8..d209b8e5 100644 --- a/pkg/os/system/api.go +++ b/pkg/os/system/api.go @@ -107,45 +107,54 @@ func New() APIImplementor { } func (APIImplementor) GetBIOSSerialNumber() (string, error) { - bios, err := cim.QueryBIOSElement(cim.BIOSSelectorList) - if err != nil { - return "", fmt.Errorf("failed to get BIOS element: %w", err) - } + var sn string + err := cim.WithCOMThread(func() error { + bios, err := cim.QueryBIOSElement(cim.BIOSSelectorList) + if err != nil { + return fmt.Errorf("failed to get BIOS element: %w", err) + } - sn, err := cim.GetBIOSSerialNumber(bios) - if err != nil { - return "", fmt.Errorf("failed to get BIOS serial number property: %w", err) - } + sn, err = cim.GetBIOSSerialNumber(bios) + if err != nil { + return fmt.Errorf("failed to get BIOS serial number property: %w", err) + } - return sn, nil + return nil + }) + return sn, err } func (impl APIImplementor) GetService(name string) (*ServiceInfo, error) { - service, err := impl.serviceFactory.GetService(name) - if err != nil { - return nil, fmt.Errorf("failed to get service %s. error: %w", name, err) - } + var serviceInfo *ServiceInfo + err := cim.WithCOMThread(func() error { + service, err := impl.serviceFactory.GetService(name) + if err != nil { + return fmt.Errorf("failed to get service %s. error: %w", name, err) + } - displayName, err := cim.GetServiceDisplayName(service) - if err != nil { - return nil, fmt.Errorf("failed to get displayName property of service %s: %w", name, err) - } + displayName, err := cim.GetServiceDisplayName(service) + if err != nil { + return fmt.Errorf("failed to get displayName property of service %s: %w", name, err) + } - state, err := cim.GetServiceState(service) - if err != nil { - return nil, fmt.Errorf("failed to get state property of service %s: %w", name, err) - } + state, err := cim.GetServiceState(service) + if err != nil { + return fmt.Errorf("failed to get state property of service %s: %w", name, err) + } - startMode, err := cim.GetServiceStartMode(service) - if err != nil { - return nil, fmt.Errorf("failed to get startMode property of service %s: %w", name, err) - } + startMode, err := cim.GetServiceStartMode(service) + if err != nil { + return fmt.Errorf("failed to get startMode property of service %s: %w", name, err) + } - return &ServiceInfo{ - DisplayName: displayName, - StartType: serviceStartModeToStartType(startMode), - Status: serviceState(state), - }, nil + serviceInfo = &ServiceInfo{ + DisplayName: displayName, + StartType: serviceStartModeToStartType(startMode), + Status: serviceState(state), + } + return nil + }) + return serviceInfo, err } func (impl APIImplementor) StartService(name string) error { @@ -171,21 +180,23 @@ func (impl APIImplementor) StartService(name string) error { return state == serviceStateRunning, newState, err } - service, err := impl.serviceFactory.GetService(name) - if err != nil { - return fmt.Errorf("failed to get service %s. error: %w", name, err) - } + return cim.WithCOMThread(func() error { + service, err := impl.serviceFactory.GetService(name) + if err != nil { + return fmt.Errorf("failed to get service %s. error: %w", name, err) + } - state, err := impl.serviceManager.WaitUntilServiceState(service, startService, serviceRunningCheck, serviceStateCheckInternal, serviceStateCheckTimeout) - if err != nil && !errors.Is(err, errTimedOut) { - return fmt.Errorf("failed to wait for service %s state change. error: %w", name, err) - } + state, err := impl.serviceManager.WaitUntilServiceState(service, startService, serviceRunningCheck, serviceStateCheckInternal, serviceStateCheckTimeout) + if err != nil && !errors.Is(err, errTimedOut) { + return fmt.Errorf("failed to wait for service %s state change. error: %w", name, err) + } - if state != serviceStateRunning { - return fmt.Errorf("timed out waiting for service %s to become running", name) - } + if state != serviceStateRunning { + return fmt.Errorf("timed out waiting for service %s to become running", name) + } - return nil + return nil + }) } func (impl APIImplementor) stopSingleService(name string) (bool, error) { @@ -234,27 +245,29 @@ func (impl APIImplementor) stopSingleService(name string) (bool, error) { } func (impl APIImplementor) StopService(name string, force bool) error { - dependentRunning, err := impl.stopSingleService(name) - if err == nil { - return nil - } - if !dependentRunning || !force { - return fmt.Errorf("failed to stop service %s. error: %w", name, err) - } - - serviceNames, err := impl.serviceManager.GetDependentsForService(name) - if err != nil { - return fmt.Errorf("error getting dependent services for service name %s", name) - } + return cim.WithCOMThread(func() error { + dependentRunning, err := impl.stopSingleService(name) + if err == nil { + return nil + } + if !dependentRunning || !force { + return fmt.Errorf("failed to stop service %s. error: %w", name, err) + } - for _, serviceName := range serviceNames { - _, err = impl.stopSingleService(serviceName) + serviceNames, err := impl.serviceManager.GetDependentsForService(name) if err != nil { - return fmt.Errorf("failed to stop service %s. error: %w", name, err) + return fmt.Errorf("error getting dependent services for service name %s", name) + } + + for _, serviceName := range serviceNames { + _, err = impl.stopSingleService(serviceName) + if err != nil { + return fmt.Errorf("failed to stop service %s. error: %w", name, err) + } } - } - return nil + return nil + }) } type ServiceManagerImpl struct { diff --git a/pkg/os/volume/api.go b/pkg/os/volume/api.go index ea667798..87a10f0f 100644 --- a/pkg/os/volume/api.go +++ b/pkg/os/volume/api.go @@ -69,50 +69,55 @@ func New() VolumeAPI { // ListVolumesOnDisk - returns back list of volumes(volumeIDs) in a disk and a partition. func (VolumeAPI) ListVolumesOnDisk(diskNumber uint32, partitionNumber uint32) (volumeIDs []string, err error) { - partitions, err := cim.ListPartitionsOnDisk(diskNumber, partitionNumber, cim.PartitionSelectorListObjectID) - if err != nil { - return nil, errors.Wrapf(err, "failed to list partition on disk %d", diskNumber) - } + err = cim.WithCOMThread(func() error { + partitions, err := cim.ListPartitionsOnDisk(diskNumber, partitionNumber, cim.PartitionSelectorListObjectID) + if err != nil { + return errors.Wrapf(err, "failed to list partition on disk %d", diskNumber) + } - volumes, err := cim.FindVolumesByPartition(partitions) - if cim.IgnoreNotFound(err) != nil { - return nil, errors.Wrapf(err, "failed to list volumes on disk %d", diskNumber) - } + volumes, err := cim.FindVolumesByPartition(partitions) + if cim.IgnoreNotFound(err) != nil { + return errors.Wrapf(err, "failed to list volumes on disk %d", diskNumber) + } - for _, volume := range volumes { - uniqueID, err := cim.GetVolumeUniqueID(volume) - if err != nil { - return nil, errors.Wrapf(err, "failed to get unique ID for volume %v", volume) + for _, volume := range volumes { + uniqueID, err := cim.GetVolumeUniqueID(volume) + if err != nil { + return errors.Wrapf(err, "failed to get unique ID for volume %v", volume) + } + volumeIDs = append(volumeIDs, uniqueID) } - volumeIDs = append(volumeIDs, uniqueID) - } - return volumeIDs, nil + return nil + }) + return } // FormatVolume - Formats a volume with the NTFS format. func (VolumeAPI) FormatVolume(volumeID string) (err error) { - volume, err := cim.QueryVolumeByUniqueID(volumeID, nil) - if err != nil { - return fmt.Errorf("error formatting volume (%s). error: %v", volumeID, err) - } + return cim.WithCOMThread(func() error { + volume, err := cim.QueryVolumeByUniqueID(volumeID, nil) + if err != nil { + return fmt.Errorf("error formatting volume (%s). error: %v", volumeID, err) + } - result, err := cim.FormatVolume(volume, - "NTFS", // Format, - "", // FileSystemLabel, - nil, // AllocationUnitSize, - false, // Full, - true, // Force - nil, // Compress, - nil, // ShortFileNameSupport, - nil, // SetIntegrityStreams, - nil, // UseLargeFRS, - nil, // DisableHeatGathering, - ) - if result != 0 || err != nil { - return fmt.Errorf("error formatting volume (%s). result: %d, error: %v", volumeID, result, err) - } - return nil + result, err := cim.FormatVolume(volume, + "NTFS", // Format, + "", // FileSystemLabel, + nil, // AllocationUnitSize, + false, // Full, + true, // Force + nil, // Compress, + nil, // ShortFileNameSupport, + nil, // SetIntegrityStreams, + nil, // UseLargeFRS, + nil, // DisableHeatGathering, + ) + if result != 0 || err != nil { + return fmt.Errorf("error formatting volume (%s). result: %d, error: %v", volumeID, result, err) + } + return nil + }) } // WriteVolumeCache - Writes the file system cache to disk with the given volume id @@ -122,17 +127,22 @@ func (VolumeAPI) WriteVolumeCache(volumeID string) (err error) { // IsVolumeFormatted - Check if the volume is formatted with the pre specified filesystem(typically ntfs). func (VolumeAPI) IsVolumeFormatted(volumeID string) (bool, error) { - volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForFileSystemType) - if err != nil { - return false, fmt.Errorf("error checking if volume (%s) is formatted. error: %v", volumeID, err) - } + var formatted bool + err := cim.WithCOMThread(func() error { + volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForFileSystemType) + if err != nil { + return fmt.Errorf("error checking if volume (%s) is formatted. error: %v", volumeID, err) + } - fsType, err := cim.GetVolumeFileSystemType(volume) - if err != nil { - return false, fmt.Errorf("failed to query volume file system type (%s): %w", volumeID, err) - } + fsType, err := cim.GetVolumeFileSystemType(volume) + if err != nil { + return fmt.Errorf("failed to query volume file system type (%s): %w", volumeID, err) + } - return fsType != cim.FileSystemUnknown, nil + formatted = fsType != cim.FileSystemUnknown + return nil + }) + return formatted, err } // MountVolume - mounts a volume to a path. This is done using Win32 API SetVolumeMountPoint for presenting the volume via a path. @@ -182,101 +192,112 @@ func (VolumeAPI) UnmountVolume(volumeID, path string) error { // ResizeVolume - resizes a volume with the given size, if size == 0 then max supported size is used func (VolumeAPI) ResizeVolume(volumeID string, size int64) error { - var err error - var finalSize int64 - part, err := cim.GetPartitionByVolumeUniqueID(volumeID) - if err != nil { - return err - } - - // If size is 0 then we will resize to the maximum size possible, otherwise just resize to size - if size == 0 { - var result int - var status string - result, _, finalSize, status, err = cim.GetPartitionSupportedSize(part) - if result != 0 || err != nil { - return fmt.Errorf("error getting sizeMin, sizeMax from volume (%s). result: %d, status: %s, error: %v", volumeID, result, status, err) + return cim.WithCOMThread(func() error { + var err error + var finalSize int64 + part, err := cim.GetPartitionByVolumeUniqueID(volumeID) + if err != nil { + return err } - } else { - finalSize = size - } + // If size is 0 then we will resize to the maximum size possible, otherwise just resize to size + if size == 0 { + var result int + var status string + result, _, finalSize, status, err = cim.GetPartitionSupportedSize(part) + if result != 0 || err != nil { + return fmt.Errorf("error getting sizeMin, sizeMax from volume (%s). result: %d, status: %s, error: %v", volumeID, result, status, err) + } - currentSize, err := cim.GetPartitionSize(part) - if err != nil { - return fmt.Errorf("error getting the current size of volume (%s) with error (%v)", volumeID, err) - } + } else { + finalSize = size + } - // only resize if finalSize - currentSize is greater than 100MB - if finalSize-currentSize < minimumResizeSize { - klog.V(2).Infof("minimum resize difference (100MB) not met, skipping resize. volumeID=%s currentSize=%d finalSize=%d", volumeID, currentSize, finalSize) - return nil - } + currentSize, err := cim.GetPartitionSize(part) + if err != nil { + return fmt.Errorf("error getting the current size of volume (%s) with error (%v)", volumeID, err) + } - //if the partition's size is already the size we want this is a noop, just return - if currentSize >= finalSize { - klog.V(2).Infof("Attempted to resize volume (%s) to a lower size, from currentBytes=%d wantedBytes=%d", volumeID, currentSize, finalSize) - return nil - } + // only resize if finalSize - currentSize is greater than 100MB + if finalSize-currentSize < minimumResizeSize { + klog.V(2).Infof("minimum resize difference (100MB) not met, skipping resize. volumeID=%s currentSize=%d finalSize=%d", volumeID, currentSize, finalSize) + return nil + } - result, _, err := cim.ResizePartition(part, finalSize) - if result != 0 || err != nil { - return fmt.Errorf("error resizing volume (%s). size:%v, finalSize %v, error: %v", volumeID, size, finalSize, err) - } + //if the partition's size is already the size we want this is a noop, just return + if currentSize >= finalSize { + klog.V(2).Infof("Attempted to resize volume (%s) to a lower size, from currentBytes=%d wantedBytes=%d", volumeID, currentSize, finalSize) + return nil + } - diskNumber, err := cim.GetPartitionDiskNumber(part) - if err != nil { - return fmt.Errorf("error parsing disk number of volume (%s). error: %v", volumeID, err) - } + result, _, err := cim.ResizePartition(part, finalSize) + if result != 0 || err != nil { + return fmt.Errorf("error resizing volume (%s). size:%v, finalSize %v, error: %v", volumeID, size, finalSize, err) + } - disk, err := cim.QueryDiskByNumber(diskNumber, nil) - if err != nil { - return fmt.Errorf("error query disk of volume (%s). error: %v", volumeID, err) - } + diskNumber, err := cim.GetPartitionDiskNumber(part) + if err != nil { + return fmt.Errorf("error parsing disk number of volume (%s). error: %v", volumeID, err) + } - result, _, err = cim.RefreshDisk(disk) - if result != 0 || err != nil { - return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err) - } + disk, err := cim.QueryDiskByNumber(diskNumber, nil) + if err != nil { + return fmt.Errorf("error query disk of volume (%s). error: %v", volumeID, err) + } - return nil + result, _, err = cim.RefreshDisk(disk) + if result != 0 || err != nil { + return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err) + } + + return nil + }) } // GetVolumeStats - retrieves the volume stats for a given volume -func (VolumeAPI) GetVolumeStats(volumeID string) (int64, int64, error) { - volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForStats) - if err != nil { - return -1, -1, fmt.Errorf("error getting capacity and used size of volume (%s). error: %v", volumeID, err) - } +func (VolumeAPI) GetVolumeStats(volumeID string) (volumeSize, volumeUsedSize int64, err error) { + volumeSize = -1 + volumeUsedSize = -1 + err = cim.WithCOMThread(func() error { + volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForStats) + if err != nil { + return fmt.Errorf("error getting capacity and used size of volume (%s). error: %v", volumeID, err) + } - volumeSize, err := cim.GetVolumeSize(volume) - if err != nil { - return -1, -1, fmt.Errorf("failed to query volume size (%s): %w", volumeID, err) - } + volumeSize, err = cim.GetVolumeSize(volume) + if err != nil { + return fmt.Errorf("failed to query volume size (%s): %w", volumeID, err) + } - volumeSizeRemaining, err := cim.GetVolumeSizeRemaining(volume) - if err != nil { - return -1, -1, fmt.Errorf("failed to query volume remaining size (%s): %w", volumeID, err) - } + volumeSizeRemaining, err := cim.GetVolumeSizeRemaining(volume) + if err != nil { + return fmt.Errorf("failed to query volume remaining size (%s): %w", volumeID, err) + } - volumeUsedSize := volumeSize - volumeSizeRemaining - return volumeSize, volumeUsedSize, nil + volumeUsedSize = volumeSize - volumeSizeRemaining + return nil + }) + return } // GetDiskNumberFromVolumeID - gets the disk number where the volume is. func (VolumeAPI) GetDiskNumberFromVolumeID(volumeID string) (uint32, error) { - // get the size and sizeRemaining for the volume - part, err := cim.GetPartitionByVolumeUniqueID(volumeID) - if err != nil { - return 0, err - } + var diskNumber uint32 + err := cim.WithCOMThread(func() error { + // get the size and sizeRemaining for the volume + part, err := cim.GetPartitionByVolumeUniqueID(volumeID) + if err != nil { + return err + } - diskNumber, err := cim.GetPartitionDiskNumber(part) - if err != nil { - return 0, fmt.Errorf("error query disk number of volume (%s). error: %v", volumeID, err) - } + diskNumber, err = cim.GetPartitionDiskNumber(part) + if err != nil { + return fmt.Errorf("error query disk number of volume (%s). error: %v", volumeID, err) + } - return diskNumber, nil + return nil + }) + return diskNumber, err } // GetVolumeIDFromTargetPath - gets the volume ID given a mount point, the function is recursive until it find a volume or errors out @@ -395,28 +416,34 @@ func getVolumeForDriveLetter(path string) (string, error) { return "", fmt.Errorf("the path %s is not a valid drive letter", path) } - volume, err := cim.GetVolumeByDriveLetter(path, cim.VolumeSelectorListUniqueID) - if err != nil { - return "", nil - } + var uniqueID string + err := cim.WithCOMThread(func() error { + volume, err := cim.GetVolumeByDriveLetter(path, cim.VolumeSelectorListUniqueID) + if err != nil { + return err + } - uniqueID, err := cim.GetVolumeUniqueID(volume) - if err != nil { - return "", fmt.Errorf("error query unique ID of volume (%v). error: %v", volume, err) - } + uniqueID, err = cim.GetVolumeUniqueID(volume) + if err != nil { + return fmt.Errorf("error query unique ID of volume (%v). error: %v", volume, err) + } - return uniqueID, nil + return nil + }) + return uniqueID, err } func writeCache(volumeID string) error { - volume, err := cim.QueryVolumeByUniqueID(volumeID, nil) - if err != nil { - return fmt.Errorf("error writing volume (%s) cache. error: %v", volumeID, err) - } + return cim.WithCOMThread(func() error { + volume, err := cim.QueryVolumeByUniqueID(volumeID, nil) + if err != nil { + return fmt.Errorf("error writing volume (%s) cache. error: %v", volumeID, err) + } - result, err := cim.FlushVolume(volume) - if result != 0 || err != nil { - return fmt.Errorf("error writing volume (%s) cache. result: %d, error: %v", volumeID, result, err) - } - return nil + result, err := cim.FlushVolume(volume) + if result != 0 || err != nil { + return fmt.Errorf("error writing volume (%s) cache. result: %d, error: %v", volumeID, result, err) + } + return nil + }) }