Skip to content

Commit 235eb7a

Browse files
yangdanny97facebook-github-bot
authored andcommitted
support discriminated unions in narrowing
Summary: When we have a check like `x.y == 1` or `x["key"] = "value"` or `x[0] is None`, we previously only narrowed the appropriate facet for `x`. This diff makes us also narrow the type of `x`. This is useful for modeling a form of tagged unions in user code, where we have a union of classes/typeddicts/tuples that's differentiated by a facet that's typed as some literal. This is a feature that's supported for Pyright, and should be useful for Pydantic users as well. see: #650 closes #418 Reviewed By: stroxler Differential Revision: D80464317 fbshipit-source-id: ec583bbb910c53fecff4384b292cf8adc9ea93a9
1 parent dccec96 commit 235eb7a

File tree

2 files changed

+198
-1
lines changed

2 files changed

+198
-1
lines changed

pyrefly/lib/alt/narrow.rs

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use ruff_python_ast::StringLiteralValue;
2020
use ruff_python_ast::name::Name;
2121
use ruff_text_size::Ranged;
2222
use ruff_text_size::TextRange;
23+
use vec1::Vec1;
2324

2425
use crate::alt::answers::LookupAnswer;
2526
use crate::alt::answers_solver::AnswersSolver;
@@ -240,6 +241,102 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
240241
})
241242
}
242243

