Skip to content

cleanup: Move WMI related functions to cim package #393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions pkg/cim/disk.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ const (
// GPTPartitionTypeMicrosoftReserved is the GUID for Microsoft Reserved Partition (MSR)
// Reserved by Windows for system use
GPTPartitionTypeMicrosoftReserved = "{e3c9e316-0b5c-4db8-817d-f92df00215ae}"

// ErrorCodeCreatePartitionAccessPathAlreadyInUse is the error code (42002) returned when the driver letter failed to assign after partition created
ErrorCodeCreatePartitionAccessPathAlreadyInUse = 42002
)

var (
DiskSelectorListForDiskNumberAndLocation = []string{"Number", "Location"}
DiskSelectorListForPartitionStyle = []string{"PartitionStyle"}
DiskSelectorListForPathAndSerialNumber = []string{"Path", "SerialNumber"}
DiskSelectorListForIsOffline = []string{"IsOffline"}
DiskSelectorListForSize = []string{"Size"}
)

// QueryDiskByNumber retrieves disk information for a specific disk identified by its number.
Expand Down Expand Up @@ -77,6 +88,60 @@ func ListDisks(selectorList []string) ([]*storage.MSFT_Disk, error) {
return disks, nil
}

// InitializeDisk initializes a RAW disk with a particular partition style.
//
// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/initialize-msft-disk
// for the WMI method definition.
func InitializeDisk(disk *storage.MSFT_Disk, partitionStyle int) (int, error) {
result, err := disk.InvokeMethodWithReturn("Initialize", int32(partitionStyle))
return int(result), err
}

// RefreshDisk Refreshes the cached disk layout information.
//
// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-refresh
// for the WMI method definition.
func RefreshDisk(disk *storage.MSFT_Disk) (int, string, error) {
var status string
result, err := disk.InvokeMethodWithReturn("Refresh", &status)
return int(result), status, err
}

// CreatePartition creates a partition on a disk.
//
// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/createpartition-msft-disk
// for the WMI method definition.
func CreatePartition(disk *storage.MSFT_Disk, params ...interface{}) (int, error) {
result, err := disk.InvokeMethodWithReturn("CreatePartition", params...)
return int(result), err
}

// SetDiskState takes a disk online or offline.
//
// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-online and
// https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-offline
// for the WMI method definition.
func SetDiskState(disk *storage.MSFT_Disk, online bool) (int, string, error) {
method := "Offline"
if online {
method = "Online"
}

var status string
result, err := disk.InvokeMethodWithReturn(method, &status)
return int(result), status, err
}

// RescanDisks rescans all changes by updating the internal cache of software objects (that is, Disks, Partitions, Volumes)
// for the storage setting.
//
// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-storagesetting-updatehoststoragecache
// for the WMI method definition.
func RescanDisks() (int, error) {
result, _, err := InvokeCimMethod(WMINamespaceStorage, "MSFT_StorageSetting", "UpdateHostStorageCache", nil)
return result, err
}

// GetDiskNumber returns the number of a disk.
func GetDiskNumber(disk *storage.MSFT_Disk) (uint32, error) {
number, err := disk.GetProperty("Number")
Expand All @@ -85,3 +150,41 @@ func GetDiskNumber(disk *storage.MSFT_Disk) (uint32, error) {
}
return uint32(number.(int32)), err
}

// GetDiskLocation returns the location of a disk.
func GetDiskLocation(disk *storage.MSFT_Disk) (string, error) {
return disk.GetPropertyLocation()
}

// GetDiskPartitionStyle returns the partition style of a disk.
func GetDiskPartitionStyle(disk *storage.MSFT_Disk) (int32, error) {
retValue, err := disk.GetProperty("PartitionStyle")
if err != nil {
return 0, err
}
return retValue.(int32), err
}

// IsDiskOffline returns whether a disk is offline.
func IsDiskOffline(disk *storage.MSFT_Disk) (bool, error) {
return disk.GetPropertyIsOffline()
}

// GetDiskSize returns the size of a disk.
func GetDiskSize(disk *storage.MSFT_Disk) (int64, error) {
sz, err := disk.GetProperty("Size")
if err != nil {
return -1, err
}
return strconv.ParseInt(sz.(string), 10, 64)
}

// GetDiskPath returns the path of a disk.
func GetDiskPath(disk *storage.MSFT_Disk) (string, error) {
return disk.GetPropertyPath()
}

