Skip to content

Commit e4bf4c7

Browse files
hyperpolymathclaude
andcommitted
feat(gnn): close the training feedback loop — StatisticsTracker → Julia
Turns ECHIDNA into an online learning system by wiring proof outcomes from the Rust StatisticsTracker into the Julia GNN ML server so that accumulated evidence influences future premise ranking. Three parts: 1. StatisticsTracker::export_records() — exports all (prover, domain) stats as StatsSummaryRecord (new serialisable type) so the background sync task can push snapshots without holding the write lock. 2. Background GNN training sync (server.rs) — tokio::spawn task wakes every 60 s, checks if >= 10 new outcomes have accumulated since the last push, then POSTs a snapshot to Julia's new /training/update endpoint. Fire-and-forget: errors are logged at debug level. StatisticsTracker is now shared between MetaController (Bayesian routing) and the sync task via Arc<RwLock<StatisticsTracker>>. 3. Julia /training/update endpoint (gnn_endpoint.jl) — accepts [{ prover, domain, success_rate, … }] records, merges them into PROVER_DOMAIN_WEIGHTS, and applies the per-domain confidence as a [0.5, 1.0] score multiplier in rank_with_gnn() when the /gnn/rank caller provides domain_hints. GnnRankRequest gains a domain_hints field (default empty, so all existing callers are unaffected). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent cf94b49 commit e4bf4c7

6 files changed

Lines changed: 251 additions & 15 deletions

File tree

src/julia/api/gnn_endpoint.jl

Lines changed: 112 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ using LinearAlgebra
3131
const GNN_MODEL = Ref{Any}(nothing)
3232
const GNN_VOCAB = Ref{Any}(nothing)
3333

