-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcmd_generate.go
More file actions
399 lines (356 loc) · 15 KB
/
cmd_generate.go
File metadata and controls
399 lines (356 loc) · 15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
// ===========================================================================
// TEXT GENERATION CLI - Interactive and Batch Inference
// ===========================================================================
//
// This file implements the command-line interface for text generation using
// trained GPT models. It demonstrates the full inference pipeline: loading
// models, tokenizing text, generating tokens, and decoding back to text.
//
// INTENTION:
// Provide two modes of interaction:
// 1. Single-shot generation: Quick text completion for scripts/automation
// 2. Interactive REPL: Experimentation with sampling parameters in real-time
//
// WHY TWO MODES?
// - Single-shot is ideal for production use, scripting, benchmarking
// - Interactive mode is crucial for development: you need to experiment with
// temperature, top-k, top-p to find settings that work for your model/task
// - Interactive mode lets you iterate quickly without reloading the model
// (model loading can take seconds, experimentation requires dozens of tries)
//
// KEY CONCEPTS:
//
// 1. SAMPLING PARAMETERS (Control output randomness/quality)
// - Temperature: Controls randomness of predictions
// * 0.0 = greedy (always pick most likely token) - deterministic but boring
// * 0.7 = focused (good for factual/code generation)
// * 1.0 = neutral (match training distribution)
// * 1.5+ = creative (diverse but potentially incoherent)
// - Top-k: Only sample from k most likely tokens
// * Prevents sampling very unlikely tokens (reduces nonsense)
// * 40 is a good default (empirically determined by GPT-2 paper)
// - Top-p (nucleus): Sample from smallest set of tokens with cumulative probability p
// * More adaptive than top-k (adjusts set size based on confidence)
// * 0.9 is a good default (keeps quality high while allowing diversity)
//
// 2. BACKEND SELECTION (See backend.go for details)
// Why configurable? Because performance varies drastically by hardware:
// - M4 Max: Accelerate framework gives 10-20x speedup over naive
// - Linux ARM: SVE or OpenBLAS are best options
// - NVIDIA: CUDA backend required for GPU acceleration
// The "auto" backend will eventually auto-detect, but manual selection
// lets you benchmark different approaches on the same hardware.
//
// 3. MODEL/TOKENIZER LOADING
// Why separate files?
// - Models are large (MBs-GBs), tokenizers are small (KBs)
// - Different tokenizers can be used with same model architecture
// - Allows sharing tokenizers across models (save disk space)
//
// USAGE EXAMPLES:
//
// Single prompt:
// go run . generate -model=model.bin -tokenizer=tok.bin -prompt="Hello"
//
// Interactive experimentation:
// go run . generate -model=model.bin -tokenizer=tok.bin -interactive
// > Hello world
// > /temp 1.2 # Make it more creative
// > /tokens 50 # Generate longer text
// > Hello world # Try again with new settings
//
// Production use with specific backend:
// go run . generate -model=model.bin -tokenizer=tok.bin \
// -prompt="func sum(a, b int)" -backend=accelerate -temperature=0.3
//
// ===========================================================================
package main
import (
"bufio"
"flag"
"fmt"
"os"
"strings"
)
// RunGenerateCommand implements the text generation CLI.
//
// This is the main entry point for text generation. It:
// 1. Parses command-line flags
// 2. Loads model and tokenizer (the slow part - can take seconds)
// 3. Configures compute backend and sampling parameters
// 4. Dispatches to either interactive or single-shot mode
//
// Design decision: Load model ONCE, then reuse for multiple generations.
// This is why interactive mode is so much faster than repeatedly running
// single-shot mode - you pay the loading cost once, not per generation.
func RunGenerateCommand(args []string) error {
fs := flag.NewFlagSet("generate", flag.ExitOnError)
// Model and tokenizer paths (required)
// Why separate files? See header comments for rationale.
modelPath := fs.String("model", "", "Path to saved model file (required)")
tokenizerPath := fs.String("tokenizer", "", "Path to saved tokenizer file (required)")
// Generation parameters
// Why both prompt and interactive? Different use cases:
// - prompt: for scripting, automation, benchmarking
// - interactive: for experimentation, debugging, finding good parameters
prompt := fs.String("prompt", "", "Text prompt for generation")
interactive := fs.Bool("interactive", false, "Interactive mode (REPL)")
maxTokens := fs.Int("max-tokens", 100, "Maximum number of tokens to generate")
// Sampling parameters (control output randomness/quality)
// These defaults (0.8, 40, 0.9) are empirically good for most tasks.
// See header comments for detailed explanation of each parameter.
temperature := fs.Float64("temperature", 0.8, "Temperature for sampling (0=greedy, higher=more random)")
topK := fs.Int("top-k", 40, "Top-k sampling (0=disabled)")
topP := fs.Float64("top-p", 0.9, "Top-p (nucleus) sampling (0=disabled)")
// Backend selection (affects performance, not output)
// "auto" will eventually auto-detect best backend, currently uses naive.
// Manual selection lets you benchmark or force specific hardware usage.
backend := fs.String("backend", "auto", "Compute backend: auto, naive, accelerate, metal, cuda, sve, openblas")
if err := fs.Parse(args); err != nil {
return err
}
// Validate required arguments
if *modelPath == "" {
return fmt.Errorf("--model is required")
}
if *tokenizerPath == "" {
return fmt.Errorf("--tokenizer is required")
}
// Load model (the slow part - can take seconds for large models)
// Why so slow? Need to read MBs-GBs from disk and allocate all weight matrices.
// This is why we load once and reuse for multiple generations.
fmt.Printf("Loading model from %s...\n", *modelPath)
model, err := LoadGPT(*modelPath)
if err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
fmt.Printf("✓ Model loaded (vocab=%d, dim=%d, layers=%d)\n",
model.config.VocabSize, model.config.EmbedDim, model.config.NumLayers)
// Set backend if specified
// Backend selection determines which hardware accelerates the matrix multiplications.
// On failure, we fall back to naive implementation - slower but always works.
// Why not fail hard? Because we want text generation to work even if hardware
// acceleration isn't available (e.g., wrong driver version, missing libraries).
if *backend != "auto" && *backend != "naive" {
if err := setupBackend(model, *backend); err != nil {
fmt.Fprintf(os.Stderr, "Warning: Failed to setup backend '%s': %v\n", *backend, err)
fmt.Fprintf(os.Stderr, "Falling back to naive implementation\n")
}
}
// Load tokenizer
// Much faster than model loading (KBs vs MBs-GBs).
// Tokenizer converts text <-> token IDs, required for both input and output.
fmt.Printf("Loading tokenizer from %s...\n", *tokenizerPath)
tokenizer := NewTokenizer()
if err := tokenizer.Load(*tokenizerPath); err != nil {
return fmt.Errorf("failed to load tokenizer: %w", err)
}
fmt.Printf("✓ Tokenizer loaded (vocab size=%d)\n", tokenizer.VocabSize())
// Configure sampling
// These parameters stay constant for the initial run, but can be changed
// interactively in REPL mode using /temp, /topk, /topp commands.
samplingConfig := &SampleConfig{
Temperature: *temperature,
TopK: *topK,
TopP: *topP,
}
fmt.Println()
fmt.Printf("Generation settings:\n")
fmt.Printf(" Max tokens: %d\n", *maxTokens)
fmt.Printf(" Temperature: %.2f\n", samplingConfig.Temperature)
fmt.Printf(" Top-k: %d\n", samplingConfig.TopK)
fmt.Printf(" Top-p: %.2f\n", samplingConfig.TopP)
fmt.Println()
if *interactive {
// Interactive mode (REPL)
return runInteractive(model, tokenizer, *maxTokens, samplingConfig)
}
// Single prompt mode
if *prompt == "" {
return fmt.Errorf("either --prompt or --interactive is required")
}
return generateText(model, tokenizer, *prompt, *maxTokens, samplingConfig)
}
// generateText generates text from a single prompt.
//
// This demonstrates the complete inference pipeline:
// 1. Encode: Convert text -> token IDs (using BPE tokenizer)
// 2. Generate: Autoregressively predict next tokens (using transformer model)
// 3. Decode: Convert token IDs -> text (using BPE tokenizer)
//
// Why show token IDs? For debugging and understanding. When generation goes wrong,
// seeing the token IDs helps you understand if the problem is in tokenization
// (wrong token IDs) or generation (right tokens, wrong predictions).
func generateText(model *GPT, tokenizer *Tokenizer, promptText string, maxTokens int, config *SampleConfig) error {
fmt.Printf("Prompt: %s\n", promptText)
fmt.Println()
// Encode prompt: text -> token IDs
// BPE tokenizer breaks text into subword units (e.g., "hello" might be one token,
// "helloworld" might be ["hello", "world"], or ["hel", "low", "or", "ld"]).
promptTokens := tokenizer.Encode(promptText)
if len(promptTokens) == 0 {
return fmt.Errorf("prompt encoding resulted in zero tokens")
}
fmt.Printf("Encoded to %d tokens: %v\n", len(promptTokens), promptTokens)
fmt.Println()
// Generate: Autoregressive next-token prediction
// The model generates one token at a time, feeding each new token back as input.
// This is why generation is slow - it's inherently sequential, can't be parallelized.
// Temperature/top-k/top-p control the randomness of the sampling process.
fmt.Println("Generating...")
generatedTokens := model.GenerateWithSampling(promptTokens, maxTokens, config)
// Decode: token IDs -> text
// Convert the full sequence (prompt + generated tokens) back to readable text.
generatedText := tokenizer.Decode(generatedTokens)
fmt.Println()
fmt.Println("=== Generated Text ===")
fmt.Println(generatedText)
fmt.Println()
fmt.Printf("Generated %d total tokens\n", len(generatedTokens))
return nil
}
// runInteractive runs an interactive text generation REPL.
//
// WHY REPL MODE?
// When developing with LLMs, you need to experiment with different prompts and
// sampling parameters to find what works. Reloading the model each time is too slow
// (can take seconds). REPL mode loads once and lets you iterate quickly.
//
// DESIGN: Slash commands for configuration
// We use "/command" syntax (like Discord, Slack) instead of special flags because:
// - It's familiar to users of modern chat apps
// - It's unambiguous (any line starting with "/" is a command, not a prompt)
// - It allows natural language prompts without escaping
//
// The config pointer is shared with generateText, so changes apply immediately.
func runInteractive(model *GPT, tokenizer *Tokenizer, maxTokens int, config *SampleConfig) error {
fmt.Println("=== Interactive Mode ===")
fmt.Println("Enter prompts to generate text. Type 'quit' or 'exit' to stop.")
fmt.Println("Commands:")
fmt.Println(" /temp <value> Set temperature (e.g., /temp 0.8)")
fmt.Println(" /topk <value> Set top-k (e.g., /topk 40)")
fmt.Println(" /topp <value> Set top-p (e.g., /topp 0.9)")
fmt.Println(" /tokens <value> Set max tokens (e.g., /tokens 50)")
fmt.Println(" /config Show current settings")
fmt.Println()
scanner := bufio.NewScanner(os.Stdin)
for {
fmt.Print("> ")
if !scanner.Scan() {
break
}
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
// Check for exit commands
if line == "quit" || line == "exit" {
fmt.Println("Goodbye!")
return nil
}
// Check for configuration commands
// Slash commands let you adjust sampling parameters without restarting.
// This is crucial for experimentation: you can see how temperature affects
// output quality in real-time.
if strings.HasPrefix(line, "/") {
if err := handleCommand(line, config, &maxTokens); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
}
continue
}
// Generate text from user input
// Each generation uses the current config values, so changes from slash
// commands take effect immediately.
if err := generateText(model, tokenizer, line, maxTokens, config); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
}
fmt.Println()
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("error reading input: %w", err)
}
return nil
}
// handleCommand handles interactive mode commands.
//
// This parses and executes slash commands that modify generation parameters.
// Why mutable pointers? So changes persist across multiple generations.
// The config pointer is shared with generateText(), so modifications here
// immediately affect subsequent generations.
//
// Design note: We could have returned modified values instead of mutating,
// but pointer mutation makes the REPL simpler - no need to thread updated
// config through the call stack.
func handleCommand(cmd string, config *SampleConfig, maxTokens *int) error {
parts := strings.Fields(cmd)
if len(parts) == 0 {
return nil
}
switch parts[0] {
case "/temp":
if len(parts) < 2 {
return fmt.Errorf("usage: /temp <value>")
}
var val float64
if _, err := fmt.Sscanf(parts[1], "%f", &val); err != nil {
return fmt.Errorf("invalid temperature value: %v", err)
}
config.Temperature = val
fmt.Printf("Temperature set to %.2f\n", val)
case "/topk":
if len(parts) < 2 {
return fmt.Errorf("usage: /topk <value>")
}
var val int
if _, err := fmt.Sscanf(parts[1], "%d", &val); err != nil {
return fmt.Errorf("invalid top-k value: %v", err)
}
config.TopK = val
fmt.Printf("Top-k set to %d\n", val)
case "/topp":
if len(parts) < 2 {
return fmt.Errorf("usage: /topp <value>")
}
var val float64
if _, err := fmt.Sscanf(parts[1], "%f", &val); err != nil {
return fmt.Errorf("invalid top-p value: %v", err)
}
config.TopP = val
fmt.Printf("Top-p set to %.2f\n", val)
case "/tokens":
if len(parts) < 2 {
return fmt.Errorf("usage: /tokens <value>")
}
var val int
if _, err := fmt.Sscanf(parts[1], "%d", &val); err != nil {
return fmt.Errorf("invalid max tokens value: %v", err)
}
*maxTokens = val
fmt.Printf("Max tokens set to %d\n", val)
case "/config":
fmt.Printf("Current settings:\n")
fmt.Printf(" Temperature: %.2f\n", config.Temperature)
fmt.Printf(" Top-k: %d\n", config.TopK)
fmt.Printf(" Top-p: %.2f\n", config.TopP)
fmt.Printf(" Max tokens: %d\n", *maxTokens)
default:
return fmt.Errorf("unknown command: %s", parts[0])
}
return nil
}
// setupBackend validates backend selection for pure Go implementation.
//
// This pure Go version only supports CPU execution. The backend parameter
// is kept for command-line compatibility but currently does nothing.
//
// For production inference at scale, consider:
// - Quantization (int8/int4) for smaller models
// - Batch processing for higher throughput
// - External frameworks with GPU support (ONNX Runtime, TensorFlow Lite)
func setupBackend(model *GPT, backendName string) error {
if backendName != "" && backendName != "cpu" {
return fmt.Errorf("backend %q not available - only 'cpu' supported in pure Go build", backendName)
}
fmt.Printf("✓ Using pure Go CPU backend\n")
return nil
}