1
- use common:: indexmap:: { IndexMap , IndexSet } ;
1
+ use common:: indexmap:: IndexSet ;
2
2
use hir:: hir_def:: { scope_graph:: ScopeId , IdentId , Trait } ;
3
- use itertools:: Itertools ;
4
3
use rustc_hash:: FxHashSet ;
5
4
use thin_vec:: ThinVec ;
6
5
7
6
use crate :: {
8
7
name_resolution:: { available_traits_in_scope, is_scope_visible_from} ,
9
8
ty:: {
9
+ binder:: Binder ,
10
10
canonical:: { Canonical , Canonicalized , Solution } ,
11
- fold:: TyFoldable ,
12
11
func_def:: FuncDef ,
13
12
method_table:: probe_method,
14
13
trait_def:: { impls_for_ty, TraitDef , TraitInstId , TraitMethod } ,
@@ -56,12 +55,14 @@ pub(crate) fn select_method_candidate<'db>(
56
55
method_name : IdentId < ' db > ,
57
56
scope : ScopeId < ' db > ,
58
57
assumptions : PredicateListId < ' db > ,
58
+ trait_ : Option < TraitDef < ' db > > ,
59
59
) -> Result < MethodCandidate < ' db > , MethodSelectionError < ' db > > {
60
60
if receiver. value . is_ty_var ( db) {
61
61
return Err ( MethodSelectionError :: ReceiverTypeMustBeKnown ) ;
62
62
}
63
63
64
- let candidates = assemble_method_candidates ( db, receiver, method_name, scope, assumptions) ;
64
+ let candidates =
65
+ assemble_method_candidates ( db, receiver, method_name, scope, assumptions, trait_) ;
65
66
66
67
let selector = MethodSelector {
67
68
db,
@@ -80,13 +81,15 @@ fn assemble_method_candidates<'db>(
80
81
method_name : IdentId < ' db > ,
81
82
scope : ScopeId < ' db > ,
82
83
assumptions : PredicateListId < ' db > ,
84
+ trait_ : Option < TraitDef < ' db > > ,
83
85
) -> AssembledCandidates < ' db > {
84
86
CandidateAssembler {
85
87
db,
86
88
receiver_ty,
87
89
method_name,
88
90
scope,
89
91
assumptions,
92
+ trait_,
90
93
candidates : AssembledCandidates :: default ( ) ,
91
94
}
92
95
. assemble ( )
@@ -102,12 +105,15 @@ struct CandidateAssembler<'db> {
102
105
scope : ScopeId < ' db > ,
103
106
/// The assumptions for the type bound in the current scope.
104
107
assumptions : PredicateListId < ' db > ,
108
+ trait_ : Option < TraitDef < ' db > > ,
105
109
candidates : AssembledCandidates < ' db > ,
106
110
}
107
111
108
112
impl < ' db > CandidateAssembler < ' db > {
109
113
fn assemble ( mut self ) -> AssembledCandidates < ' db > {
110
- self . assemble_inherent_method_candidates ( ) ;
114
+ if self . trait_ . is_none ( ) {
115
+ self . assemble_inherent_method_candidates ( ) ;
116
+ }
111
117
self . assemble_trait_method_candidates ( ) ;
112
118
self . candidates
113
119
}
@@ -125,45 +131,42 @@ impl<'db> CandidateAssembler<'db> {
125
131
126
132
fn assemble_trait_method_candidates ( & mut self ) {
127
133
let ingot = self . scope . ingot ( self . db ) ;
128
- let mut table = UnificationTable :: new ( self . db ) ;
129
- let extracted_receiver_ty = self . receiver_ty . extract_identity ( & mut table) ;
130
134
131
- for & implementor in impls_for_ty ( self . db , ingot, self . receiver_ty ) {
132
- let trait_def = implementor. skip_binder ( ) . trait_def ( self . db ) ;
133
- self . insert_trait_method_cand ( trait_def)
135
+ for & imp in impls_for_ty ( self . db , ingot, self . receiver_ty ) {
136
+ self . insert_trait_method_cand ( imp. skip_binder ( ) . trait_ ( self . db ) ) ;
134
137
}
135
138
139
+ let mut table = UnificationTable :: new ( self . db ) ;
140
+ let extracted_receiver_ty = self . receiver_ty . extract_identity ( & mut table) ;
141
+
136
142
for & pred in self . assumptions . list ( self . db ) {
137
143
let snapshot = table. snapshot ( ) ;
138
144
let self_ty = pred. self_ty ( self . db ) ;
139
145
let self_ty = table. instantiate_to_term ( self_ty) ;
140
146
141
147
if table. unify ( extracted_receiver_ty, self_ty) . is_ok ( ) {
142
- self . insert_trait_method_cand_with_inst ( pred . def ( self . db ) , pred) ;
148
+ self . insert_trait_method_cand ( pred) ;
143
149
for super_trait in pred. def ( self . db ) . super_traits ( self . db ) {
144
150
let super_trait = super_trait. instantiate ( self . db , pred. args ( self . db ) ) ;
145
- self . insert_trait_method_cand_with_inst ( super_trait . def ( self . db ) , super_trait) ;
151
+ self . insert_trait_method_cand ( super_trait) ;
146
152
}
147
153
}
148
154
149
155
table. rollback_to ( snapshot) ;
150
156
}
151
157
}
152
158
153
- fn insert_trait_method_cand ( & mut self , trait_def : TraitDef < ' db > ) {
154
- if let Some ( & trait_method) = trait_def. methods ( self . db ) . get ( & self . method_name ) {
155
- self . candidates . insert_trait ( trait_def, trait_method) ;
156
- }
159
+ fn allow_trait ( & self , trait_def : TraitDef < ' db > ) -> bool {
160
+ self . trait_ . map ( |t| t == trait_def) . unwrap_or ( true )
157
161
}
158
162
159
- fn insert_trait_method_cand_with_inst (
160
- & mut self ,
161
- trait_def : TraitDef < ' db > ,
162
- trait_inst : TraitInstId < ' db > ,
163
- ) {
163
+ fn insert_trait_method_cand ( & mut self , inst : TraitInstId < ' db > ) {
164
+ let trait_def = inst . def ( self . db ) ;
165
+ if ! self . allow_trait ( trait_def) {
166
+ return ;
167
+ }
164
168
if let Some ( & trait_method) = trait_def. methods ( self . db ) . get ( & self . method_name ) {
165
- self . candidates
166
- . insert_trait_with_inst ( trait_def, trait_method, trait_inst) ;
169
+ self . candidates . traits . insert ( ( inst, trait_method) ) ;
167
170
}
168
171
}
169
172
}
@@ -234,15 +237,15 @@ impl<'db> MethodSelector<'db> {
234
237
let traits = & self . candidates . traits ;
235
238
236
239
if traits. len ( ) == 1 {
237
- let ( def , method) = traits. iter ( ) . next ( ) . unwrap ( ) ;
238
- return Ok ( self . find_inst ( * def , * method) ) ;
240
+ let ( inst , method) = traits. iter ( ) . next ( ) . unwrap ( ) ;
241
+ return Ok ( self . check_inst ( * inst , * method) ) ;
239
242
}
240
243
241
244
let available_traits = self . available_traits ( ) ;
242
245
let visible_traits: Vec < _ > = traits
243
246
. iter ( )
244
247
. copied ( )
245
- . filter ( |cand | available_traits. contains ( & cand . 0 ) )
248
+ . filter ( |( inst , _method ) | available_traits. contains ( & inst . def ( self . db ) ) )
246
249
. collect ( ) ;
247
250
248
251
match visible_traits. len ( ) {
@@ -251,14 +254,17 @@ impl<'db> MethodSelector<'db> {
251
254
Err ( MethodSelectionError :: NotFound )
252
255
} else {
253
256
// Suggests trait imports.
254
- let traits = traits. iter ( ) . map ( |( def, _) | def. trait_ ( self . db ) ) . collect ( ) ;
257
+ let traits = traits
258
+ . iter ( )
259
+ . map ( |( inst, _) | inst. def ( self . db ) . trait_ ( self . db ) )
260
+ . collect ( ) ;
255
261
Err ( MethodSelectionError :: InvisibleTraitMethod ( traits) )
256
262
}
257
263
}
258
264
259
265
1 => {
260
266
let ( def, method) = visible_traits[ 0 ] ;
261
- Ok ( self . find_inst ( def, method) )
267
+ Ok ( self . check_inst ( def, method) )
262
268
}
263
269
264
270
_ => Err ( MethodSelectionError :: AmbiguousTraitMethod (
@@ -275,39 +281,14 @@ impl<'db> MethodSelector<'db> {
275
281
/// checks if the goal is satisfiable given the current assumptions.
276
282
/// Depending on the result, it either returns a confirmed trait method
277
283
/// candidate or one that needs further confirmation.
278
- ///
279
- /// # Arguments
280
- ///
281
- /// * `def` - The trait definition.
282
- /// * `method` - The trait method.
283
- ///
284
- /// # Returns
285
- ///
286
- /// A `Candidate` representing the found trait method instance.
287
- fn find_inst ( & self , def : TraitDef < ' db > , method : TraitMethod < ' db > ) -> MethodCandidate < ' db > {
284
+ fn check_inst ( & self , inst : TraitInstId < ' db > , method : TraitMethod < ' db > ) -> MethodCandidate < ' db > {
288
285
let mut table = UnificationTable :: new ( self . db ) ;
289
- let receiver = self . receiver . extract_identity ( & mut table) ;
290
- let receiver = table. instantiate_to_term ( receiver) ; // xxx remove?
291
-
292
- // Check if we have a stored trait instance with associated type bindings
293
- let cand = if let Some ( & stored_inst) = self . candidates . trait_instances . get ( & ( def, method) ) {
294
- // Use the stored instance which includes associated type bindings
295
- stored_inst
296
- } else {
297
- // Create a fresh instance without bindings
298
- let inst_args = def
299
- . params ( self . db )
300
- . iter ( )
301
- . map ( |ty| table. new_var_from_param ( * ty) )
302
- . collect_vec ( ) ;
303
- TraitInstId :: new ( self . db , def, inst_args, IndexMap :: new ( ) )
304
- } ;
305
-
306
- // Unify receiver and method self.
307
- method. instantiate_with_inst ( & mut table, receiver, cand) ;
286
+ // Seed the table with receiver's canonical variables so that subsequent
287
+ // canonicalization can safely probe them.
288
+ let _ = self . receiver . extract_identity ( & mut table) ;
308
289
309
- let cand = cand . fold_with ( & mut table ) ;
310
- let canonical_cand = Canonicalized :: new ( self . db , cand ) ;
290
+ let canonical_cand = Canonicalized :: new ( self . db , inst ) ;
291
+ let inst = table . instantiate_with_fresh_vars ( Binder :: bind ( inst ) ) ;
311
292
312
293
match is_goal_satisfiable (
313
294
self . db ,
@@ -318,13 +299,13 @@ impl<'db> MethodSelector<'db> {
318
299
GoalSatisfiability :: Satisfied ( solution) => {
319
300
// Map back the solution to the current context.
320
301
let solution = canonical_cand. extract_solution ( & mut table, * solution) ;
321
-
322
- // Unify candidate to solution .
323
- table. unify ( cand , solution) . unwrap ( ) ;
302
+ // Replace TyParams in the solved instance with fresh inference vars so
303
+ // downstream unification can bind them (e.g., T = u32) .
304
+ let solution = table. instantiate_with_fresh_vars ( Binder :: bind ( solution) ) ;
324
305
325
306
MethodCandidate :: TraitMethod ( TraitMethodCand :: new (
326
307
self . receiver
327
- . canonicalize_solution ( self . db , & mut table, cand ) ,
308
+ . canonicalize_solution ( self . db , & mut table, solution ) ,
328
309
method,
329
310
) )
330
311
}
@@ -334,7 +315,7 @@ impl<'db> MethodSelector<'db> {
334
315
| GoalSatisfiability :: UnSat ( _) => {
335
316
MethodCandidate :: NeedsConfirmation ( TraitMethodCand :: new (
336
317
self . receiver
337
- . canonicalize_solution ( self . db , & mut table, cand ) ,
318
+ . canonicalize_solution ( self . db , & mut table, inst ) ,
338
319
method,
339
320
) )
340
321
}
@@ -373,7 +354,7 @@ impl<'db> MethodSelector<'db> {
373
354
#[ derive( Debug , Clone , PartialEq , Eq , Hash , salsa:: Update ) ]
374
355
pub enum MethodSelectionError < ' db > {
375
356
AmbiguousInherentMethod ( ThinVec < FuncDef < ' db > > ) ,
376
- AmbiguousTraitMethod ( ThinVec < TraitDef < ' db > > ) ,
357
+ AmbiguousTraitMethod ( ThinVec < TraitInstId < ' db > > ) ,
377
358
NotFound ,
378
359
InvisibleInherentMethod ( FuncDef < ' db > ) ,
379
360
InvisibleTraitMethod ( ThinVec < Trait < ' db > > ) ,
@@ -383,27 +364,11 @@ pub enum MethodSelectionError<'db> {
383
364
#[ derive( Default ) ]
384
365
struct AssembledCandidates < ' db > {
385
366
inherent_methods : FxHashSet < FuncDef < ' db > > ,
386
- traits : IndexSet < ( TraitDef < ' db > , TraitMethod < ' db > ) > ,
387
- // Store trait instances with their associated type bindings
388
- trait_instances : IndexMap < ( TraitDef < ' db > , TraitMethod < ' db > ) , TraitInstId < ' db > > ,
367
+ traits : IndexSet < ( TraitInstId < ' db > , TraitMethod < ' db > ) > ,
389
368
}
390
369
391
370
impl < ' db > AssembledCandidates < ' db > {
392
371
fn insert_inherent_method ( & mut self , method : FuncDef < ' db > ) {
393
372
self . inherent_methods . insert ( method) ;
394
373
}
395
-
396
- fn insert_trait ( & mut self , def : TraitDef < ' db > , method : TraitMethod < ' db > ) {
397
- self . traits . insert ( ( def, method) ) ;
398
- }
399
-
400
- fn insert_trait_with_inst (
401
- & mut self ,
402
- def : TraitDef < ' db > ,
403
- method : TraitMethod < ' db > ,
404
- inst : TraitInstId < ' db > ,
405
- ) {
406
- self . traits . insert ( ( def, method) ) ;
407
- self . trait_instances . insert ( ( def, method) , inst) ;
408
- }
409
374
}
0 commit comments