Skip to content

Commit d01a2f5

Browse files
authored
Merge pull request #29 from Rigidity/type-guard-redesign
Type guard redesign
2 parents 7e6ad94 + a4bbcf2 commit d01a2f5

File tree

14 files changed

+140
-69
lines changed

14 files changed

+140
-69
lines changed

.github/workflows/rust.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,8 @@ jobs:
1717
- name: Checkout
1818
uses: actions/checkout@v4
1919

20-
- name: Cargo binstall
21-
run: curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash
22-
23-
- name: Instal cargo-workspaces
24-
run: cargo binstall cargo-workspaces --locked -y
20+
- name: Install cargo-workspaces
21+
run: cargo install cargo-workspaces
2522

2623
- name: Run tests
2724
run: cargo test --all-features --workspace
@@ -31,7 +28,7 @@ jobs:
3128

3229
- name: Unused dependencies
3330
run: |
34-
cargo binstall cargo-machete --locked -y
31+
cargo install cargo-machete --locked
3532
cargo machete
3633
3734
- name: Fmt

crates/rue-compiler/src/compiler.rs

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::{
1212
database::{Database, HirId, ScopeId, SymbolId},
1313
hir::{Hir, Op},
1414
scope::Scope,
15+
symbol::{Function, Symbol},
1516
value::{GuardPath, Value},
1617
ErrorKind,
1718
};
@@ -46,8 +47,8 @@ pub struct Compiler<'a> {
4647
// The type definition stack is used for calculating types referenced in types.
4748
type_definition_stack: Vec<TypeId>,
4849

49-
// The type guard stack is used for overriding types in certain contexts.
50-
type_guard_stack: Vec<HashMap<GuardPath, TypeId>>,
50+
// Overridden symbol types due to type guards.
51+
type_overrides: Vec<HashMap<SymbolId, TypeId>>,
5152