// GetDiskSerialNumber returns the serial number of a disk.
func GetDiskSerialNumber(disk *storage.MSFT_Disk) (string, error) {
return disk.GetPropertySerialNumber()
}
53 changes: 50 additions & 3 deletions pkg/cim/smb.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package cim

import (
"strings"

"github.com/microsoft/wmi/pkg/base/query"
cim "github.com/microsoft/wmi/pkg/wmiinstance"
)
Expand All @@ -17,8 +19,24 @@ const (
SmbMappingStatusConnecting
SmbMappingStatusReconnecting
SmbMappingStatusUnavailable

credentialDelimiter = ":"
)

// escapeQueryParameter escapes a parameter for WMI Queries
func escapeQueryParameter(s string) string {
s = strings.ReplaceAll(s, "'", "''")
s = strings.ReplaceAll(s, "\\", "\\\\")
return s
}

func escapeUserName(userName string) string {
// refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170
userName = strings.ReplaceAll(userName, "\\", "\\\\")
userName = strings.ReplaceAll(userName, credentialDelimiter, "\\"+credentialDelimiter)
return userName
}

// QuerySmbGlobalMappingByRemotePath retrieves the SMB global mapping from its remote path.
//
// The equivalent WMI query is:
Expand All @@ -28,7 +46,7 @@ const (
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
// for the WMI class definition.
func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, error) {
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath)
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath))
instances, err := QueryInstances(WMINamespaceSmb, smbQuery)
if err != nil {
return nil, err
Expand All @@ -37,12 +55,22 @@ func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, err
return instances[0], err
}

// RemoveSmbGlobalMappingByRemotePath removes a SMB global mapping matching to the remote path.
// GetSmbGlobalMappingStatus returns the status of an SMB global mapping.
func GetSmbGlobalMappingStatus(inst *cim.WmiInstance) (int32, error) {
statusProp, err := inst.GetProperty("Status")
if err != nil {
return SmbMappingStatusUnavailable, err
}

return statusProp.(int32), nil
}