34+
# Per-(prover, domain) success-rate weights pushed from Rust via
35+
# POST /training/update. Format: weights[prover_name][domain] = success_rate.
36+
# Used to modulate premise scores in rank_with_gnn when domain_hints present.
37+
const PROVER_DOMAIN_WEIGHTS = Ref{Dict{String,Dict{String,Float64}}}(Dict())
38+
39+
# Running total of training records received since server start.
40+
const TOTAL_TRAINING_RECORDS = Ref{Int}(0)
41+
3442
"""
3543
load_gnn_model(models_dir::String)
3644
@@ -94,23 +102,110 @@ function parse_proof_graph(body)
94102
end
95103

96104
"""
97-
rank_with_gnn(g, node_features, goal_idx, premise_indices, config)
105+
rank_with_gnn(g, node_features, goal_idx, premise_indices, config, domain_hints)
98106
99107
Run GNN message passing + cross-attention scoring on the parsed graph.
100108
If a trained model is available, use it. Otherwise, use cosine similarity
101109
between the goal and premise node features as a fallback.
110+
111+
When `domain_hints` is non-empty and `PROVER_DOMAIN_WEIGHTS` has been
112+
populated by prior `/training/update` calls, premise scores are modulated
113+
by the aggregate domain confidence from accumulated proof outcomes.
102114
"""
103-
function rank_with_gnn(g, node_features, goal_idx, premise_indices, config)
115+
function rank_with_gnn(g, node_features, goal_idx, premise_indices, config, domain_hints=String[])
104116
model = GNN_MODEL[]
105117

106118
if model !== nothing
107-
# Use trained model for ranking
108-
# (Delegate to neural_solver.jl PremiseRanker when available)
109-
return rank_with_trained_model(model, g, node_features, goal_idx, premise_indices)
119+
scores, indices = rank_with_trained_model(model, g, node_features, goal_idx, premise_indices)
120+
else
121+
# Fallback: cosine similarity between goal and premise features
122+
scores, indices = rank_with_cosine(node_features, goal_idx, premise_indices)
123+
end
124+
125+
# Apply accumulated training weights when the caller provided domain hints.
126+
# This is a no-op until /training/update has been called at least once
127+
# and the rank request includes non-empty domain_hints.
128+
if !isempty(domain_hints) && !isempty(PROVER_DOMAIN_WEIGHTS[])
129+
scores = apply_domain_weights(scores, domain_hints)
110130
end
111131

112-
# Fallback: cosine similarity between goal and premise features
113-
return rank_with_cosine(node_features, goal_idx, premise_indices)
132+
return (scores, indices)
133+
end
134+
135+
"""
136+
apply_domain_weights(scores, domain_hints)
137+
138+
Scale premise scores by the mean success rate across all provers for the
139+
requested domain aspects. Uses a `[0.5, 1.0]` range so that even low-
140+
confidence domains retain half the base score rather than collapsing to zero.
141+
142+
When no training evidence exists for the requested domains, scores are
143+
returned unchanged.
144+
"""
145+
function apply_domain_weights(scores::Vector{Float32}, domain_hints::Vector{String})
146+
weights = PROVER_DOMAIN_WEIGHTS[]
147+
domain_rates = Float64[]
148+
for domain in domain_hints
149+
for (_, prover_weights) in weights
150+
if haskey(prover_weights, domain)
151+
push!(domain_rates, prover_weights[domain])
152+
end
153+
end
154+
end
155+
isempty(domain_rates) && return scores
156+
mean_confidence = Statistics.mean(domain_rates)
157+
scale = Float32(0.5 + 0.5 * mean_confidence)
158+
return scores .* scale
159+
end
160+
161+
"""
162+
handle_training_update(req::HTTP.Request)
163+
164+
POST /training/update — Receive proof-outcome statistics from the Rust server
165+
and update per-(prover, domain) success-rate weights used to modulate premise
166+
ranking scores.
167+
168+
Payload: `{ "records": [{ "prover", "domain", "attempts", "successes",
169+
"timeouts", "failures", "mean_time_ms", "success_rate" }] }`
170+
171+
The Rust `StatisticsTracker` is authoritative; Julia simply mirrors it so that
172+
the GNN ranking layer can incorporate proof-outcome evidence without a round-trip.
173+
"""
174+
function handle_training_update(req::HTTP.Request)
175+
try
176+
body = JSON3.read(String(req.body))
177+
records = get(body, :records, [])
178+
179+
weights = PROVER_DOMAIN_WEIGHTS[]
180+
n = 0
181+
for rec in records
182+
prover = string(get(rec, :prover, "Unknown"))
183+
domain = string(get(rec, :domain, "unspecified"))
184+
rate = Float64(get(rec, :success_rate, 0.0))
185+
if !haskey(weights, prover)
186+
weights[prover] = Dict{String,Float64}()
187+
end
188+
weights[prover][domain] = rate
189+
n += 1
190+
end
191+
PROVER_DOMAIN_WEIGHTS[] = weights
192+
TOTAL_TRAINING_RECORDS[] += n
193+
194+
@info "Training update: $n records (total=$(TOTAL_TRAINING_RECORDS[]))"
195+
196+
return HTTP.Response(200, JSON3.write(Dict(
197+
"status" => "ok",
198+
"records_received" => n,
199+
"total_records" => TOTAL_TRAINING_RECORDS[],
200+
"weights_updated" => n > 0
201+
)))
202+
catch e
203+
@error "Training update failed" exception=(e, catch_backtrace())
204+
return HTTP.Response(500, JSON3.write(Dict(
205+
"status" => "error",
206+
"error" => string(e)
207+
)))
208+
end
114209
end
115210

116211
"""
@@ -190,11 +285,14 @@ function handle_gnn_rank(req::HTTP.Request)
190285
top_k = get(body, :top_k, 20)
191286
min_score = get(body, :min_score, 0.05)
192287
include_embeddings = get(body, :include_embeddings, false)
288+
domain_hints = String.(get(body, :domain_hints, String[]))
193289

194-
# Run GNN ranking
290+
# Run GNN ranking; domain_hints allow weight-guided score modulation
291+
# when /training/update has populated PROVER_DOMAIN_WEIGHTS.
195292
(scores, indices) = rank_with_gnn(
196293
g, node_features, goal_idx, premise_indices,
197-
get(body, :config, nothing)
294+
get(body, :config, nothing),
295+
domain_hints
198296
)
199297