5253
// The generic type stack is used for overriding generic types that are being checked against.
5354
generic_type_stack: Vec<HashMap<TypeId, TypeId>>,
@@ -74,7 +75,7 @@ impl<'a> Compiler<'a> {
7475
scope_stack: vec![builtins.scope_id],
7576
symbol_stack: Vec::new(),
7677
type_definition_stack: Vec::new(),
77-
type_guard_stack: Vec::new(),
78+
type_overrides: Vec::new(),
7879
generic_type_stack: Vec::new(),
7980
allow_generic_inference_stack: vec![false],
8081
is_callee: false,
@@ -169,13 +170,50 @@ impl<'a> Compiler<'a> {
169170
Value::new(self.builtins.unknown, self.ty.std().unknown)
170171
}
171172

172-
fn symbol_type(&self, guard_path: &GuardPath) -> Option<TypeId> {
173-
for guards in self.type_guard_stack.iter().rev() {
174-
if let Some(guard) = guards.get(guard_path) {
175-
return Some(*guard);
173+
fn build_overrides(&mut self, guards: HashMap<GuardPath, TypeId>) -> HashMap<SymbolId, TypeId> {
174+
type GuardItem = (Vec<TypePath>, TypeId);
175+
176+
let mut symbol_guards: HashMap<SymbolId, Vec<GuardItem>> = HashMap::new();
177+
178+
for (guard_path, type_id) in guards {
179+
symbol_guards
180+
.entry(guard_path.symbol_id)
181+
.or_default()
182+
.push((guard_path.items, type_id));
183+
}
184+
185+
let mut overrides = HashMap::new();
186+
187+
for (symbol_id, mut items) in symbol_guards {
188+
// Order by length.
189+
items.sort_by_key(|(items, _)| items.len());
190+
191+
let mut type_id = self.symbol_type(symbol_id);
192+
193+
for (path_items, new_type_id) in items {
194+
type_id = self.ty.replace(type_id, new_type_id, &path_items);
195+
}
196+
197+
overrides.insert(symbol_id, type_id);
198+
}
199+
200+
overrides
201+
}
202+
203+
fn symbol_type(&self, symbol_id: SymbolId) -> TypeId {
204+
for guards in self.type_overrides.iter().rev() {
205+
if let Some(type_id) = guards.get(&symbol_id) {
206+
return *type_id;
176207
}
177208
}
178-
None
209+
210+
match self.db.symbol(symbol_id) {
211+
Symbol::Unknown | Symbol::Module(..) => unreachable!(),
212+
Symbol::Function(Function { type_id, .. })
213+
| Symbol::InlineFunction(Function { type_id, .. })
214+
| Symbol::Parameter(type_id) => *type_id,
215+
Symbol::Let(value) | Symbol::Const(value) | Symbol::InlineConst(value) => value.type_id,
216+
}
179217
}
180218

181219
fn scope(&self) -> &Scope {

crates/rue-compiler/src/compiler/block.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ impl Compiler<'_> {
5050

5151
// Push the type guards onto the stack.
5252
// This will be popped in reverse order later after all statements have been lowered.
53-
self.type_guard_stack.push(else_guards);
53+
let overrides = self.build_overrides(else_guards);
54+
self.type_overrides.push(overrides);
5455

5556
statements.push(Statement::If(condition_hir, then_hir));
5657
}
@@ -103,8 +104,8 @@ impl Compiler<'_> {
103104
// If the condition is false, we raise an error.
104105
// So we can assume that the condition is true from this point on.
105106
// This will be popped in reverse order later after all statements have been lowered.
106-
107-
self.type_guard_stack.push(condition.then_guards());
107+
let overrides = self.build_overrides(condition.then_guards());
108+
self.type_overrides.push(overrides);
108109

109110
let not_condition = self.db.alloc_hir(Hir::Op(Op::Not, condition.hir_id));
110111
let raise = self.db.alloc_hir(Hir::Raise(None));
@@ -126,7 +127,8 @@ impl Compiler<'_> {
126127
assume_stmt.syntax().text_range(),
127128
);
128129

129-
self.type_guard_stack.push(expr.then_guards());
130+
let overrides = self.build_overrides(expr.then_guards());
131+
self.type_overrides.push(overrides);
130132
statements.push(Statement::Assume);
131133
}
132134
}
@@ -158,7 +160,7 @@ impl Compiler<'_> {
158160
body = value;
159161
}
160162
Statement::If(condition, then_block) => {
161-
self.type_guard_stack.pop().unwrap();
163+
self.type_overrides.pop().unwrap();
162164

163165
body = Value::new(
164166
self.db
@@ -167,7 +169,7 @@ impl Compiler<'_> {
167169
);
168170
}
169171
Statement::Assume => {
170-
self.type_guard_stack.pop().unwrap();
172+
self.type_overrides.pop().unwrap();
171173
}
172174
}
173175
}

crates/rue-compiler/src/compiler/expr/binary_expr.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ impl Compiler<'_> {
172172
let else_type = self.ty.difference(rhs.type_id, self.ty.std().nil);
173173
value
174174
.guards
175-
.insert(guard_path, Guard::new(then_type, else_type));
175+
.insert(guard_path, Guard::new(Some(then_type), Some(else_type)));
176176
}
177177
}
178178

@@ -182,7 +182,7 @@ impl Compiler<'_> {
182182
let else_type = self.ty.difference(lhs.type_id, self.ty.std().nil);
183183
value
184184
.guards
185-
.insert(guard_path, Guard::new(then_type, else_type));
185+
.insert(guard_path, Guard::new(Some(then_type), Some(else_type)));
186186
}
187187
}
188188

@@ -250,13 +250,14 @@ impl Compiler<'_> {
250250
}
251251

252252
fn op_and(&mut self, lhs: Value, rhs: Option<&Expr>, text_range: TextRange) -> Value {
253-
self.type_guard_stack.push(lhs.then_guards());
253+
let overrides = self.build_overrides(lhs.then_guards());
254+
self.type_overrides.push(overrides);
254255

255256
let rhs = rhs
256257
.map(|rhs| self.compile_expr(rhs, Some(self.ty.std().bool)))
257258
.unwrap_or_else(|| self.unknown());
258259

259-
self.type_guard_stack.pop().unwrap();
260+
self.type_overrides.pop().unwrap();
260261

261262
self.type_check(lhs.type_id, self.ty.std().bool, text_range);
262263
self.type_check(rhs.type_id, self.ty.std().bool, text_range);
@@ -267,19 +268,28 @@ impl Compiler<'_> {
267268
rhs.hir_id,
268269
self.ty.std().bool,
269270
);
270-
value.guards.extend(lhs.guards);
271-
value.guards.extend(rhs.guards);
271+
value.guards.extend(
272+
lhs.guards
273+
.into_iter()
274+
.map(|(path, guard)| (path, Guard::new(guard.then_type, None))),
275+
);
276+
value.guards.extend(
277+
rhs.guards
278+
.into_iter()
279+
.map(|(path, guard)| (path, Guard::new(guard.then_type, None))),
280+
);
272281
value
273282
}
274283

