Skip to content

Commit 1a5f08f

Browse files
committed
Combine expression and pattern adjustments map in InferenceResult
1 parent dc52db6 commit 1a5f08f

File tree

13 files changed

+88
-110
lines changed

13 files changed

+88
-110
lines changed

crates/hir-ty/src/diagnostics/expr.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,9 @@ impl ExprValidator {
279279
fn is_known_valid_scrutinee(&self, scrutinee_expr: ExprId, db: &dyn HirDatabase) -> bool {
280280
if self
281281
.infer
282-
.expr_adjustments
283-
.get(&scrutinee_expr)
284-
.is_some_and(|adjusts| adjusts.iter().any(|a| matches!(a.kind, Adjust::Deref(..))))
282+
.expr_adjustments(scrutinee_expr)
283+
.iter()
284+
.any(|a| matches!(a.kind, Adjust::Deref(..)))
285285
{
286286
return false;
287287
}

crates/hir-ty/src/diagnostics/match_check.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,9 @@ impl<'a> PatCtxt<'a> {
113113
// Pattern adjustment is part of RFC 2005-match-ergonomics.
114114
// More info https://github.com/rust-lang/rust/issues/42640#issuecomment-313535089
115115
let unadjusted_pat = self.lower_pattern_unadjusted(pat);
116-
self.infer.pat_adjustments.get(&pat).map(|it| &**it).unwrap_or_default().iter().rev().fold(
117-
unadjusted_pat,
118-
|subpattern, ref_ty| Pat {
119-
ty: ref_ty.clone(),
120-
kind: Box::new(PatKind::Deref { subpattern }),
121-
},
122-
)
116+
self.infer.pat_adjustments(pat).iter().rev().fold(unadjusted_pat, |subpattern, ref_ty| {
117+
Pat { ty: ref_ty.target.clone(), kind: Box::new(PatKind::Deref { subpattern }) }
118+
})
123119
}
124120

125121
fn lower_pattern_unadjusted(&mut self, pat: PatId) -> Pat {

crates/hir-ty/src/infer.rs

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,6 @@ pub struct InferenceResult {
460460
// `TyKind::Error`.
461461
// Which will then mark this field.
462462
pub(crate) has_errors: bool,
463-
/// Interned common types to return references to.
464-
/// Stores the types which were implicitly dereferenced in pattern binding modes.
465-
pub pat_adjustments: FxHashMap<PatId, Vec<Ty>>,
466463
/// Stores the binding mode (`ref` in `let ref x = 2`) of bindings.
467464
///
468465
/// This one is tied to the `PatId` instead of `BindingId`, because in some rare cases, a binding in an
@@ -477,7 +474,8 @@ pub struct InferenceResult {
477474
/// ```
478475
/// the first `rest` has implicit `ref` binding mode, but the second `rest` binding mode is `move`.
479476
pub binding_modes: ArenaMap<PatId, BindingMode>,
480-
pub expr_adjustments: FxHashMap<ExprId, Box<[Adjustment]>>,
477+
/// For patterns, this stores the types which were implicitly dereferenced in pattern binding modes.
478+
adjustments: FxHashMap<ExprOrPatId, Box<[Adjustment]>>,
481479
pub(crate) closure_info: FxHashMap<ClosureId, (Vec<CapturedItem>, FnTrait)>,
482480
// FIXME: remove this field
483481
pub mutated_bindings_in_closure: FxHashSet<BindingId>,
@@ -488,6 +486,12 @@ impl InferenceResult {
488486
pub fn method_resolution(&self, expr: ExprId) -> Option<(FunctionId, Substitution)> {
489487
self.method_resolutions.get(&expr).cloned()
490488
}
489+
pub fn pat_adjustments(&self, pat: PatId) -> &[Adjustment] {
490+
self.adjustments.get(&ExprOrPatId::PatId(pat)).map_or(&[], |v| v.as_ref())
491+
}
492+
pub fn expr_adjustments(&self, expr: ExprId) -> &[Adjustment] {
493+
self.adjustments.get(&ExprOrPatId::ExprId(expr)).map_or(&[], |v| v.as_ref())
494+
}
491495
pub fn field_resolution(&self, expr: ExprId) -> Option<Either<FieldId, TupleFieldId>> {
492496
self.field_resolutions.get(&expr).copied()
493497
}
@@ -751,9 +755,8 @@ impl<'db> InferenceContext<'db> {
751755
type_of_for_iterator,
752756
type_mismatches,
753757
has_errors,
754-
pat_adjustments,
755758
binding_modes: _,
756-
expr_adjustments,
759+
adjustments,
757760
// Types in `closure_info` have already been `resolve_completely()`'d during
758761
// `InferenceContext::infer_closures()` (in `HirPlace::ty()` specifically), so no need
759762
// to resolve them here.
@@ -769,7 +772,7 @@ impl<'db> InferenceContext<'db> {
769772
// Even though coercion casts provide type hints, we check casts after fallback for
770773
// backwards compatibility. This makes fallback a stronger type hint than a cast coercion.
771774
let mut apply_adjustments = |expr, adj: Vec<_>| {
772-
expr_adjustments.insert(expr, adj.into_boxed_slice());
775+
adjustments.insert(ExprOrPatId::ExprId(expr), adj.into_boxed_slice());
773776
};
774777
let mut set_coercion_cast = |expr| {
775778
coercion_casts.insert(expr);
@@ -868,16 +871,11 @@ impl<'db> InferenceContext<'db> {
868871
*has_errors || subst.type_parameters(Interner).any(|ty| ty.contains_unknown());
869872
}
870873
assoc_resolutions.shrink_to_fit();
871-
for adjustment in expr_adjustments.values_mut().flatten() {
874+
for adjustment in adjustments.values_mut().flatten() {
872875
adjustment.target = table.resolve_completely(adjustment.target.clone());
873876
*has_errors = *has_errors || adjustment.target.contains_unknown();
874877
}
875-
expr_adjustments.shrink_to_fit();
876-
for adjustment in pat_adjustments.values_mut().flatten() {
877-
*adjustment = table.resolve_completely(adjustment.clone());
878-
*has_errors = *has_errors || adjustment.contains_unknown();
879-
}
880-
pat_adjustments.shrink_to_fit();
878+
adjustments.shrink_to_fit();
881879
result.tuple_field_access_types = tuple_field_accesses_rev
882880
.into_iter()
883881
.enumerate()
@@ -1260,7 +1258,7 @@ impl<'db> InferenceContext<'db> {
12601258
if adjustments.is_empty() {
12611259
return;
12621260
}
1263-
match self.result.expr_adjustments.entry(expr) {
1261+
match self.result.adjustments.entry(expr.into()) {
12641262
std::collections::hash_map::Entry::Occupied(mut entry) => {
12651263
match (&mut entry.get_mut()[..], &adjustments[..]) {
12661264
(

crates/hir-ty/src/infer/closure.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -905,8 +905,7 @@ impl CapturedItemWithoutTy {
905905
impl InferenceContext<'_> {
906906
fn place_of_expr(&mut self, tgt_expr: ExprId) -> Option<HirPlace> {
907907
let r = self.place_of_expr_without_adjust(tgt_expr)?;
908-
let adjustments =
909-
self.result.expr_adjustments.get(&tgt_expr).map(|it| &**it).unwrap_or_default();
908+
let adjustments = self.result.expr_adjustments(tgt_expr);
910909
apply_adjusts_to_place(&mut self.current_capture_span_stack, r, adjustments)
911910
}
912911

@@ -1087,12 +1086,12 @@ impl InferenceContext<'_> {
10871086
}
10881087

10891088
fn walk_expr(&mut self, tgt_expr: ExprId) {
1090-
if let Some(it) = self.result.expr_adjustments.get_mut(&tgt_expr) {
1089+
if let Some(it) = self.result.adjustments.get_mut(&tgt_expr.into()) {
10911090
// FIXME: this take is completely unneeded, and just is here to make borrow checker
10921091
// happy. Remove it if you can.
10931092
let x_taken = mem::take(it);
10941093
self.walk_expr_with_adjust(tgt_expr, &x_taken);
1095-
*self.result.expr_adjustments.get_mut(&tgt_expr).unwrap() = x_taken;
1094+
*self.result.adjustments.get_mut(&tgt_expr.into()).unwrap() = x_taken;
10961095
} else {
10971096
self.walk_expr_without_adjust(tgt_expr);
10981097
}
@@ -1389,7 +1388,7 @@ impl InferenceContext<'_> {
13891388
},
13901389
},
13911390
}
1392-
if self.result.pat_adjustments.get(&p).is_some_and(|it| !it.is_empty()) {
1391+
if !self.result.pat_adjustments(p).is_empty() {
13931392
for_mut = BorrowKind::Mut { kind: MutBorrowKind::ClosureCapture };
13941393
}
13951394
self.body.walk_pats_shallow(p, |p| self.walk_pat_inner(p, update_result, for_mut));
@@ -1401,10 +1400,8 @@ impl InferenceContext<'_> {
14011400

14021401
fn expr_ty_after_adjustments(&self, e: ExprId) -> Ty {
14031402
let mut ty = None;
1404-
if let Some(it) = self.result.expr_adjustments.get(&e) {
1405-
if let Some(it) = it.last() {
1406-
ty = Some(it.target.clone());
1407-
}
1403+
if let Some(it) = self.result.expr_adjustments(e).last() {
1404+
ty = Some(it.target.clone());
14081405
}
14091406
ty.unwrap_or_else(|| self.expr_ty(e))
14101407
}
@@ -1515,8 +1512,7 @@ impl InferenceContext<'_> {
15151512
}
15161513

15171514
fn consume_with_pat(&mut self, mut place: HirPlace, tgt_pat: PatId) {
1518-
let adjustments_count =
1519-
self.result.pat_adjustments.get(&tgt_pat).map(|it| it.len()).unwrap_or_default();
1515+
let adjustments_count = self.result.pat_adjustments(tgt_pat).len();
15201516
place.projections.extend((0..adjustments_count).map(|_| ProjectionElem::Deref));
15211517
self.current_capture_span_stack
15221518
.extend((0..adjustments_count).map(|_| MirSpan::PatId(tgt_pat)));
@@ -1736,8 +1732,12 @@ impl InferenceContext<'_> {
17361732

17371733
for (derefed_callee, callee_ty, params, expr) in exprs {
17381734
if let &Expr::Call { callee, .. } = &self.body[expr] {
1739-
let mut adjustments =
1740-
self.result.expr_adjustments.remove(&callee).unwrap_or_default().into_vec();
1735+
let mut adjustments = self
1736+
.result
1737+
.adjustments
1738+
.remove(&callee.into())
1739+
.unwrap_or_default()
1740+
.into_vec();
17411741
self.write_fn_trait_method_resolution(
17421742
kind,
17431743
&derefed_callee,
@@ -1746,7 +1746,7 @@ impl InferenceContext<'_> {
17461746
&params,
17471747
expr,
17481748
);
1749-
self.result.expr_adjustments.insert(callee, adjustments.into_boxed_slice());
1749+
self.result.adjustments.insert(callee.into(), adjustments.into_boxed_slice());
17501750
}
17511751
}
17521752
}

crates/hir-ty/src/infer/coerce.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ impl CoerceMany {
160160
// - [Comment from rustc](https://github.com/rust-lang/rust/blob/5ff18d0eaefd1bd9ab8ec33dab2404a44e7631ed/compiler/rustc_hir_typeck/src/coercion.rs#L1334-L1335)
161161
// First try to coerce the new expression to the type of the previous ones,
162162
// but only if the new expression has no coercion already applied to it.
163-
if expr.is_none_or(|expr| !ctx.result.expr_adjustments.contains_key(&expr)) {
163+
if expr.is_none_or(|expr| !ctx.result.adjustments.contains_key(&expr.into())) {
164164
if let Ok(res) = ctx.coerce(expr, &expr_ty, &self.merged_ty(), CoerceNever::Yes) {
165165
self.final_ty = Some(res);
166166
if let Some(expr) = expr {

crates/hir-ty/src/infer/expr.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,10 @@ impl InferenceContext<'_> {
250250
// While we don't allow *arbitrary* coercions here, we *do* allow
251251
// coercions from `!` to `expected`.
252252
if ty.is_never() {
253-
if let Some(adjustments) = self.result.expr_adjustments.get(&expr) {
254-
return if let [Adjustment { kind: Adjust::NeverToAny, target }] = &**adjustments {
255-
target.clone()
256-
} else {
257-
self.err_ty()
258-
};
253+
match self.result.expr_adjustments(expr) {
254+
[Adjustment { kind: Adjust::NeverToAny, target }] => return target.clone(),
255+
[_, ..] => return self.err_ty(),
256+
[] => (),
259257
}
260258

261259
if let Some(target) = expected.only_has_type(&mut self.table) {

crates/hir-ty/src/infer/mutability.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ impl InferenceContext<'_> {
2424
}
2525

2626
fn infer_mut_expr(&mut self, tgt_expr: ExprId, mut mutability: Mutability) {
27-
if let Some(adjustments) = self.result.expr_adjustments.get_mut(&tgt_expr) {
27+
if let Some(adjustments) = self.result.adjustments.get_mut(&tgt_expr.into()) {
2828
for adj in adjustments.iter_mut().rev() {
2929
match &mut adj.kind {
3030
Adjust::NeverToAny | Adjust::Deref(None) | Adjust::Pointer(_) => (),
@@ -138,8 +138,8 @@ impl InferenceContext<'_> {
138138
let mut base_ty = None;
139139
let base_adjustments = self
140140
.result
141-
.expr_adjustments
142-
.get_mut(&base)
141+
.adjustments
142+
.get_mut(&base.into())
143143
.and_then(|it| it.last_mut());
144144
if let Some(Adjustment {
145145
kind: Adjust::Borrow(AutoBorrow::Ref(_, mutability)),

crates/hir-ty/src/infer/pat.rs

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ use hir_expand::name::Name;
1111
use stdx::TupleExt;
1212

1313
use crate::{
14-
DeclContext, DeclOrigin, InferenceDiagnostic, Interner, Mutability, Scalar, Substitution, Ty,
15-
TyBuilder, TyExt, TyKind,
14+
Adjust, Adjustment, DeclContext, DeclOrigin, InferenceDiagnostic, Interner, Mutability, Scalar,
15+
Substitution, Ty, TyBuilder, TyExt, TyKind,
1616
consteval::{self, try_const_usize, usize_const},
1717
infer::{
1818
BindingMode, ERROR_TY, Expectation, InferenceContext, TypeMismatch, coerce::CoerceNever,
@@ -250,7 +250,8 @@ impl InferenceContext<'_> {
250250
} else if self.is_non_ref_pat(self.body, pat) {
251251
let mut pat_adjustments = Vec::new();
252252
while let Some((inner, _lifetime, mutability)) = expected.as_reference() {
253-
pat_adjustments.push(expected.clone());
253+
pat_adjustments
254+
.push(Adjustment { kind: Adjust::Deref(None), target: expected.clone() });
254255
expected = self.resolve_ty_shallow(inner);
255256
default_bm = match default_bm {
256257
BindingMode::Move => BindingMode::Ref(mutability),
@@ -260,8 +261,7 @@ impl InferenceContext<'_> {
260261
}
261262

262263
if !pat_adjustments.is_empty() {
263-
pat_adjustments.shrink_to_fit();
264-
self.result.pat_adjustments.insert(pat, pat_adjustments);
264+
self.result.adjustments.insert(pat.into(), pat_adjustments.into_boxed_slice());
265265
}
266266
}
267267

@@ -306,12 +306,13 @@ impl InferenceContext<'_> {
306306
match self.table.coerce(&expected, &ty_inserted_vars, CoerceNever::Yes) {
307307
Ok((adjustments, coerced_ty)) => {
308308
if !adjustments.is_empty() {
309-
self.result
310-
.pat_adjustments
311-
.entry(pat)
312-
.or_default()
313-
.extend(adjustments.into_iter().map(|adjust| adjust.target));
309+
let adjustments = match self.result.adjustments.remove(&pat.into()) {
310+
Some(prev) => prev.into_iter().chain(adjustments).collect(),
311+
None => adjustments.into_boxed_slice(),
312+
};
313+
self.result.adjustments.insert(pat.into(), adjustments);
314314
}
315+
315316
self.write_pat_ty(pat, coerced_ty);
316317
return self.pat_ty_after_adjustment(pat);
317318
}
@@ -419,10 +420,10 @@ impl InferenceContext<'_> {
419420

420421
fn pat_ty_after_adjustment(&self, pat: PatId) -> Ty {
421422
self.result
422-
.pat_adjustments
423-
.get(&pat)
423+
.adjustments
424+
.get(&pat.into())
424425
.and_then(|it| it.first())
425-
.unwrap_or(&self.result.type_of_pat[pat])
426+
.map_or_else(|| &self.result.type_of_pat[pat], |adj| &adj.target)
426427
.clone()
427428
}
428429

crates/hir-ty/src/mir/lower.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ impl<'ctx> MirLowerCtx<'ctx> {
399399
place: Place,
400400
prev_block: BasicBlockId,
401401
) -> Result<Option<BasicBlockId>> {
402-
if let Some(adjustments) = self.infer.expr_adjustments.get(&expr_id) {
402+
if let adjustments @ [_, ..] = self.infer.expr_adjustments(expr_id) {
403403
return self.lower_expr_to_place_with_adjust(expr_id, place, prev_block, adjustments);
404404
}
405405
self.lower_expr_to_place_without_adjust(expr_id, place, prev_block)
@@ -1054,13 +1054,10 @@ impl<'ctx> MirLowerCtx<'ctx> {
10541054
}
10551055
if let hir_def::hir::BinaryOp::Assignment { op: Some(op) } = op {
10561056
// last adjustment is `&mut` which we don't want it.
1057-
let adjusts = self
1058-
.infer
1059-
.expr_adjustments
1060-
.get(lhs)
1061-
.and_then(|it| it.split_last())
1062-
.map(|it| it.1)
1063-
.ok_or(MirLowerError::TypeError("adjustment of binary op was missing"))?;
1057+
let adjusts =
1058+
self.infer.expr_adjustments(*lhs).split_last().map(|it| it.1).ok_or(
1059+
MirLowerError::TypeError("adjustment of binary op was missing"),
1060+
)?;
10641061
let Some((lhs_place, current)) =
10651062
self.lower_expr_as_place_with_adjust(current, *lhs, false, adjusts)?
10661063
else {
@@ -1597,10 +1594,8 @@ impl<'ctx> MirLowerCtx<'ctx> {
15971594

15981595
fn expr_ty_after_adjustments(&self, e: ExprId) -> Ty {
15991596
let mut ty = None;
1600-
if let Some(it) = self.infer.expr_adjustments.get(&e) {
1601-
if let Some(it) = it.last() {
1602-
ty = Some(it.target.clone());
1603-
}
1597+
if let Some(it) = self.infer.expr_adjustments(e).last() {
1598+
ty = Some(it.target.clone());
16041599
}
16051600
ty.unwrap_or_else(|| self.expr_ty_without_adjust(e))
16061601
}
@@ -1674,7 +1669,7 @@ impl<'ctx> MirLowerCtx<'ctx> {
16741669
}
16751670

16761671
fn has_adjustments(&self, expr_id: ExprId) -> bool {
1677-
!self.infer.expr_adjustments.get(&expr_id).map(|it| it.is_empty()).unwrap_or(true)
1672+
!self.infer.expr_adjustments(expr_id).is_empty()
16781673
}
16791674

16801675
fn merge_blocks(

crates/hir-ty/src/mir/lower/as_place.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,11 @@ impl MirLowerCtx<'_> {
115115
expr_id: ExprId,
116116
upgrade_rvalue: bool,
117117
) -> Result<Option<(Place, BasicBlockId)>> {
118-
match self.infer.expr_adjustments.get(&expr_id) {
119-
Some(a) => self.lower_expr_as_place_with_adjust(current, expr_id, upgrade_rvalue, a),
120-
None => self.lower_expr_as_place_without_adjust(current, expr_id, upgrade_rvalue),
118+
match self.infer.expr_adjustments(expr_id) {
119+
a @ [_, ..] => {
120+
self.lower_expr_as_place_with_adjust(current, expr_id, upgrade_rvalue, a)
121+
}
122+
[] => self.lower_expr_as_place_without_adjust(current, expr_id, upgrade_rvalue),
121123
}
122124
}
123125

@@ -254,13 +256,8 @@ impl MirLowerCtx<'_> {
254256
index_fn,
255257
);
256258
}
257-
let adjusts = self
258-
.infer
259-
.expr_adjustments
260-
.get(base)
261-
.and_then(|it| it.split_last())
262-
.map(|it| it.1)
263-
.unwrap_or(&[]);
259+
let adjusts =
260+
self.infer.expr_adjustments(*base).split_last().map(|it| it.1).unwrap_or(&[]);
264261
let Some((mut p_base, current)) =
265262
self.lower_expr_as_place_with_adjust(current, *base, true, adjusts)?
266263
else {

crates/hir-ty/src/mir/lower/pattern_matching.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ impl MirLowerCtx<'_> {
119119
pattern: PatId,
120120
mode: MatchingMode,
121121
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
122-
let cnt = self.infer.pat_adjustments.get(&pattern).map(|x| x.len()).unwrap_or_default();
122+
let cnt = self.infer.pat_adjustments(pattern).len();
123123
cond_place.projection = self.result.projection_store.intern(
124124
cond_place
125125
.projection

0 commit comments

Comments
 (0)