Skip to content

Commit 49f3d9e

Browse files
committed
Ensure COM threading apartment in Service APIs
1 parent 7259733 commit 49f3d9e

File tree

1 file changed

+82
-69
lines changed

1 file changed

+82
-69
lines changed

pkg/os/system/api.go

Lines changed: 82 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -106,79 +106,90 @@ func New() APIImplementor {
106106
}
107107

108108
func (APIImplementor) GetBIOSSerialNumber() (string, error) {
109-
bios, err := cim.QueryBIOSElement(cim.BIOSSelectorList)
110-
if err != nil {
111-
return "", fmt.Errorf("failed to get BIOS element: %w", err)
112-
}
109+
var sn string
110+
err := cim.WithCOMThread(func() error {
111+
bios, err := cim.QueryBIOSElement(cim.BIOSSelectorList)
112+
if err != nil {
113+
return fmt.Errorf("failed to get BIOS element: %w", err)
114+
}
113115

114-
sn, err := cim.GetBIOSSerialNumber(bios)
115-
if err != nil {
116-
return "", fmt.Errorf("failed to get BIOS serial number property: %w", err)
117-
}
116+
sn, err = cim.GetBIOSSerialNumber(bios)
117+
if err != nil {
118+
return fmt.Errorf("failed to get BIOS serial number property: %w", err)
119+
}
118120

119-
return sn, nil
121+
return nil
122+
})
123+
return sn, err
120124
}
121125

122126
func (impl APIImplementor) GetService(name string) (*ServiceInfo, error) {
123-
service, err := impl.serviceFactory.GetService(name)
124-
if err != nil {
125-
return nil, fmt.Errorf("failed to get service %s: %w", name, err)
126-
}
127+
var serviceInfo *ServiceInfo
128+
err := cim.WithCOMThread(func() error {
129+
service, err := impl.serviceFactory.GetService(name)
130+
if err != nil {
131+
return fmt.Errorf("failed to get service %s: %w", name, err)
132+
}
127133

128-
displayName, err := cim.GetServiceDisplayName(service)
129-
if err != nil {
130-
return nil, fmt.Errorf("failed to get displayName property of service %s: %w", name, err)
131-
}
134+
displayName, err := cim.GetServiceDisplayName(service)
135+
if err != nil {
136+
return fmt.Errorf("failed to get displayName property of service %s: %w", name, err)
137+
}
132138

133-
state, err := cim.GetServiceState(service)
134-
if err != nil {
135-
return nil, fmt.Errorf("failed to get state property of service %s: %w", name, err)
136-
}
139+
state, err := cim.GetServiceState(service)
140+
if err != nil {
141+
return fmt.Errorf("failed to get state property of service %s: %w", name, err)
142+
}
137143

138-
startMode, err := cim.GetServiceStartMode(service)
139-
if err != nil {
140-
return nil, fmt.Errorf("failed to get startMode property of service %s: %w", name, err)
141-
}
144+
startMode, err := cim.GetServiceStartMode(service)
145+
if err != nil {
146+
return fmt.Errorf("failed to get startMode property of service %s: %w", name, err)
147+
}
142148

143-
return &ServiceInfo{
144-
DisplayName: displayName,
145-
StartType: serviceStartModeToStartType(startMode),
146-
Status: serviceState(state),
147-
}, nil
149+
serviceInfo = &ServiceInfo{
150+
DisplayName: displayName,
151+
StartType: serviceStartModeToStartType(startMode),
152+
Status: serviceState(state),
153+
}
154+
return nil
155+
})
156+
return serviceInfo, err
148157
}
149158

150159
func (impl APIImplementor) StartService(name string) error {
151-
startService := func(service cim.ServiceInterface) error {
152-
retVal, err := service.StartService()
153-
if err != nil || (retVal != startServiceErrorCodeAccepted && retVal != startServiceErrorCodeAlreadyRunning) {
154-
return fmt.Errorf("error starting service name %s. return value: %d, error: %v", name, retVal, err)
155-
}
156-
return nil
157-
}
158-
serviceRunningCheck := func() (bool, string, cim.ServiceInterface, error) {
159-
service, err := impl.serviceFactory.GetService(name)
160-
if err != nil {
161-
return false, "", nil, err
160+
return cim.WithCOMThread(func() error {
161+
startService := func(service cim.ServiceInterface) error {
162+
retVal, err := service.StartService()
163+
if err != nil || (retVal != startServiceErrorCodeAccepted && retVal != startServiceErrorCodeAlreadyRunning) {
164+
return fmt.Errorf("error starting service name %s. return value: %d, error: %v", name, retVal, err)
165+
}
166+
return nil
162167
}
168+
serviceRunningCheck := func() (bool, string, cim.ServiceInterface, error) {
169+
service, err := impl.serviceFactory.GetService(name)
170+
if err != nil {
171+
return false, "", nil, err
172+
}
163173

164-
state, err := cim.GetServiceState(service)
165-
if err != nil {
166-
return false, state, service, err
167-
}
174+
state, err := cim.GetServiceState(service)
175+
if err != nil {
176+
return false, state, service, err
177+
}
168178

169-
return state == serviceStateRunning, state, service, err
170-
}
179+
return state == serviceStateRunning, state, service, err
180+
}
171181

172-
state, err := impl.serviceManager.WaitUntilServiceState(startService, serviceRunningCheck, serviceStateCheckInternal, serviceStateCheckTimeout)
173-
if err != nil && !errors.Is(err, errTimedOut) {
174-
return err
175-
}
182+
state, err := impl.serviceManager.WaitUntilServiceState(startService, serviceRunningCheck, serviceStateCheckInternal, serviceStateCheckTimeout)
183+
if err != nil && !errors.Is(err, errTimedOut) {
184+
return err
185+
}
176186

177-
if state != serviceStateRunning {
178-
return fmt.Errorf("timed out waiting for service %s to become running", name)
179-
}
187+
if state != serviceStateRunning {
188+
return fmt.Errorf("timed out waiting for service %s to become running", name)
189+
}
180190

181-
return nil
191+
return nil
192+
})
182193
}
183194

184195
func (impl APIImplementor) stopSingleService(name string) (bool, error) {
@@ -221,24 +232,26 @@ func (impl APIImplementor) stopSingleService(name string) (bool, error) {
221232
}
222233

223234
func (impl APIImplementor) StopService(name string, force bool) error {
224-
dependentRunning, err := impl.stopSingleService(name)
225-
if err == nil || !dependentRunning || !force {
226-
return err
227-
}
228-
229-
serviceNames, err := impl.serviceManager.GetDependentsForService(name)
230-
if err != nil {
231-
return fmt.Errorf("error getting dependent services for service name %s", name)
232-
}
235+
return cim.WithCOMThread(func() error {
236+
dependentRunning, err := impl.stopSingleService(name)
237+
if err == nil || !dependentRunning || !force {
238+
return err
239+
}
233240

234-
for _, serviceName := range serviceNames {
235-
_, err = impl.stopSingleService(serviceName)
241+
serviceNames, err := impl.serviceManager.GetDependentsForService(name)
236242
if err != nil {
237-
return err
243+
return fmt.Errorf("error getting dependent services for service name %s", name)
244+
}
245+
246+
for _, serviceName := range serviceNames {
247+
_, err = impl.stopSingleService(serviceName)
248+
if err != nil {
249+
return err
250+
}
238251
}
239-
}
240252

241-
return nil
253+
return nil
254+
})
242255
}
243256

244257
type ServiceManagerImpl struct {

0 commit comments

Comments
 (0)