Skip to content

Commit f2f9e17

Browse files
authored
feat: set up interruptible generation in Studio (#1319)
GEN-1145 *WARNING* This PR relies on [openapi-generation PR#2367](speakeasy-api/openapi-generation#2367) being merged first and should be merged alongside [speakeasy-registry PR#4134](speakeasy-api/speakeasy-registry#4134) This PR sets up the ability to: - Cancel generation before it is complete - Receive progress udpates from the generator (new steps, file status changes, main README content) - Update the Studio UI based on received updates --- **1. Cancelability** To control cancelability, a `CancellableGeneration` field was added to the `Workflow` struct to hold the *cancellable context* and the *cancel function*. When this field is non-nil, `g.GenerateWithCancel` will be used instead of `g.Generate` (c.f. open-generation PR#2367) --- **2. Streamability** To control streamability, a `StreamableGeneneration` field was added to the `Workflow` struct to hold the `OnProgressUpdate` callback function that will get called everytime the generator sends an update: - if a new step has started then `progressUpdate.Step` will be non-nil - if a file status has changed then the `progressUpdate.File` will non-nil - a `progressUpdate.File.IsMainReadme` boolean is also set to indicate that the file being modified is the Readme. If it is `true`, then the `progressUpdate.File.Content` will also be set --- **3. Studio events** At the moment, the Studio UI is being notified when: - the `genSDK` and `compileSDK` steps started - the main `README` content got received
1 parent f78b016 commit f2f9e17

37 files changed

+1015
-630
lines changed

internal/log/logListener.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ var (
1212
MsgWarn MsgType = "warn"
1313
MsgError MsgType = "error"
1414
MsgGithub MsgType = "github"
15+
MsgStudio MsgType = "studio"
1516
)

internal/run/source.go

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,26 @@ import (
2727
"github.com/speakeasy-api/speakeasy/internal/workflowTracking"
2828
)
2929

30+
type SourceResultCallback func(sourceRes *SourceResult, sourceStep SourceStepID) error
31+
32+
type SourceStepID string
33+
34+
const (
35+
// CLI steps
36+
SourceStepFetch SourceStepID = "Fetching spec"
37+
SourceStepOverlay SourceStepID = "Overlaying"
38+
SourceStepTransform SourceStepID = "Transforming"
39+
SourceStepLint SourceStepID = "Linting"
40+
SourceStepUpload SourceStepID = "Uploading spec"
41+
// Generator steps
42+
SourceStepStart SourceStepID = "Started"
43+
SourceStepGenerate SourceStepID = "Generating SDK"
44+
SourceStepCompile SourceStepID = "Compiling SDK"
45+
SourceStepComplete SourceStepID = "Completed"
46+
SourceStepCancel SourceStepID = "Cancelling"
47+
SourceStepExit SourceStepID = "Exiting"
48+
)
49+
3050
type SourceResult struct {
3151
Source string
3252
// The merged OAS spec that was input to the source contents as a string
@@ -55,7 +75,7 @@ func (e *LintingError) Error() string {
5575
return fmt.Sprintf("linting failed: %s - %s", e.Document, errString)
5676
}
5777

58-
func (w *Workflow) RunSource(ctx context.Context, parentStep *workflowTracking.WorkflowStep, sourceID, targetID string, targetLanguage string) (string, *SourceResult, error) {
78+
func (w *Workflow) RunSource(ctx context.Context, parentStep *workflowTracking.WorkflowStep, sourceID, targetID, targetLanguage string) (string, *SourceResult, error) {
5979
rootStep := parentStep.NewSubstep(fmt.Sprintf("Source: %s", sourceID))
6080
source := w.workflow.Sources[sourceID]
6181
sourceRes := &SourceResult{
@@ -64,9 +84,9 @@ func (w *Workflow) RunSource(ctx context.Context, parentStep *workflowTracking.W
6484
}
6585
defer func() {
6686
w.SourceResults[sourceID] = sourceRes
67-
w.OnSourceResult(sourceRes, "")
87+
w.OnSourceResult(sourceRes, SourceStepComplete)
6888
}()
69-
w.OnSourceResult(sourceRes, "Fetching spec")
89+
w.OnSourceResult(sourceRes, SourceStepFetch)
7090

7191
rulesetToUse := "speakeasy-generation"
7292
if source.Ruleset != nil {
@@ -122,29 +142,29 @@ func (w *Workflow) RunSource(ctx context.Context, parentStep *workflowTracking.W
122142
}
123143

124144
if len(source.Overlays) > 0 && !w.FrozenWorkflowLock {
125-
w.OnSourceResult(sourceRes, "Overlaying")
145+
w.OnSourceResult(sourceRes, SourceStepOverlay)
126146
currentDocument, err = NewOverlay(rootStep, source).Do(ctx, currentDocument)
127147
if err != nil {
128148
return "", nil, err
129149
}
130150
}
131151

132152
if len(source.Transformations) > 0 && !w.FrozenWorkflowLock {
133-
w.OnSourceResult(sourceRes, "Transforming")
153+
w.OnSourceResult(sourceRes, SourceStepTransform)
134154
currentDocument, err = NewTransform(rootStep, source).Do(ctx, currentDocument)
135155
if err != nil {
136156
return "", nil, err
137157
}
138158
}
139159

140160
if err := writeToOutputLocation(ctx, currentDocument, outputLocation); err != nil {
141-
return "", nil, fmt.Errorf("failed to write to output location: %w", err)
161+
return "", nil, fmt.Errorf("failed to write to output location: %w %s?", err, outputLocation)
142162
}
143163
currentDocument = outputLocation
144164
sourceRes.OutputPath = currentDocument
145165

146166
if !w.SkipLinting {
147-
w.OnSourceResult(sourceRes, "Linting")
167+
w.OnSourceResult(sourceRes, SourceStepLint)
148168
sourceRes.LintResult, err = w.validateDocument(ctx, rootStep, sourceID, currentDocument, rulesetToUse, w.ProjectDir, targetLanguage)
149169
if err != nil {
150170
return "", sourceRes, &LintingError{Err: err, Document: currentDocument}
@@ -159,7 +179,7 @@ func (w *Workflow) RunSource(ctx context.Context, parentStep *workflowTracking.W
159179
}
160180
step.Succeed()
161181

162-
w.OnSourceResult(sourceRes, "Uploading spec")
182+
w.OnSourceResult(sourceRes, SourceStepUpload)
163183

164184
if !w.SkipSnapshot {
165185
err = w.snapshotSource(ctx, rootStep, sourceID, source, currentDocument)

internal/run/target.go

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -147,35 +147,57 @@ func (w *Workflow) runTarget(ctx context.Context, target string) (*SourceResult,
147147
}
148148
}
149149

150-
genStep := rootStep.NewSubstep(fmt.Sprintf("Generating %s SDK", utils.CapitalizeFirst(t.Target)))
151-
152150
logListener := make(chan log.Msg)
153151
logger := log.From(ctx).WithListener(logListener)
154152
ctx = log.With(ctx, logger)
153+
154+
if w.StreamableGeneration != nil && w.Debug {
155+
w.StreamableGeneration.LogListener = logListener
156+
}
157+
158+
if w.CancellableGeneration != nil {
159+
cancelCtx, cancelFunc := context.WithCancel(ctx)
160+
w.CancellableGeneration.CancellationMutex.Lock()
161+
w.CancellableGeneration.CancellableContext = cancelCtx
162+
w.CancellableGeneration.CancelGeneration = cancelFunc
163+
w.CancellableGeneration.CancellationMutex.Unlock()
164+
165+
defer func() {
166+
w.CancellableGeneration.CancellationMutex.Lock()
167+
w.CancellableGeneration.CancelGeneration = nil
168+
w.CancellableGeneration.CancellableContext = nil
169+
w.CancellableGeneration.CancellationMutex.Unlock()
170+
cancelFunc() // Ensure context is cleaned up
171+
}()
172+
}
173+
174+
genStep := rootStep.NewSubstep(fmt.Sprintf("Generating %s SDK", utils.CapitalizeFirst(t.Target)))
155175
go genStep.ListenForSubsteps(logListener)
156176

157177
generationAccess, err := sdkgen.Generate(
158178
ctx,
159179
sdkgen.GenerateOptions{
160-
CustomerID: config.GetCustomerID(),
161-
WorkspaceID: config.GetWorkspaceID(),
162-
Language: t.Target,
163-
SchemaPath: sourcePath,
164-
Header: "",
165-
Token: "",
166-
OutDir: outDir,
167-
CLIVersion: events.GetSpeakeasyVersionFromContext(ctx),
168-
InstallationURL: w.InstallationURLs[target],
169-
Debug: w.Debug,
170-
AutoYes: true,
171-
Published: published,
172-
OutputTests: false,
173-
Repo: w.Repo,
174-
RepoSubDir: w.RepoSubDirs[target],
175-
Verbose: w.Verbose,
176-
Compile: w.ShouldCompile,
177-
TargetName: target,
178-
SkipVersioning: w.SkipVersioning,
180+
CustomerID: config.GetCustomerID(),
181+
WorkspaceID: config.GetWorkspaceID(),
182+
Language: t.Target,
183+
SchemaPath: sourcePath,
184+
Header: "",
185+
Token: "",
186+
OutDir: outDir,
187+
CLIVersion: events.GetSpeakeasyVersionFromContext(ctx),
188+
InstallationURL: w.InstallationURLs[target],
189+
Debug: w.Debug,
190+
AutoYes: true,
191+
Published: published,
192+
OutputTests: false,
193+
Repo: w.Repo,
194+
RepoSubDir: w.RepoSubDirs[target],
195+
Verbose: w.Verbose,
196+
Compile: w.ShouldCompile,
197+
TargetName: target,
198+
SkipVersioning: w.SkipVersioning,
199+
CancellableGeneration: w.CancellableGeneration,
200+
StreamableGeneration: w.StreamableGeneration,
179201
},
180202
)
181203

@@ -388,3 +410,16 @@ func (w *Workflow) printTargetSuccessMessage(ctx context.Context) {
388410
msg := fmt.Sprintf("%s\n%s\n", styles.Success.Render(heading), strings.Join(additionalLines, "\n"))
389411
log.From(ctx).Println(msg)
390412
}
413+
414+
func (w *Workflow) CancelGeneration() error {
415+
if w.CancellableGeneration != nil {
416+
w.CancellableGeneration.CancellationMutex.Lock()
417+
defer w.CancellableGeneration.CancellationMutex.Unlock()
418+
if w.CancellableGeneration.CancelGeneration != nil {
419+
w.CancellableGeneration.CancelGeneration()
420+
return nil
421+
}
422+
}
423+
424+
return fmt.Errorf("Generation is not cancellable")
425+
}

internal/run/workflow.go

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ package run
33
import (
44
"context"
55
"fmt"
6+
"sync"
67
"time"
78

89
"github.com/speakeasy-api/speakeasy/registry"
910

11+
"github.com/speakeasy-api/openapi-generation/v2/pkg/generate"
1012
"github.com/speakeasy-api/sdk-gen-config/workflow"
1113
"github.com/speakeasy-api/speakeasy-core/events"
1214
"github.com/speakeasy-api/speakeasy/internal/log"
@@ -57,10 +59,14 @@ type Workflow struct {
5759
computedChanges map[string]bool
5860
SourceResults map[string]*SourceResult
5961
TargetResults map[string]*TargetResult
60-
OnSourceResult func(*SourceResult, string)
62+
OnSourceResult SourceResultCallback
6163
Duration time.Duration
6264
criticalWarns []string
6365
Error error
66+
67+
// Studio
68+
CancellableGeneration *sdkgen.CancellableGeneration
69+
StreamableGeneration *sdkgen.StreamableGeneration
6470
}
6571

6672
type Opt func(w *Workflow)
@@ -100,7 +106,7 @@ func NewWorkflow(
100106
ForceGeneration: false,
101107
SourceResults: make(map[string]*SourceResult),
102108
TargetResults: make(map[string]*TargetResult),
103-
OnSourceResult: func(*SourceResult, string) {},
109+
OnSourceResult: func(*SourceResult, SourceStepID) error { return nil },
104110
computedChanges: make(map[string]bool),
105111
lockfile: lockfile,
106112
lockfileOld: lockfileOld,
@@ -262,6 +268,42 @@ func WithRegistryTags(registryTags []string) Opt {
262268
}
263269
}
264270

271+
func WithSourceUpdates(onSourceResult SourceResultCallback) Opt {
272+
if onSourceResult != nil {
273+
return func(w *Workflow) {
274+
w.OnSourceResult = onSourceResult
275+
}
276+
}
277+
278+
return func(w *Workflow) {
279+
w.OnSourceResult = func(sourceRes *SourceResult, sourceStep SourceStepID) error { return nil }
280+
}
281+
}
282+
283+
func WithCancellableGeneration(cancellable bool) Opt {
284+
return func(w *Workflow) {
285+
if cancellable {
286+
w.CancellableGeneration = &sdkgen.CancellableGeneration{
287+
CancellationMutex: sync.Mutex{},
288+
// CancelCtx and CancelFunc fields depend on the runTarget context
289+
// and will be set right before generation starts.
290+
}
291+
} else {
292+
w.CancellableGeneration = nil
293+
}
294+
}
295+
}
296+
297+
func WithStreamableGeneration(onProgressUpdate func(generate.ProgressUpdate), genSteps, fileStatus bool) Opt {
298+
return func(w *Workflow) {
299+
w.StreamableGeneration = &sdkgen.StreamableGeneration{
300+
OnProgressUpdate: onProgressUpdate,
301+
GenSteps: genSteps,
302+
FileStatus: fileStatus,
303+
}
304+
}
305+
}
306+
265307
func (w *Workflow) CountDiagnostics() int {
266308
count := 0
267309
for _, sourceResult := range w.SourceResults {

internal/sdkgen/sdkgen.go

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"os"
88
"path/filepath"
99
"strings"
10+
"sync"
1011

1112
"github.com/charmbracelet/lipgloss"
1213
"github.com/speakeasy-api/speakeasy-core/auth"
@@ -33,6 +34,19 @@ type GenerationAccess struct {
3334
Level *shared.Level
3435
}
3536

37+
type CancellableGeneration struct {
38+
CancellationMutex sync.Mutex // protects both CancellableContext and CancelGeneration (exposed by w.CancelGeneration())
39+
CancellableContext context.Context // the context that can be cancelled to stop generation
40+
CancelGeneration context.CancelFunc // the function to call to cancel generation
41+
}
42+
43+
type StreamableGeneration struct {
44+
OnProgressUpdate func(generate.ProgressUpdate) // the callback function called on each progress update
45+
GenSteps bool // whether to receive an update before each generation step starts
46+
FileStatus bool // whether to receive updates on each file status change
47+
LogListener chan log.Msg // the channel to listen for log messages (Debug only)
48+
}
49+
3650
type GenerateOptions struct {
3751
CustomerID string
3852
WorkspaceID string
@@ -53,6 +67,9 @@ type GenerateOptions struct {
5367
Compile bool
5468
TargetName string
5569
SkipVersioning bool
70+
71+
CancellableGeneration *CancellableGeneration
72+
StreamableGeneration *StreamableGeneration
5673
}
5774

5875
func Generate(ctx context.Context, opts GenerateOptions) (*GenerationAccess, error) {
@@ -146,6 +163,18 @@ func Generate(ctx context.Context, opts GenerateOptions) (*GenerationAccess, err
146163
generatorOpts = append(generatorOpts, generate.WithSkipVersioning(opts.SkipVersioning))
147164
}
148165

166+
if opts.StreamableGeneration != nil {
167+
generatorOpts = append(
168+
generatorOpts,
169+
generate.WithProgressUpdates(
170+
opts.TargetName,
171+
opts.StreamableGeneration.OnProgressUpdate,
172+
opts.StreamableGeneration.GenSteps,
173+
opts.StreamableGeneration.FileStatus,
174+
),
175+
)
176+
}
177+
149178
g, err := generate.New(generatorOpts...)
150179
if err != nil {
151180
return &GenerationAccess{
@@ -157,13 +186,28 @@ func Generate(ctx context.Context, opts GenerateOptions) (*GenerationAccess, err
157186

158187
err = events.Telemetry(ctx, shared.InteractionTypeTargetGenerate, func(ctx context.Context, event *shared.CliEvent) error {
159188
event.GenerateTargetName = &opts.TargetName
160-
if errs := g.Generate(ctx, schema, opts.SchemaPath, opts.Language, opts.OutDir, isRemote, opts.Compile); len(errs) > 0 {
189+
190+
var errs []error
191+
if opts.CancellableGeneration != nil && opts.CancellableGeneration.CancellableContext != nil {
192+
cancelCtx := opts.CancellableGeneration.CancellableContext
193+
194+
var cancelled bool
195+
cancelled, errs = g.GenerateWithCancel(cancelCtx, schema, opts.SchemaPath, opts.Language, opts.OutDir, isRemote, opts.Compile)
196+
if cancelled {
197+
return fmt.Errorf("Generation was aborted for %s ✖", opts.Language)
198+
}
199+
} else {
200+
errs = g.Generate(ctx, schema, opts.SchemaPath, opts.Language, opts.OutDir, isRemote, opts.Compile)
201+
}
202+
203+
if len(errs) > 0 {
161204
for _, err := range errs {
162205
logger.Error("", zap.Error(err))
163206
}
164207

165208
return fmt.Errorf("failed to generate SDKs for %s ✖", opts.Language)
166209
}
210+
167211
return nil
168212
})
169213
if err != nil {

0 commit comments

Comments
 (0)