Skip to content

Commit b37e59e

Browse files
committed
Ensure COM threading apartment in Service APIs
1 parent 3cf0220 commit b37e59e

File tree

1 file changed

+72
-59
lines changed

1 file changed

+72
-59
lines changed

pkg/os/system/api.go

Lines changed: 72 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -107,45 +107,54 @@ func New() APIImplementor {
107107
}
108108

109109
func (APIImplementor) GetBIOSSerialNumber() (string, error) {
110-
bios, err := cim.QueryBIOSElement(cim.BIOSSelectorList)
111-
if err != nil {
112-
return "", fmt.Errorf("failed to get BIOS element: %w", err)
113-
}
110+
var sn string
111+
err := cim.WithCOMThread(func() error {
112+
bios, err := cim.QueryBIOSElement(cim.BIOSSelectorList)
113+
if err != nil {
114+
return fmt.Errorf("failed to get BIOS element: %w", err)
115+
}
114116

115-
sn, err := cim.GetBIOSSerialNumber(bios)
116-
if err != nil {
117-
return "", fmt.Errorf("failed to get BIOS serial number property: %w", err)
118-
}
117+
sn, err = cim.GetBIOSSerialNumber(bios)
118+
if err != nil {
119+
return fmt.Errorf("failed to get BIOS serial number property: %w", err)
120+
}
119121

120-
return sn, nil
122+
return nil
123+
})
124+
return sn, err
121125
}
122126

123127
func (impl APIImplementor) GetService(name string) (*ServiceInfo, error) {
124-
service, err := impl.serviceFactory.GetService(name)
125-
if err != nil {
126-
return nil, fmt.Errorf("failed to get service %s. error: %w", name, err)
127-
}
128+
var serviceInfo *ServiceInfo
129+
err := cim.WithCOMThread(func() error {
130+
service, err := impl.serviceFactory.GetService(name)
131+
if err != nil {
132+
return fmt.Errorf("failed to get service %s. error: %w", name, err)
133+
}
128134

129-
displayName, err := cim.GetServiceDisplayName(service)
130-
if err != nil {
131-
return nil, fmt.Errorf("failed to get displayName property of service %s: %w", name, err)
132-
}
135+
displayName, err := cim.GetServiceDisplayName(service)
136+
if err != nil {
137+
return fmt.Errorf("failed to get displayName property of service %s: %w", name, err)
138+
}
133139

134-
state, err := cim.GetServiceState(service)
135-
if err != nil {
136-
return nil, fmt.Errorf("failed to get state property of service %s: %w", name, err)
137-
}
140+
state, err := cim.GetServiceState(service)
141+
if err != nil {
142+
return fmt.Errorf("failed to get state property of service %s: %w", name, err)
143+
}
138144

139-
startMode, err := cim.GetServiceStartMode(service)
140-
if err != nil {
141-
return nil, fmt.Errorf("failed to get startMode property of service %s: %w", name, err)
142-
}
145+
startMode, err := cim.GetServiceStartMode(service)
146+
if err != nil {
147+
return fmt.Errorf("failed to get startMode property of service %s: %w", name, err)
148+
}
143149

144-
return &ServiceInfo{
145-
DisplayName: displayName,
146-
StartType: serviceStartModeToStartType(startMode),
147-
Status: serviceState(state),
148-
}, nil
150+
serviceInfo = &ServiceInfo{
151+
DisplayName: displayName,
152+
StartType: serviceStartModeToStartType(startMode),
153+
Status: serviceState(state),
154+
}
155+
return nil
156+
})
157+
return serviceInfo, err
149158
}
150159

151160
func (impl APIImplementor) StartService(name string) error {
@@ -171,21 +180,23 @@ func (impl APIImplementor) StartService(name string) error {
171180
return state == serviceStateRunning, newState, err
172181
}
173182

174-
service, err := impl.serviceFactory.GetService(name)
175-
if err != nil {
176-
return fmt.Errorf("failed to get service %s. error: %w", name, err)
177-
}
183+
return cim.WithCOMThread(func() error {
184+
service, err := impl.serviceFactory.GetService(name)
185+
if err != nil {
186+
return fmt.Errorf("failed to get service %s. error: %w", name, err)
187+
}
178188

179-
state, err := impl.serviceManager.WaitUntilServiceState(service, startService, serviceRunningCheck, serviceStateCheckInternal, serviceStateCheckTimeout)
180-
if err != nil && !errors.Is(err, errTimedOut) {
181-
return fmt.Errorf("failed to wait for service %s state change. error: %w", name, err)
182-
}
189+
state, err := impl.serviceManager.WaitUntilServiceState(service, startService, serviceRunningCheck, serviceStateCheckInternal, serviceStateCheckTimeout)
190+
if err != nil && !errors.Is(err, errTimedOut) {
191+
return fmt.Errorf("failed to wait for service %s state change. error: %w", name, err)
192+
}
183193

184-
if state != serviceStateRunning {
185-
return fmt.Errorf("timed out waiting for service %s to become running", name)
186-
}
194+
if state != serviceStateRunning {
195+
return fmt.Errorf("timed out waiting for service %s to become running", name)
196+
}
187197

188-
return nil
198+
return nil
199+
})
189200
}
190201

191202
func (impl APIImplementor) stopSingleService(name string) (bool, error) {
@@ -234,27 +245,29 @@ func (impl APIImplementor) stopSingleService(name string) (bool, error) {
234245
}
235246

236247
func (impl APIImplementor) StopService(name string, force bool) error {
237-
dependentRunning, err := impl.stopSingleService(name)
238-
if err == nil {
239-
return nil
240-
}
241-
if !dependentRunning || !force {
242-
return fmt.Errorf("failed to stop service %s. error: %w", name, err)
243-
}
244-
245-
serviceNames, err := impl.serviceManager.GetDependentsForService(name)
246-
if err != nil {
247-
return fmt.Errorf("error getting dependent services for service name %s", name)
248-
}
248+
return cim.WithCOMThread(func() error {
249+
dependentRunning, err := impl.stopSingleService(name)
250+
if err == nil {
251+
return nil
252+
}
253+
if !dependentRunning || !force {
254+
return fmt.Errorf("failed to stop service %s. error: %w", name, err)
255+
}
249256

250-
for _, serviceName := range serviceNames {
251-
_, err = impl.stopSingleService(serviceName)
257+
serviceNames, err := impl.serviceManager.GetDependentsForService(name)
252258
if err != nil {
253-
return fmt.Errorf("failed to stop service %s. error: %w", name, err)
259+
return fmt.Errorf("error getting dependent services for service name %s", name)
260+
}
261+
262+
for _, serviceName := range serviceNames {
263+
_, err = impl.stopSingleService(serviceName)
264+
if err != nil {
265+
return fmt.Errorf("failed to stop service %s. error: %w", name, err)
266+
}
254267
}
255-
}
256268

257-
return nil
269+
return nil
270+
})
258271
}
259272

260273
type ServiceManagerImpl struct {

0 commit comments

Comments
 (0)