200298
# Sort by score (descending)
@@ -268,15 +366,17 @@ Call this from the main api_server.jl to enable GNN functionality.
268366
"""
269367
function register_gnn_routes!(existing_handler)
270368
@info "Registering GNN endpoints:"
271-
@info " POST /gnn/rank — Rank premises via GNN"
272-
@info " GET /gnn/health — GNN model status"
369+
@info " POST /gnn/rank — Rank premises via GNN"
370+
@info " GET /gnn/health — GNN model status"
371+
@info " POST /training/update — Receive proof-outcome stats from Rust"
273372

274-
# Return a combined handler that dispatches to GNN routes first
275373
function combined_handler(req::HTTP.Request)
276374
if req.target == "/gnn/rank" && req.method == "POST"
277375
return handle_gnn_rank(req)
278376
elseif req.target == "/gnn/health"
279377
return handle_gnn_health(req)
378+
elseif req.target == "/training/update" && req.method == "POST"
379+
return handle_training_update(req)
280380
else
281381
return existing_handler(req)
282382
end

src/rust/gnn/client.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ struct GnnRankRequest {
7272
include_embeddings: bool,
7373
/// GNN configuration hints for the server
7474
config: GnnServerHints,
75+
/// Optional domain-aspect tags for the goal (e.g. `["arithmetic.factorisation"]`).
76+
/// When non-empty, Julia uses accumulated training weights for these domains
77+
/// to modulate premise scores. Empty for backwards-compatible callers.
78+
#[serde(default)]
79+
domain_hints: Vec<String>,
7580
}
7681

7782
/// Serialised proof graph for the Julia server.
@@ -355,6 +360,7 @@ impl GnnClient {
355360
num_gnn_layers: self.config.num_gnn_layers,
356361
use_attention: self.config.use_attention,
357362
},
363+
domain_hints: vec![],
358364
}
359365
}
360366

src/rust/server.rs

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use echidna::agent::meta_controller::{MetaController, Plan};
1818
use echidna::agent::AgenticGoal;
1919
use echidna::core::{Goal, ProofState, Tactic, TacticResult, Term};
2020
use echidna::dispatch::ProverDispatcher;
21+
use echidna::verification::StatisticsTracker;
2122
use echidna::{ProverBackend, ProverConfig, ProverKind};
2223
use reqwest::Client;
2324
use serde::{Deserialize, Serialize};
@@ -28,7 +29,7 @@ use std::sync::Arc;
2829
use std::time::Duration;
2930
use tokio::sync::{Mutex, RwLock};
3031
use tower_http::cors::CorsLayer;
31-
use tracing::{info, instrument};
32+
use tracing::{debug, info, instrument};
3233

3334
/// Application state shared across handlers
3435
#[derive(Clone)]
@@ -44,6 +45,9 @@ struct AppState {
4445
/// state accumulates across the server's lifetime. All handlers that
4546
/// perform goal-aware dispatch share this single instance.
4647
meta_controller: Arc<MetaController>,
48+
/// Shared proof-outcome statistics used by MetaController for Bayesian
49+
/// routing and exported to Julia for GNN online learning.
50+
stats: Arc<RwLock<StatisticsTracker>>,
4751
}
4852

4953
/// A proof session
@@ -81,15 +85,59 @@ pub async fn start_server(port: u16, host: String, enable_cors: bool) -> Result<
8185
},
8286
}
8387

84-
let meta_controller = Arc::new(MetaController::new());
88+
// Shared stats tracker — MetaController uses it for Bayesian routing;
89+
// the background GNN sync task pushes snapshots to Julia for online learning.
90+
let stats = Arc::new(RwLock::new(StatisticsTracker::new()));
91+
let meta_controller = Arc::new(MetaController::new().with_stats(stats.clone()));
8592

8693
let state = AppState {
8794
sessions: Arc::new(RwLock::new(HashMap::new())),
8895
ml_client,
8996
ml_api_url,
9097
meta_controller,
98+
stats,
9199
};
92100

101+
// GNN training sync — background task that pushes accumulated proof-outcome
102+
// stats to Julia's /training/update endpoint every 60 seconds.
103+
// Only fires when >= 10 new outcomes have accumulated since the last push,
104+
// so idle servers incur no traffic. Errors are logged at debug level and
105+
// do not affect the main server loop (fire-and-forget).
106+
{
107+
let sync_stats = state.stats.clone();
108+
let sync_client = state.ml_client.clone();
109+
let sync_url = state.ml_api_url.clone();
110+
tokio::spawn(async move {
111+
let mut last_push_count: u64 = 0;
112+
let mut interval = tokio::time::interval(Duration::from_secs(60));
113+
interval.tick().await; // consume the immediate first tick; wait 60s before first push
114+
loop {
115+
interval.tick().await;
116+
let (total, records) = {
117+
let guard = sync_stats.read().await;
118+
(guard.total_attempts(), guard.export_records())
119+
};
120+
if total < last_push_count + 10 || records.is_empty() {
121+
continue;
122+
}
123+
last_push_count = total;
124+
let url = format!("{}/training/update", sync_url);
125+
let payload = json!({ "records": records });
126+
match sync_client.post(&url).json(&payload).send().await {
127+
Ok(resp) if resp.status().is_success() => {
128+
info!("GNN training sync: {} records pushed to Julia", records.len());
129+
},
130+
Ok(resp) => {
131+
debug!("GNN training sync: Julia returned {}", resp.status());
132+
},
133+
Err(e) => {
134+
debug!("GNN training sync: Julia unavailable ({})", e);
135+
},
136+
}
137+
}
138+
});
139+
}
140+
93141
// Build router
94142
let mut app = Router::new()
95143
// Groove discovery endpoint — returns capability manifest for service mesh.

src/rust/verification/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ pub use portfolio::{PortfolioConfig, PortfolioResult, PortfolioSolver};
3030
pub use proof::{
3131
theorem_identity, Proof, ProofStateRecord, ProofVersion, TacticApplication, TacticStatus,
3232
};
33-
pub use statistics::StatisticsTracker;
33+
pub use statistics::{StatsSummaryRecord, StatisticsTracker};

src/rust/verification/statistics.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,27 @@ impl ProverDomainStats {
143143
}
144144
}
145145

146+
/// Serialisable snapshot of one (prover, domain) stats pair.
147+
///
148+
/// Exported by `StatisticsTracker::export_records` and pushed to the Julia
149+
/// ML server by the background training-sync task so that online weight
150+
/// updates incorporate accumulated proof evidence.
151+
#[derive(Debug, Clone, Serialize, Deserialize)]
152+
pub struct StatsSummaryRecord {
153+
/// Prover name — Debug-format of `ProverKind` (e.g. `"Z3"`, `"Lean"`).
154+
pub prover: String,
155+
/// Domain tag (e.g. `"arithmetic.factorisation"`).
156+
pub domain: String,
157+
pub attempts: u64,
158+
pub successes: u64,
159+
pub timeouts: u64,
160+
pub failures: u64,
161+
/// Mean proof time over successful attempts (ms). `None` if no successes.
162+
pub mean_time_ms: Option<f64>,
163+
/// Fraction of successful attempts `[0.0, 1.0]`.
164+
pub success_rate: f64,
165+
}
166+
146167
/// Tracks statistics across all provers and domains
147168
#[derive(Debug, Clone, Serialize, Deserialize)]
148169
pub struct StatisticsTracker {
@@ -322,6 +343,34 @@ impl StatisticsTracker {
322343
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
323344
serde_json::from_str(json)
324345
}
346+
347+
/// Export every (prover, domain) entry as a flat `Vec<StatsSummaryRecord>`.
348+
///
349+
/// Used by the background GNN training-sync task to push accumulated proof
350+
/// evidence to the Julia ML server (`POST /training/update`). The server
351+
/// uses these records to update per-(prover, domain) success-rate weights
352+
/// that modulate premise ranking scores.
353+
pub fn export_records(&self) -> Vec<StatsSummaryRecord> {
354+
self.stats
355+
.iter()
356+
.map(|(key, stats)| {
357+
// Key format produced by `make_key`: "ProverDebug::domain"
358+
let mut parts = key.splitn(2, "::");
359+
let prover = parts.next().unwrap_or("Unknown").to_string();
360+
let domain = parts.next().unwrap_or("unspecified").to_string();
361+
StatsSummaryRecord {
362+
prover,
363+
domain,
364+
attempts: stats.attempts,
365+
successes: stats.successes,
366+
timeouts: stats.timeouts,
367+
failures: stats.failures,
368+
mean_time_ms: stats.mean_time_ms(),
369+
success_rate: stats.success_rate(),
370+
}
371+
})
372+
.collect()
373+
}
325374
}
326375

327376
#[cfg(test)]
@@ -446,6 +495,35 @@ mod tests {
446495
assert_eq!(upper, 1.0);
447496
}
448497

498+
#[test]
499+
fn test_export_records() {
500+
let mut tracker = StatisticsTracker::new();
501+
tracker.record_success(ProverKind::Z3, "arithmetic", 100);
502+
tracker.record_success(ProverKind::Z3, "arithmetic", 200);
503+
tracker.record_failure(ProverKind::Lean, "topology");
504+
505+
let records = tracker.export_records();
506+
assert_eq!(records.len(), 2);
507+
508+
let z3 = records
509+
.iter()
510+
.find(|r| r.prover == "Z3" && r.domain == "arithmetic")
511+
.expect("Z3::arithmetic record present");
512+
assert_eq!(z3.successes, 2);
513+
assert_eq!(z3.attempts, 2);
514+
assert!((z3.success_rate - 1.0).abs() < 1e-10);
515+
assert!(z3.mean_time_ms.is_some());
516+
517+
let lean = records
518+
.iter()
519+
.find(|r| r.prover == "Lean" && r.domain == "topology")
520+
.expect("Lean::topology record present");
521+
assert_eq!(lean.failures, 1);
522+
assert_eq!(lean.successes, 0);
523+
assert_eq!(lean.success_rate, 0.0);
524+
assert!(lean.mean_time_ms.is_none());
525+
}
526+
449527
#[test]
450528
fn test_serialization_roundtrip() {
451529
let mut tracker = StatisticsTracker::new();

0 commit comments

Comments
 (0)