// RemoveSmbGlobalMappingByRemotePath removes an SMB global mapping matching to the remote path.
//
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
// for the WMI class definition.
func RemoveSmbGlobalMappingByRemotePath(remotePath string) error {
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath)
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath))
instances, err := QueryInstances(WMINamespaceSmb, smbQuery)
if err != nil {
return err
Expand All @@ -51,3 +79,22 @@ func RemoveSmbGlobalMappingByRemotePath(remotePath string) error {
_, err = instances[0].InvokeMethod("Remove", true)
return err
}

// NewSmbGlobalMapping creates a new SMB global mapping to the remote path.
//
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
// for the WMI class definition.
func NewSmbGlobalMapping(remotePath, username, password string, requirePrivacy bool) (int, error) {
params := map[string]interface{}{
"RemotePath": remotePath,
"RequirePrivacy": requirePrivacy,
}
if username != "" {
// refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178
// on how SMB credential is handled in PowerShell
params["Credential"] = escapeUserName(username) + credentialDelimiter + password
}

result, _, err := InvokeCimMethod(WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params)
return result, err
}
151 changes: 151 additions & 0 deletions pkg/cim/volume.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,23 @@ import (
"fmt"
"strconv"

"github.com/go-ole/go-ole"
"github.com/microsoft/wmi/pkg/base/query"
"github.com/microsoft/wmi/pkg/errors"
"github.com/microsoft/wmi/server2019/root/microsoft/windows/storage"
"k8s.io/klog/v2"
)

const (
FileSystemUnknown = 0
)

var (
VolumeSelectorListForFileSystemType = []string{"FileSystemType"}
VolumeSelectorListForStats = []string{"UniqueId", "SizeRemaining", "Size"}
VolumeSelectorListUniqueID = []string{"UniqueId"}

PartitionSelectorListObjectID = []string{"ObjectId"}
)

// QueryVolumeByUniqueID retrieves a specific volume by its unique identifier,
Expand Down Expand Up @@ -78,6 +92,68 @@ func ListVolumes(selectorList []string) ([]*storage.MSFT_Volume, error) {
return volumes, nil
}

// FormatVolume formats the specified volume.
//
// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/format-msft-volume
// for the WMI method definition.
func FormatVolume(volume *storage.MSFT_Volume, params ...interface{}) (int, error) {
result, err := volume.InvokeMethodWithReturn("Format", params...)
return int(result), err
}

// FlushVolume flushes the cached data in the volume's file system to disk.
//
// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-volume-flush
// for the WMI method definition.
func FlushVolume(volume *storage.MSFT_Volume) (int, error) {
result, err := volume.Flush()
return int(result), err
}

// GetVolumeUniqueID returns the unique ID (object ID) of a volume.
func GetVolumeUniqueID(volume *storage.MSFT_Volume) (string, error) {
return volume.GetPropertyUniqueId()
}

// GetVolumeFileSystemType returns the file system type of a volume.
func GetVolumeFileSystemType(volume *storage.MSFT_Volume) (int32, error) {
fsType, err := volume.GetProperty("FileSystemType")
if err != nil {
return 0, err
}
return fsType.(int32), nil
}

// GetVolumeSize returns the size of a volume.
func GetVolumeSize(volume *storage.MSFT_Volume) (int64, error) {
volumeSizeVal, err := volume.GetProperty("Size")
if err != nil {
return -1, err
}

volumeSize, err := strconv.ParseInt(volumeSizeVal.(string), 10, 64)
if err != nil {
return -1, err
}

return volumeSize, err
}

// GetVolumeSizeRemaining returns the remaining size of a volume.
func GetVolumeSizeRemaining(volume *storage.MSFT_Volume) (int64, error) {
volumeSizeRemainingVal, err := volume.GetProperty("SizeRemaining")
if err != nil {
return -1, err
}

volumeSizeRemaining, err := strconv.ParseInt(volumeSizeRemainingVal.(string), 10, 64)
if err != nil {
return -1, err
}

return volumeSizeRemaining, err
}

// ListPartitionsOnDisk retrieves all partitions or a partition with the specified number on a disk.
//
// The equivalent WMI query is:
Expand Down Expand Up @@ -245,3 +321,78 @@ func GetPartitionDiskNumber(part *storage.MSFT_Partition) (uint32, error) {

return uint32(diskNumber.(int32)), nil
}

// SetPartitionState takes a partition online or offline.
//
// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-online and
// https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-offline
// for the WMI method definition.
func SetPartitionState(part *storage.MSFT_Partition, online bool) (int, string, error) {
method := "Offline"
if online {
method = "Online"
}

var status string
result, err := part.InvokeMethodWithReturn(method, &status)
return int(result), status, err
}

// GetPartitionSupportedSize retrieves the minimum and maximum sizes that the partition can be resized to using the ResizePartition method.
//
// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-getsupportedsizes
// for the WMI method definition.
func GetPartitionSupportedSize(part *storage.MSFT_Partition) (result int, sizeMin, sizeMax int64, status string, err error) {
sizeMin = -1
sizeMax = -1

var sizeMinVar, sizeMaxVar ole.VARIANT
invokeResult, err := part.InvokeMethodWithReturn("GetSupportedSize", &sizeMinVar, &sizeMaxVar, &status)
if invokeResult != 0 || err != nil {
result = int(invokeResult)
}
klog.V(5).Infof("got sizeMin (%v) sizeMax (%v) from partition (%v), status: %s", sizeMinVar, sizeMaxVar, part, status)

sizeMin, err = strconv.ParseInt(sizeMinVar.ToString(), 10, 64)
if err != nil {
return
}

sizeMax, err = strconv.ParseInt(sizeMaxVar.ToString(), 10, 64)
return
}

// ResizePartition resizes a partition.
//
// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-resize
// for the WMI method definition.
func ResizePartition(part *storage.MSFT_Partition, size int64) (int, string, error) {
var status string
result, err := part.InvokeMethodWithReturn("Resize", strconv.Itoa(int(size)), &status)
return int(result), status, err
}

// GetPartitionSize returns the size of a partition.
func GetPartitionSize(part *storage.MSFT_Partition) (int64, error) {
sizeProp, err := part.GetProperty("Size")
if err != nil {
return -1, err
}

size, err := strconv.ParseInt(sizeProp.(string), 10, 64)
if err != nil {
return -1, err
}

return size, err
}

// FilterForPartitionOnDisk creates a WMI query filter to query a disk by its number.
func FilterForPartitionOnDisk(diskNumber uint32) *query.WmiQueryFilter {
return query.NewWmiQueryFilter("DiskNumber", strconv.Itoa(int(diskNumber)), query.Equals)
}

// FilterForPartitionsOfTypeNormal creates a WMI query filter for all non-reserved partitions.
func FilterForPartitionsOfTypeNormal() *query.WmiQueryFilter {
return query.NewWmiQueryFilter("GptType", GPTPartitionTypeMicrosoftReserved, query.NotEquals)
}
Loading