@@ -12,6 +12,7 @@ use crate::{
12
12
database:: { Database , HirId , ScopeId , SymbolId } ,
13
13
hir:: { Hir , Op } ,
14
14
scope:: Scope ,
15
+ symbol:: { Function , Symbol } ,
15
16
value:: { GuardPath , Value } ,
16
17
ErrorKind ,
17
18
} ;
@@ -46,8 +47,8 @@ pub struct Compiler<'a> {
46
47
// The type definition stack is used for calculating types referenced in types.
47
48
type_definition_stack : Vec < TypeId > ,
48
49
49
- // The type guard stack is used for overriding types in certain contexts .
50
- type_guard_stack : Vec < HashMap < GuardPath , TypeId > > ,
50
+ // Overridden symbol types due to type guards .
51
+ type_overrides : Vec < HashMap < SymbolId , TypeId > > ,
51
52
52
53
// The generic type stack is used for overriding generic types that are being checked against.
53
54
generic_type_stack : Vec < HashMap < TypeId , TypeId > > ,
@@ -74,7 +75,7 @@ impl<'a> Compiler<'a> {
74
75
scope_stack : vec ! [ builtins. scope_id] ,
75
76
symbol_stack : Vec :: new ( ) ,
76
77
type_definition_stack : Vec :: new ( ) ,
77
- type_guard_stack : Vec :: new ( ) ,
78
+ type_overrides : Vec :: new ( ) ,
78
79
generic_type_stack : Vec :: new ( ) ,
79
80
allow_generic_inference_stack : vec ! [ false ] ,
80
81
is_callee : false ,
@@ -169,13 +170,50 @@ impl<'a> Compiler<'a> {
169
170
Value :: new ( self . builtins . unknown , self . ty . std ( ) . unknown )
170
171
}
171
172
172
- fn symbol_type ( & self , guard_path : & GuardPath ) -> Option < TypeId > {
173
- for guards in self . type_guard_stack . iter ( ) . rev ( ) {
174
- if let Some ( guard) = guards. get ( guard_path) {
175
- return Some ( * guard) ;
173
+ fn build_overrides ( & mut self , guards : HashMap < GuardPath , TypeId > ) -> HashMap < SymbolId , TypeId > {
174
+ type GuardItem = ( Vec < TypePath > , TypeId ) ;
175
+
176
+ let mut symbol_guards: HashMap < SymbolId , Vec < GuardItem > > = HashMap :: new ( ) ;
177
+
178
+ for ( guard_path, type_id) in guards {
179
+ symbol_guards
180
+ . entry ( guard_path. symbol_id )
181
+ . or_default ( )
182
+ . push ( ( guard_path. items , type_id) ) ;
183
+ }
184
+
185
+ let mut overrides = HashMap :: new ( ) ;
186
+
187
+ for ( symbol_id, mut items) in symbol_guards {
188
+ // Order by length.
189
+ items. sort_by_key ( |( items, _) | items. len ( ) ) ;
190
+
191
+ let mut type_id = self . symbol_type ( symbol_id) ;
192
+
193
+ for ( path_items, new_type_id) in items {
194
+ type_id = self . ty . replace ( type_id, new_type_id, & path_items) ;
195
+ }
196
+
197
+ overrides. insert ( symbol_id, type_id) ;
198
+ }
199
+
200
+ overrides
201
+ }
202
+
203
+ fn symbol_type ( & self , symbol_id : SymbolId ) -> TypeId {
204
+ for guards in self . type_overrides . iter ( ) . rev ( ) {
205
+ if let Some ( type_id) = guards. get ( & symbol_id) {
206
+ return * type_id;
176
207
}
177
208
}
178
- None
209
+
210
+ match self . db . symbol ( symbol_id) {
211
+ Symbol :: Unknown | Symbol :: Module ( ..) => unreachable ! ( ) ,
212
+ Symbol :: Function ( Function { type_id, .. } )
213
+ | Symbol :: InlineFunction ( Function { type_id, .. } )
214
+ | Symbol :: Parameter ( type_id) => * type_id,
215
+ Symbol :: Let ( value) | Symbol :: Const ( value) | Symbol :: InlineConst ( value) => value. type_id ,
216
+ }
179
217
}
180
218
181
219
fn scope ( & self ) -> & Scope {
0 commit comments