275284
fn op_or(&mut self, lhs: &Value, rhs: Option<&Expr>, text_range: TextRange) -> Value {
276-
self.type_guard_stack.push(lhs.then_guards());
285+
let overrides = self.build_overrides(lhs.else_guards());
286+
self.type_overrides.push(overrides);
277287

278288
let rhs = rhs
279289
.map(|rhs| self.compile_expr(rhs, Some(self.ty.std().bool)))
280290
.unwrap_or_else(|| self.unknown());
281291

282-
self.type_guard_stack.pop().unwrap();
292+
self.type_overrides.pop().unwrap();
283293

284294
self.type_check(lhs.type_id, self.ty.std().bool, text_range);
285295
self.type_check(rhs.type_id, self.ty.std().bool, text_range);

crates/rue-compiler/src/compiler/expr/field_access_expr.rs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ impl Compiler<'_> {
2222
return self.unknown();
2323
};
2424

25-
let mut new_value = match self.ty.get(old_value.type_id).clone() {
26-
Type::Unknown => return self.unknown(),
25+
match self.ty.get(old_value.type_id).clone() {
26+
Type::Unknown => self.unknown(),
2727
Type::Struct(ty) => {
2828
let Some(value) = self.compile_struct_field_access(old_value, &ty, &name) else {
2929
return self.unknown();
@@ -55,17 +55,9 @@ impl Compiler<'_> {
5555
),
5656
name.text_range(),
5757
);
58-
return self.unknown();
59-
}
60-
};
61-
62-
if let Some(guard_path) = new_value.guard_path.as_ref() {
63-
if let Some(type_override) = self.symbol_type(guard_path) {
64-
new_value.type_id = type_override;
58+
self.unknown()
6559
}
6660
}
67-
68-
new_value
6961
}
7062