244+
// Try to narrow a type based on the type of its facet.
245+
// For example, if we have a `x.y == 0` check and `x` is some union,
246+
// we can eliminate cases from the union where `x.y` is some other
247+
// literal.
248+
pub fn atomic_narrow_for_facet(
249+
&self,
250+
base: &Type,
251+
facet: &FacetKind,
252+
op: &AtomicNarrowOp,
253+
range: TextRange,
254+
errors: &ErrorCollector,
255+
) -> Option<Type> {
256+
match op {
257+
AtomicNarrowOp::Is(v) => {
258+
let right = self.expr_infer(v, errors);
259+
Some(self.distribute_over_union(base, |t| {
260+
let base_info = TypeInfo::of_ty(t.clone());
261+
let facet_ty = self.get_facet_chain_type(
262+
&base_info,
263+
&FacetChain::new(Vec1::new(facet.clone())),
264+
range,
265+
);
266+
match right {
267+
Type::None | Type::Literal(Lit::Bool(_)) | Type::Literal(Lit::Enum(_)) => {
268+
if self.is_subset_eq(&right, &facet_ty) {
269+
t.clone()
270+
} else {
271+
Type::never()
272+
}
273+
}
274+
_ => t.clone(),
275+
}
276+
}))
277+
}
278+
AtomicNarrowOp::IsNot(v) => {
279+
let right = self.expr_infer(v, errors);
280+
Some(self.distribute_over_union(base, |t| {
281+
let base_info = TypeInfo::of_ty(t.clone());
282+
let facet_ty = self.get_facet_chain_type(
283+
&base_info,
284+
&FacetChain::new(Vec1::new(facet.clone())),
285+
range,
286+
);
287+
match (&facet_ty, &right) {
288+
(
289+
Type::None | Type::Literal(Lit::Bool(_)) | Type::Literal(Lit::Enum(_)),
290+
Type::None | Type::Literal(Lit::Bool(_)) | Type::Literal(Lit::Enum(_)),
291+
) if right == facet_ty => Type::never(),
292+
_ => t.clone(),
293+
}
294+
}))
295+
}
296+
AtomicNarrowOp::Eq(v) => {
297+
let right = self.expr_infer(v, errors);
298+
Some(self.distribute_over_union(base, |t| {
299+
let base_info = TypeInfo::of_ty(t.clone());
300+
let facet_ty = self.get_facet_chain_type(
301+
&base_info,
302+
&FacetChain::new(Vec1::new(facet.clone())),
303+
range,
304+
);
305+
match right {
306+
Type::None | Type::Literal(_) => {
307+
if self.is_subset_eq(&right, &facet_ty) {
308+
t.clone()
309+
} else {
310+
Type::never()
311+
}
312+
}
313+
_ => t.clone(),
314+
}
315+
}))
316+
}
317+
AtomicNarrowOp::NotEq(v) => {
318+
let right = self.expr_infer(v, errors);
319+
Some(self.distribute_over_union(base, |t| {
320+
let base_info = TypeInfo::of_ty(t.clone());
321+
let facet_ty = self.get_facet_chain_type(
322+
&base_info,
323+
&FacetChain::new(Vec1::new(facet.clone())),
324+
range,
325+
);
326+
match (&facet_ty, &right) {
327+
(Type::None | Type::Literal(_), Type::None | Type::Literal(_))
328+
if right == facet_ty =>
329+
{
330+
Type::never()
331+
}
332+
_ => t.clone(),
333+
}
334+
}))
335+
}
336+
_ => None,
337+
}
338+
}
339+
243340
pub fn atomic_narrow(
244341
&self,
245342
ty: &Type,
@@ -709,7 +806,33 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
709806
range,
710807
errors,
711808
);
712-
type_info.with_narrow(facet_chain.facets(), ty)
809+
let mut narrowed = type_info.with_narrow(facet_chain.facets(), ty);
810+
// For certain types of narrows, we can also narrow the parent of the current subject
811+
if let Some((last, prefix)) = facet_chain.facets().split_last() {
812+
match Vec1::try_from(prefix) {
813+
Ok(prefix_facets) => {
814+
let prefix_chain = FacetChain::new(prefix_facets);
815+
let base_ty =
816+
self.get_facet_chain_type(type_info, &prefix_chain, range);
817+
if let Some(narrowed_ty) =
818+
self.atomic_narrow_for_facet(&base_ty, last, op, range, errors)
819+
&& narrowed_ty != base_ty
820+
{
821+
narrowed = narrowed.with_narrow(prefix_chain.facets(), narrowed_ty);
822+
}
823+
}
824+
_ => {
825+
let base_ty = type_info.ty();
826+
if let Some(narrowed_ty) =
827+
self.atomic_narrow_for_facet(base_ty, last, op, range, errors)
828+
&& narrowed_ty != *base_ty
829+
{
830+
narrowed = narrowed.clone().with_ty(narrowed_ty);
831+
}
832+
}
833+
};
834+
}
835+
narrowed
713836
}
714837
NarrowOp::And(ops) => {
715838
let mut ops_iter = ops.iter();

pyrefly/lib/test/narrow.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,3 +1643,77 @@ class Test:
16431643
pass
16441644
"#,
16451645
);
1646+
1647+
testcase!(
1648+
test_discriminated_union_key,
1649+
r#"
1650+
from typing import TypedDict, assert_type, Literal
1651+
1652+
class UserDict(TypedDict):
1653+
kind: Literal["user"]
1654+
is_admin: Literal[False]
1655+
1656+
class AdminDict(TypedDict):
1657+
kind: Literal["admin"]
1658+
is_admin: Literal[True]
1659+
1660+
def test(x: UserDict | AdminDict):
1661+
if x["kind"] == "user":
1662+
assert_type(x, UserDict)
1663+
else:
1664+
assert_type(x, AdminDict)
1665+
if x["is_admin"] is True:
1666+
assert_type(x, AdminDict)
1667+
else:
1668+
assert_type(x, UserDict)
1669+
"#,
1670+
);
1671+
1672+
testcase!(
1673+
test_discriminated_union_attr,
1674+
r#"
1675+
from typing import assert_type, Literal
1676+
1677+
class User:
1678+
kind: Literal["user"]
1679+
is_admin: Literal[False]
1680+
1681+
class Admin:
1682+
kind: Literal["admin"]
1683+
is_admin: Literal[True]
1684+
1685+
def test(x: User | Admin):
1686+
if x.kind == "user":
1687+
assert_type(x, User)
1688+
elif x.kind == "admin":
1689+
assert_type(x, Admin)
1690+
1691+
if x.is_admin is True:
1692+
assert_type(x, Admin)
1693+
else:
1694+
assert_type(x, User)
1695+
"#,
1696+
);
1697+
1698+
testcase!(
1699+
test_discriminated_union_index,
1700+
r#"
1701+
from typing import assert_type, Literal
1702+
1703+
def test(x: tuple[Literal["user"], Literal[None]] | tuple[Literal["admin"], int]):
1704+
if x[0] == "user":
1705+
assert_type(x, tuple[Literal["user"], Literal[None]])
1706+
else:
1707+
assert_type(x, tuple[Literal["admin"], int])
1708+
1709+
if x[1] is None:
1710+
assert_type(x, tuple[Literal["user"], Literal[None]])
1711+
else:
1712+
assert_type(x, tuple[Literal["admin"], int])
1713+
1714+
if x[1] is not None:
1715+
assert_type(x, tuple[Literal["admin"], int])
1716+
else:
1717+
assert_type(x, tuple[Literal["user"], Literal[None]])
1718+
"#,
1719+
);

0 commit comments

Comments
 (0)