Skip to content

Commit 47e900d

Browse files
committed
Ensure COM threading apartment in Volume APIs
1 parent 0e62804 commit 47e900d

File tree

1 file changed

+162
-135
lines changed

1 file changed

+162
-135
lines changed

pkg/os/volume/api.go

Lines changed: 162 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -69,50 +69,55 @@ func New() VolumeAPI {
6969

7070
// ListVolumesOnDisk - returns back list of volumes(volumeIDs) in a disk and a partition.
7171
func (VolumeAPI) ListVolumesOnDisk(diskNumber uint32, partitionNumber uint32) (volumeIDs []string, err error) {
72-
partitions, err := cim.ListPartitionsOnDisk(diskNumber, partitionNumber, cim.PartitionSelectorListObjectID)
73-
if err != nil {
74-
return nil, errors.Wrapf(err, "failed to list partition on disk %d", diskNumber)
75-
}
72+
err = cim.WithCOMThread(func() error {
73+
partitions, err := cim.ListPartitionsOnDisk(diskNumber, partitionNumber, cim.PartitionSelectorListObjectID)
74+
if err != nil {
75+
return errors.Wrapf(err, "failed to list partition on disk %d", diskNumber)
76+
}
7677

77-
volumes, err := cim.FindVolumesByPartition(partitions)
78-
if cim.IgnoreNotFound(err) != nil {
79-
return nil, errors.Wrapf(err, "failed to list volumes on disk %d", diskNumber)
80-
}
78+
volumes, err := cim.FindVolumesByPartition(partitions)
79+
if cim.IgnoreNotFound(err) != nil {
80+
return errors.Wrapf(err, "failed to list volumes on disk %d", diskNumber)
81+
}
8182

82-
for _, volume := range volumes {
83-
uniqueID, err := cim.GetVolumeUniqueID(volume)
84-
if err != nil {
85-
return nil, errors.Wrapf(err, "failed to get unique ID for volume %v", volume)
83+
for _, volume := range volumes {
84+
uniqueID, err := cim.GetVolumeUniqueID(volume)
85+
if err != nil {
86+
return errors.Wrapf(err, "failed to get unique ID for volume %v", volume)
87+
}
88+
volumeIDs = append(volumeIDs, uniqueID)
8689
}
87-
volumeIDs = append(volumeIDs, uniqueID)
88-
}
8990

90-
return volumeIDs, nil
91+
return nil
92+
})
93+
return
9194
}
9295

9396
// FormatVolume - Formats a volume with the NTFS format.
9497
func (VolumeAPI) FormatVolume(volumeID string) (err error) {
95-
volume, err := cim.QueryVolumeByUniqueID(volumeID, nil)
96-
if err != nil {
97-
return fmt.Errorf("error formatting volume (%s). error: %v", volumeID, err)
98-
}
98+
return cim.WithCOMThread(func() error {
99+
volume, err := cim.QueryVolumeByUniqueID(volumeID, nil)
100+
if err != nil {
101+
return fmt.Errorf("error formatting volume (%s). error: %v", volumeID, err)
102+
}
99103

100-
result, err := cim.FormatVolume(volume,
101-
"NTFS", // Format,
102-
"", // FileSystemLabel,
103-
nil, // AllocationUnitSize,
104-
false, // Full,
105-
true, // Force
106-
nil, // Compress,
107-
nil, // ShortFileNameSupport,
108-
nil, // SetIntegrityStreams,
109-
nil, // UseLargeFRS,
110-
nil, // DisableHeatGathering,
111-
)
112-
if result != 0 || err != nil {
113-
return fmt.Errorf("error formatting volume (%s). result: %d, error: %v", volumeID, result, err)
114-
}
115-
return nil
104+
result, err := cim.FormatVolume(volume,
105+
"NTFS", // Format,
106+
"", // FileSystemLabel,
107+
nil, // AllocationUnitSize,
108+
false, // Full,
109+
true, // Force
110+
nil, // Compress,
111+
nil, // ShortFileNameSupport,
112+
nil, // SetIntegrityStreams,
113+
nil, // UseLargeFRS,
114+
nil, // DisableHeatGathering,
115+
)
116+
if result != 0 || err != nil {
117+
return fmt.Errorf("error formatting volume (%s). result: %d, error: %v", volumeID, result, err)
118+
}
119+
return nil
120+
})
116121
}
117122

118123
// WriteVolumeCache - Writes the file system cache to disk with the given volume id
@@ -122,17 +127,22 @@ func (VolumeAPI) WriteVolumeCache(volumeID string) (err error) {
122127

123128
// IsVolumeFormatted - Check if the volume is formatted with the pre specified filesystem(typically ntfs).
124129
func (VolumeAPI) IsVolumeFormatted(volumeID string) (bool, error) {
125-
volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForFileSystemType)
126-
if err != nil {
127-
return false, fmt.Errorf("error checking if volume (%s) is formatted. error: %v", volumeID, err)
128-
}
130+
var formatted bool
131+
err := cim.WithCOMThread(func() error {
132+
volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForFileSystemType)
133+
if err != nil {
134+
return fmt.Errorf("error checking if volume (%s) is formatted. error: %v", volumeID, err)
135+
}
129136

130-
fsType, err := cim.GetVolumeFileSystemType(volume)
131-
if err != nil {
132-
return false, fmt.Errorf("failed to query volume file system type (%s): %w", volumeID, err)
133-
}
137+
fsType, err := cim.GetVolumeFileSystemType(volume)
138+
if err != nil {
139+
return fmt.Errorf("failed to query volume file system type (%s): %w", volumeID, err)
140+
}
134141

135-
return fsType != cim.FileSystemUnknown, nil
142+
formatted = fsType != cim.FileSystemUnknown
143+
return nil
144+
})
145+
return formatted, err
136146
}
137147

138148
// MountVolume - mounts a volume to a path. This is done using Win32 API SetVolumeMountPoint for presenting the volume via a path.
@@ -182,101 +192,112 @@ func (VolumeAPI) UnmountVolume(volumeID, path string) error {
182192

183193
// ResizeVolume - resizes a volume with the given size, if size == 0 then max supported size is used
184194
func (VolumeAPI) ResizeVolume(volumeID string, size int64) error {
185-
var err error
186-
var finalSize int64
187-
part, err := cim.GetPartitionByVolumeUniqueID(volumeID)
188-
if err != nil {
189-
return err
190-
}
191-
192-
// If size is 0 then we will resize to the maximum size possible, otherwise just resize to size
193-
if size == 0 {
194-
var result int
195-
var status string
196-
result, _, finalSize, status, err = cim.GetPartitionSupportedSize(part)
197-
if result != 0 || err != nil {
198-
return fmt.Errorf("error getting sizeMin, sizeMax from volume (%s). result: %d, status: %s, error: %v", volumeID, result, status, err)
195+
return cim.WithCOMThread(func() error {
196+
var err error
197+
var finalSize int64
198+
part, err := cim.GetPartitionByVolumeUniqueID(volumeID)
199+
if err != nil {
200+
return err
199201
}
200202

201-
} else {
202-
finalSize = size
203-
}
203+
// If size is 0 then we will resize to the maximum size possible, otherwise just resize to size
204+
if size == 0 {
205+
var result int
206+
var status string
207+
result, _, finalSize, status, err = cim.GetPartitionSupportedSize(part)
208+
if result != 0 || err != nil {
209+
return fmt.Errorf("error getting sizeMin, sizeMax from volume (%s). result: %d, status: %s, error: %v", volumeID, result, status, err)
210+
}
204211

205-
currentSize, err := cim.GetPartitionSize(part)
206-
if err != nil {
207-
return fmt.Errorf("error getting the current size of volume (%s) with error (%v)", volumeID, err)
208-
}
212+
} else {
213+
finalSize = size
214+
}
209215

210-
// only resize if finalSize - currentSize is greater than 100MB
211-
if finalSize-currentSize < minimumResizeSize {
212-
klog.V(2).Infof("minimum resize difference (100MB) not met, skipping resize. volumeID=%s currentSize=%d finalSize=%d", volumeID, currentSize, finalSize)
213-
return nil
214-
}
216+
currentSize, err := cim.GetPartitionSize(part)
217+
if err != nil {
218+
return fmt.Errorf("error getting the current size of volume (%s) with error (%v)", volumeID, err)
219+
}
215220

216-
//if the partition's size is already the size we want this is a noop, just return
217-
if currentSize >= finalSize {
218-
klog.V(2).Infof("Attempted to resize volume (%s) to a lower size, from currentBytes=%d wantedBytes=%d", volumeID, currentSize, finalSize)
219-
return nil
220-
}
221+
// only resize if finalSize - currentSize is greater than 100MB
222+
if finalSize-currentSize < minimumResizeSize {
223+
klog.V(2).Infof("minimum resize difference (100MB) not met, skipping resize. volumeID=%s currentSize=%d finalSize=%d", volumeID, currentSize, finalSize)
224+
return nil
225+
}
221226

222-
result, _, err := cim.ResizePartition(part, finalSize)
223-
if result != 0 || err != nil {
224-
return fmt.Errorf("error resizing volume (%s). size:%v, finalSize %v, error: %v", volumeID, size, finalSize, err)
225-
}
227+
//if the partition's size is already the size we want this is a noop, just return
228+
if currentSize >= finalSize {
229+
klog.V(2).Infof("Attempted to resize volume (%s) to a lower size, from currentBytes=%d wantedBytes=%d", volumeID, currentSize, finalSize)
230+
return nil
231+
}
226232

227-
diskNumber, err := cim.GetPartitionDiskNumber(part)
228-
if err != nil {
229-
return fmt.Errorf("error parsing disk number of volume (%s). error: %v", volumeID, err)
230-
}
233+
result, _, err := cim.ResizePartition(part, finalSize)
234+
if result != 0 || err != nil {
235+
return fmt.Errorf("error resizing volume (%s). size:%v, finalSize %v, error: %v", volumeID, size, finalSize, err)
236+
}
231237

232-
disk, err := cim.QueryDiskByNumber(diskNumber, nil)
233-
if err != nil {
234-
return fmt.Errorf("error query disk of volume (%s). error: %v", volumeID, err)
235-
}
238+
diskNumber, err := cim.GetPartitionDiskNumber(part)
239+
if err != nil {
240+
return fmt.Errorf("error parsing disk number of volume (%s). error: %v", volumeID, err)
241+
}
236242

237-
result, _, err = cim.RefreshDisk(disk)
238-
if result != 0 || err != nil {
239-
return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err)
240-
}
243+
disk, err := cim.QueryDiskByNumber(diskNumber, nil)
244+
if err != nil {
245+
return fmt.Errorf("error query disk of volume (%s). error: %v", volumeID, err)
246+
}
241247

242-
return nil
248+
result, _, err = cim.RefreshDisk(disk)
249+
if result != 0 || err != nil {
250+
return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err)
251+
}
252+
253+
return nil
254+
})
243255
}
244256

245257
// GetVolumeStats - retrieves the volume stats for a given volume
246-
func (VolumeAPI) GetVolumeStats(volumeID string) (int64, int64, error) {
247-
volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForStats)
248-
if err != nil {
249-
return -1, -1, fmt.Errorf("error getting capacity and used size of volume (%s). error: %v", volumeID, err)
250-
}
258+
func (VolumeAPI) GetVolumeStats(volumeID string) (volumeSize, volumeUsedSize int64, err error) {
259+
volumeSize = -1
260+
volumeUsedSize = -1
261+
err = cim.WithCOMThread(func() error {
262+
volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForStats)
263+
if err != nil {
264+
return fmt.Errorf("error getting capacity and used size of volume (%s). error: %v", volumeID, err)
265+
}
251266

252-
volumeSize, err := cim.GetVolumeSize(volume)
253-
if err != nil {
254-
return -1, -1, fmt.Errorf("failed to query volume size (%s): %w", volumeID, err)
255-
}
267+
volumeSize, err = cim.GetVolumeSize(volume)
268+
if err != nil {
269+
return fmt.Errorf("failed to query volume size (%s): %w", volumeID, err)
270+
}
256271

257-
volumeSizeRemaining, err := cim.GetVolumeSizeRemaining(volume)
258-
if err != nil {
259-
return -1, -1, fmt.Errorf("failed to query volume remaining size (%s): %w", volumeID, err)
260-
}
272+
volumeSizeRemaining, err := cim.GetVolumeSizeRemaining(volume)
273+
if err != nil {
274+
return fmt.Errorf("failed to query volume remaining size (%s): %w", volumeID, err)
275+
}
261276

262-
volumeUsedSize := volumeSize - volumeSizeRemaining
263-
return volumeSize, volumeUsedSize, nil
277+
volumeUsedSize = volumeSize - volumeSizeRemaining
278+
return nil
279+
})
280+
return
264281
}
265282