7163
fn compile_pair_field_access(
@@ -113,7 +105,7 @@ impl Compiler<'_> {
113105
) -> Option<Value> {
114106
let fields =
115107
deconstruct_items(self.ty, ty.type_id, ty.field_names.len(), ty.nil_terminated)
116-
.expect("invalid struct type");
108+
.unwrap();
117109

118110
let Some(index) = ty.field_names.get_index_of(name.text()) else {
119111
self.db
@@ -157,7 +149,7 @@ impl Compiler<'_> {
157149
.as_ref()
158150
.map(|field_names| {
159151
deconstruct_items(self.ty, type_id, field_names.len(), ty.nil_terminated)
160-
.expect("invalid struct type")
152+
.unwrap()
161153
})
162154
.unwrap_or_default()
163155
} else {

crates/rue-compiler/src/compiler/expr/guard_expr.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ impl Compiler<'_> {
6262

6363
if let Some(guard_path) = expr.guard_path {
6464
let difference = self.ty.difference(expr.type_id, rhs);
65-
value.guards.insert(guard_path, Guard::new(rhs, difference));
65+
value
66+
.guards
67+
.insert(guard_path, Guard::new(Some(rhs), Some(difference)));
6668
}
6769

6870
value

crates/rue-compiler/src/compiler/expr/if_expr.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,21 @@ impl Compiler<'_> {
1010
.map(|condition| self.compile_expr(&condition, Some(self.ty.std().bool)));
1111

1212
if let Some(condition) = condition.as_ref() {
13-
self.type_guard_stack.push(condition.then_guards());
13+
let overrides = self.build_overrides(condition.then_guards());
14+
self.type_overrides.push(overrides);
1415
}
1516

1617
let then_block = if_expr
1718
.then_block()
1819
.map(|then_block| self.compile_block_expr(&then_block, expected_type));
1920

2021
if condition.is_some() {
21-
self.type_guard_stack.pop().unwrap();
22+
self.type_overrides.pop().unwrap();
2223
}
2324

2425
if let Some(condition) = condition.as_ref() {
25-
self.type_guard_stack.push(condition.else_guards());
26+
let overrides = self.build_overrides(condition.else_guards());
27+
self.type_overrides.push(overrides);
2628
}
2729

2830
let expected_type =
@@ -33,7 +35,7 @@ impl Compiler<'_> {
3335
.map(|else_block| self.compile_block_expr(&else_block, expected_type));
3436

3537
if condition.is_some() {
36-
self.type_guard_stack.pop().unwrap();
38+
self.type_overrides.pop().unwrap();
3739
}
3840

3941
if let Some(condition_type) = condition.as_ref().map(|condition| condition.type_id) {

crates/rue-compiler/src/compiler/expr/initializer_expr.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ impl Compiler<'_> {
1313
.path()
1414
.map(|path| self.compile_path_type(&path.items(), path.syntax().text_range()));
1515

16-
match ty.map(|ty| self.ty.get(ty)).cloned() {
16+
match ty.map(|ty| self.ty.get_unaliased(ty)).cloned() {
1717
Some(Type::Struct(struct_type)) => {
1818
let fields = deconstruct_items(
1919
self.ty,
@@ -85,7 +85,7 @@ impl Compiler<'_> {
8585
self.unknown()
8686
}
8787
}
88-
Some(_) => {
88+
Some(..) => {
8989
self.db.error(
9090
ErrorKind::UninitializableType(self.type_name(ty.unwrap())),
9191
initializer.path().unwrap().syntax().text_range(),

crates/rue-compiler/src/compiler/expr/path_expr.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::{
88
Compiler,
99
},
1010
hir::Hir,
11-
symbol::{Function, Symbol},
11+
symbol::Symbol,
1212
value::{GuardPath, Value},
1313
ErrorKind,
1414
};
@@ -75,18 +75,16 @@ impl Compiler<'_> {
7575
return self.unknown();
7676
}
7777

78-
let type_override = self.symbol_type(&GuardPath::new(symbol_id));
78+
let type_id = self.symbol_type(symbol_id);
7979
let reference = self.db.alloc_hir(Hir::Reference(symbol_id, text_range));
8080

8181
let mut value = match self.db.symbol(symbol_id).clone() {
8282
Symbol::Unknown | Symbol::Module(..) => unreachable!(),
83-
Symbol::Function(Function { type_id, .. })
84-
| Symbol::InlineFunction(Function { type_id, .. })
85-
| Symbol::Parameter(type_id) => Value::new(reference, type_override.unwrap_or(type_id)),
83+
Symbol::Function(..) | Symbol::InlineFunction(..) | Symbol::Parameter(..) => {
84+
Value::new(reference, type_id)
85+
}
8686
Symbol::Let(mut value) | Symbol::Const(mut value) | Symbol::InlineConst(mut value) => {
87-
if let Some(type_id) = type_override {
88-
value.type_id = type_id;
89-
}
87+
value.type_id = type_id;
9088
value.hir_id = reference;
9189
value
9290
}

crates/rue-compiler/src/compiler/stmt/if_stmt.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,16 @@ impl Compiler<'_> {
3535
let scope_id = self.db.alloc_scope(Scope::default());
3636

3737
// We can apply any type guards from the condition.
38-
self.type_guard_stack.push(condition.then_guards());
38+
let overrides = self.build_overrides(condition.then_guards());
39+
self.type_overrides.push(overrides);
3940

4041
// Compile the then block.
4142
self.scope_stack.push(scope_id);
4243
let summary = self.compile_block(&then_block, expected_type);
4344
self.scope_stack.pop().unwrap();
4445

4546
// Pop the type guards, since we've left the scope.
46-
self.type_guard_stack.pop().unwrap();
47+
self.type_overrides.pop().unwrap();
4748

4849
// If there's an implicit return, we want to raise an error.
4950
// This could technically work but makes the intent of the code unclear.

0 commit comments

Comments
 (0)