Skip to content

Commit ef50b9d

Browse files
committed
Ensure COM threading apartment in Service APIs
1 parent a05b58c commit ef50b9d

File tree

1 file changed

+69
-56
lines changed

1 file changed

+69
-56
lines changed

pkg/os/system/api.go

Lines changed: 69 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -107,45 +107,54 @@ func New() APIImplementor {
107107
}
108108

109109
func (APIImplementor) GetBIOSSerialNumber() (string, error) {
110-
bios, err := cim.QueryBIOSElement(cim.BIOSSelectorList)
111-
if err != nil {
112-
return "", fmt.Errorf("failed to get BIOS element: %w", err)
113-
}
110+
var sn string
111+
err := cim.WithCOMThread(func() error {
112+
bios, err := cim.QueryBIOSElement(cim.BIOSSelectorList)
113+
if err != nil {
114+
return fmt.Errorf("failed to get BIOS element: %w", err)
115+
}
114116

115-
sn, err := cim.GetBIOSSerialNumber(bios)
116-
if err != nil {
117-
return "", fmt.Errorf("failed to get BIOS serial number property: %w", err)
118-
}
117+
sn, err = cim.GetBIOSSerialNumber(bios)
118+
if err != nil {
119+
return fmt.Errorf("failed to get BIOS serial number property: %w", err)
120+
}
119121

120-
return sn, nil
122+
return nil
123+
})
124+
return sn, err
121125
}
122126

123127
func (impl APIImplementor) GetService(name string) (*ServiceInfo, error) {
124-
service, err := impl.serviceFactory.GetService(name)
125-
if err != nil {
126-
return nil, fmt.Errorf("failed to get service %s: %w", name, err)
127-
}
128+
var serviceInfo *ServiceInfo
129+
err := cim.WithCOMThread(func() error {
130+
service, err := impl.serviceFactory.GetService(name)
131+
if err != nil {
132+
return fmt.Errorf("failed to get service %s: %w", name, err)
133+
}
128134

129-
displayName, err := cim.GetServiceDisplayName(service)
130-
if err != nil {
131-
return nil, fmt.Errorf("failed to get displayName property of service %s: %w", name, err)
132-
}
135+
displayName, err := cim.GetServiceDisplayName(service)
136+
if err != nil {
137+
return fmt.Errorf("failed to get displayName property of service %s: %w", name, err)
138+
}
133139

134-
state, err := cim.GetServiceState(service)
135-
if err != nil {
136-
return nil, fmt.Errorf("failed to get state property of service %s: %w", name, err)
137-
}
140+
state, err := cim.GetServiceState(service)
141+
if err != nil {
142+
return fmt.Errorf("failed to get state property of service %s: %w", name, err)
143+
}
138144

139-
startMode, err := cim.GetServiceStartMode(service)
140-
if err != nil {
141-
return nil, fmt.Errorf("failed to get startMode property of service %s: %w", name, err)
142-
}
145+
startMode, err := cim.GetServiceStartMode(service)
146+
if err != nil {
147+
return fmt.Errorf("failed to get startMode property of service %s: %w", name, err)
148+
}
143149

144-
return &ServiceInfo{
145-
DisplayName: displayName,
146-
StartType: serviceStartModeToStartType(startMode),
147-
Status: serviceState(state),
148-
}, nil
150+
serviceInfo = &ServiceInfo{
151+
DisplayName: displayName,
152+
StartType: serviceStartModeToStartType(startMode),
153+
Status: serviceState(state),
154+
}
155+
return nil
156+
})
157+
return serviceInfo, err
149158
}
150159

151160
func (impl APIImplementor) StartService(name string) error {
@@ -171,21 +180,23 @@ func (impl APIImplementor) StartService(name string) error {
171180
return state == serviceStateRunning, newState, err
172181
}
173182

174-
service, err := impl.serviceFactory.GetService(name)
175-
if err != nil {
176-
return err
177-
}
183+
return cim.WithCOMThread(func() error {
184+
service, err := impl.serviceFactory.GetService(name)
185+
if err != nil {
186+
return err
187+
}
178188

179-
state, err := impl.serviceManager.WaitUntilServiceState(service, startService, serviceRunningCheck, serviceStateCheckInternal, serviceStateCheckTimeout)
180-
if err != nil && !errors.Is(err, errTimedOut) {
181-
return err
182-
}
189+
state, err := impl.serviceManager.WaitUntilServiceState(service, startService, serviceRunningCheck, serviceStateCheckInternal, serviceStateCheckTimeout)
190+
if err != nil && !errors.Is(err, errTimedOut) {
191+
return err
192+
}
183193

184-
if state != serviceStateRunning {
185-
return fmt.Errorf("timed out waiting for service %s to become running", name)
186-
}
194+
if state != serviceStateRunning {
195+
return fmt.Errorf("timed out waiting for service %s to become running", name)
196+
}
187197

188-
return nil
198+
return nil
199+
})
189200
}
190201

191202
func (impl APIImplementor) stopSingleService(name string) (bool, error) {
@@ -234,24 +245,26 @@ func (impl APIImplementor) stopSingleService(name string) (bool, error) {
234245
}
235246

236247
func (impl APIImplementor) StopService(name string, force bool) error {
237-
dependentRunning, err := impl.stopSingleService(name)
238-
if err == nil || !dependentRunning || !force {
239-
return err
240-
}
241-
242-
serviceNames, err := impl.serviceManager.GetDependentsForService(name)
243-
if err != nil {
244-
return fmt.Errorf("error getting dependent services for service name %s", name)
245-
}
248+
return cim.WithCOMThread(func() error {
249+
dependentRunning, err := impl.stopSingleService(name)
250+
if err == nil || !dependentRunning || !force {
251+
return err
252+
}
246253

247-
for _, serviceName := range serviceNames {
248-
_, err = impl.stopSingleService(serviceName)
254+
serviceNames, err := impl.serviceManager.GetDependentsForService(name)
249255
if err != nil {
250-
return err
256+
return fmt.Errorf("error getting dependent services for service name %s", name)
257+
}
258+
259+
for _, serviceName := range serviceNames {
260+
_, err = impl.stopSingleService(serviceName)
261+
if err != nil {
262+
return err
263+
}
251264
}
252-
}
253265

254-
return nil
266+
return nil
267+
})
255268
}
256269

257270
type ServiceManagerImpl struct {

0 commit comments

Comments
 (0)