266283
// GetDiskNumberFromVolumeID - gets the disk number where the volume is.
267284
func (VolumeAPI) GetDiskNumberFromVolumeID(volumeID string) (uint32, error) {
268-
// get the size and sizeRemaining for the volume
269-
part, err := cim.GetPartitionByVolumeUniqueID(volumeID)
270-
if err != nil {
271-
return 0, err
272-
}
285+
var diskNumber uint32
286+
err := cim.WithCOMThread(func() error {
287+
// get the size and sizeRemaining for the volume
288+
part, err := cim.GetPartitionByVolumeUniqueID(volumeID)
289+
if err != nil {
290+
return err
291+
}
273292

274-
diskNumber, err := cim.GetPartitionDiskNumber(part)
275-
if err != nil {
276-
return 0, fmt.Errorf("error query disk number of volume (%s). error: %v", volumeID, err)
277-
}
293+
diskNumber, err = cim.GetPartitionDiskNumber(part)
294+
if err != nil {
295+
return fmt.Errorf("error query disk number of volume (%s). error: %v", volumeID, err)
296+
}
278297

279-
return diskNumber, nil
298+
return nil
299+
})
300+
return diskNumber, err
280301
}
281302

282303
// GetVolumeIDFromTargetPath - gets the volume ID given a mount point, the function is recursive until it find a volume or errors out
@@ -395,28 +416,34 @@ func getVolumeForDriveLetter(path string) (string, error) {
395416
return "", fmt.Errorf("the path %s is not a valid drive letter", path)
396417
}
397418

