@@ -31,6 +31,14 @@ using LinearAlgebra
3131const GNN_MODEL = Ref {Any} (nothing )
3232const GNN_VOCAB = Ref {Any} (nothing )
3333
34+ # Per-(prover, domain) success-rate weights pushed from Rust via
35+ # POST /training/update. Format: weights[prover_name][domain] = success_rate.
36+ # Used to modulate premise scores in rank_with_gnn when domain_hints present.
37+ const PROVER_DOMAIN_WEIGHTS = Ref {Dict{String,Dict{String,Float64}}} (Dict ())
38+
39+ # Running total of training records received since server start.
40+ const TOTAL_TRAINING_RECORDS = Ref {Int} (0 )
41+
3442"""
3543 load_gnn_model(models_dir::String)
3644
@@ -94,23 +102,110 @@ function parse_proof_graph(body)
94102end
95103
96104"""
97- rank_with_gnn(g, node_features, goal_idx, premise_indices, config)
105+ rank_with_gnn(g, node_features, goal_idx, premise_indices, config, domain_hints )
98106
99107Run GNN message passing + cross-attention scoring on the parsed graph.
100108If a trained model is available, use it. Otherwise, use cosine similarity
101109between the goal and premise node features as a fallback.
110+
111+ When `domain_hints` is non-empty and `PROVER_DOMAIN_WEIGHTS` has been
112+ populated by prior `/training/update` calls, premise scores are modulated
113+ by the aggregate domain confidence from accumulated proof outcomes.
102114"""
103- function rank_with_gnn (g, node_features, goal_idx, premise_indices, config)
115+ function rank_with_gnn (g, node_features, goal_idx, premise_indices, config, domain_hints = String[] )
104116 model = GNN_MODEL[]
105117
106118 if model != = nothing
107- # Use trained model for ranking
108- # (Delegate to neural_solver.jl PremiseRanker when available)
109- return rank_with_trained_model (model, g, node_features, goal_idx, premise_indices)
119+ scores, indices = rank_with_trained_model (model, g, node_features, goal_idx, premise_indices)
120+ else
121+ # Fallback: cosine similarity between goal and premise features
122+ scores, indices = rank_with_cosine (node_features, goal_idx, premise_indices)
123+ end
124+
125+ # Apply accumulated training weights when the caller provided domain hints.
126+ # This is a no-op until /training/update has been called at least once
127+ # and the rank request includes non-empty domain_hints.
128+ if ! isempty (domain_hints) && ! isempty (PROVER_DOMAIN_WEIGHTS[])
129+ scores = apply_domain_weights (scores, domain_hints)
110130 end
111131
112- # Fallback: cosine similarity between goal and premise features
113- return rank_with_cosine (node_features, goal_idx, premise_indices)
132+ return (scores, indices)
133+ end
134+
135+ """
136+ apply_domain_weights(scores, domain_hints)
137+
138+ Scale premise scores by the mean success rate across all provers for the
139+ requested domain aspects. Uses a `[0.5, 1.0]` range so that even low-
140+ confidence domains retain half the base score rather than collapsing to zero.
141+
142+ When no training evidence exists for the requested domains, scores are
143+ returned unchanged.
144+ """
145+ function apply_domain_weights (scores:: Vector{Float32} , domain_hints:: Vector{String} )
146+ weights = PROVER_DOMAIN_WEIGHTS[]
147+ domain_rates = Float64[]
148+ for domain in domain_hints
149+ for (_, prover_weights) in weights
150+ if haskey (prover_weights, domain)
151+ push! (domain_rates, prover_weights[domain])
152+ end
153+ end
154+ end
155+ isempty (domain_rates) && return scores
156+ mean_confidence = Statistics. mean (domain_rates)
157+ scale = Float32 (0.5 + 0.5 * mean_confidence)
158+ return scores .* scale
159+ end
160+
161+ """
162+ handle_training_update(req::HTTP.Request)
163+
164+ POST /training/update — Receive proof-outcome statistics from the Rust server
165+ and update per-(prover, domain) success-rate weights used to modulate premise
166+ ranking scores.
167+
168+ Payload: `{ "records": [{ "prover", "domain", "attempts", "successes",
169+ "timeouts", "failures", "mean_time_ms", "success_rate" }] }`
170+
171+ The Rust `StatisticsTracker` is authoritative; Julia simply mirrors it so that
172+ the GNN ranking layer can incorporate proof-outcome evidence without a round-trip.
173+ """
174+ function handle_training_update (req:: HTTP.Request )
175+ try
176+ body = JSON3. read (String (req. body))
177+ records = get (body, :records , [])
178+
179+ weights = PROVER_DOMAIN_WEIGHTS[]
180+ n = 0
181+ for rec in records
182+ prover = string (get (rec, :prover , " Unknown" ))
183+ domain = string (get (rec, :domain , " unspecified" ))
184+ rate = Float64 (get (rec, :success_rate , 0.0 ))
185+ if ! haskey (weights, prover)
186+ weights[prover] = Dict {String,Float64} ()
187+ end
188+ weights[prover][domain] = rate
189+ n += 1
190+ end
191+ PROVER_DOMAIN_WEIGHTS[] = weights
192+ TOTAL_TRAINING_RECORDS[] += n
193+
194+ @info " Training update: $n records (total=$(TOTAL_TRAINING_RECORDS[]) )"
195+
196+ return HTTP. Response (200 , JSON3. write (Dict (
197+ " status" => " ok" ,
198+ " records_received" => n,
199+ " total_records" => TOTAL_TRAINING_RECORDS[],
200+ " weights_updated" => n > 0
201+ )))
202+ catch e
203+ @error " Training update failed" exception= (e, catch_backtrace ())
204+ return HTTP. Response (500 , JSON3. write (Dict (
205+ " status" => " error" ,
206+ " error" => string (e)
207+ )))
208+ end
114209end
115210
116211"""
@@ -190,11 +285,14 @@ function handle_gnn_rank(req::HTTP.Request)
190285 top_k = get (body, :top_k , 20 )
191286 min_score = get (body, :min_score , 0.05 )
192287 include_embeddings = get (body, :include_embeddings , false )
288+ domain_hints = String .(get (body, :domain_hints , String[]))
193289
194- # Run GNN ranking
290+ # Run GNN ranking; domain_hints allow weight-guided score modulation
291+ # when /training/update has populated PROVER_DOMAIN_WEIGHTS.
195292 (scores, indices) = rank_with_gnn (
196293 g, node_features, goal_idx, premise_indices,
197- get (body, :config , nothing )
294+ get (body, :config , nothing ),
295+ domain_hints
198296 )
199297
200298 # Sort by score (descending)
@@ -268,15 +366,17 @@ Call this from the main api_server.jl to enable GNN functionality.
268366"""
269367function register_gnn_routes! (existing_handler)
270368 @info " Registering GNN endpoints:"
271- @info " POST /gnn/rank — Rank premises via GNN"
272- @info " GET /gnn/health — GNN model status"
369+ @info " POST /gnn/rank — Rank premises via GNN"
370+ @info " GET /gnn/health — GNN model status"
371+ @info " POST /training/update — Receive proof-outcome stats from Rust"
273372
274- # Return a combined handler that dispatches to GNN routes first
275373 function combined_handler (req:: HTTP.Request )
276374 if req. target == " /gnn/rank" && req. method == " POST"
277375 return handle_gnn_rank (req)
278376 elseif req. target == " /gnn/health"
279377 return handle_gnn_health (req)
378+ elseif req. target == " /training/update" && req. method == " POST"
379+ return handle_training_update (req)
280380 else
281381 return existing_handler (req)
282382 end
0 commit comments