398-
volume, err := cim.GetVolumeByDriveLetter(path, cim.VolumeSelectorListUniqueID)
399-
if err != nil {
400-
return "", nil
401-
}
419+
var uniqueID string
420+
err := cim.WithCOMThread(func() error {
421+
volume, err := cim.GetVolumeByDriveLetter(path, cim.VolumeSelectorListUniqueID)
422+
if err != nil {
423+
return err
424+
}
402425

403-
uniqueID, err := cim.GetVolumeUniqueID(volume)
404-
if err != nil {
405-
return "", fmt.Errorf("error query unique ID of volume (%v). error: %v", volume, err)
406-
}
426+
uniqueID, err = cim.GetVolumeUniqueID(volume)
427+
if err != nil {
428+
return fmt.Errorf("error query unique ID of volume (%v). error: %v", volume, err)
429+
}
407430

408-
return uniqueID, nil
431+
return nil
432+
})
433+
return uniqueID, err
409434
}
410435

411436
func writeCache(volumeID string) error {
412-
volume, err := cim.QueryVolumeByUniqueID(volumeID, nil)
413-
if err != nil {
414-
return fmt.Errorf("error writing volume (%s) cache. error: %v", volumeID, err)
415-
}
437+
return cim.WithCOMThread(func() error {
438+
volume, err := cim.QueryVolumeByUniqueID(volumeID, nil)
439+
if err != nil {
440+
return fmt.Errorf("error writing volume (%s) cache. error: %v", volumeID, err)
441+
}
416442

417-
result, err := cim.FlushVolume(volume)
418-
if result != 0 || err != nil {
419-
return fmt.Errorf("error writing volume (%s) cache. result: %d, error: %v", volumeID, result, err)
420-
}
421-
return nil
443+
result, err := cim.FlushVolume(volume)
444+
if result != 0 || err != nil {
445+
return fmt.Errorf("error writing volume (%s) cache. result: %d, error: %v", volumeID, result, err)
446+
}
447+
return nil
448+
})
422449
}

0 commit comments

Comments
 (0)