@@ -20,6 +20,7 @@ use ruff_python_ast::StringLiteralValue;
2020use ruff_python_ast:: name:: Name ;
2121use ruff_text_size:: Ranged ;
2222use ruff_text_size:: TextRange ;
23+ use vec1:: Vec1 ;
2324
2425use crate :: alt:: answers:: LookupAnswer ;
2526use crate :: alt:: answers_solver:: AnswersSolver ;
@@ -240,6 +241,102 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
240241 } )
241242 }
242243
244+ // Try to narrow a type based on the type of its facet.
245+ // For example, if we have a `x.y == 0` check and `x` is some union,
246+ // we can eliminate cases from the union where `x.y` is some other
247+ // literal.
248+ pub fn atomic_narrow_for_facet (
249+ & self ,
250+ base : & Type ,
251+ facet : & FacetKind ,
252+ op : & AtomicNarrowOp ,
253+ range : TextRange ,
254+ errors : & ErrorCollector ,
255+ ) -> Option < Type > {
256+ match op {
257+ AtomicNarrowOp :: Is ( v) => {
258+ let right = self . expr_infer ( v, errors) ;
259+ Some ( self . distribute_over_union ( base, |t| {
260+ let base_info = TypeInfo :: of_ty ( t. clone ( ) ) ;
261+ let facet_ty = self . get_facet_chain_type (
262+ & base_info,
263+ & FacetChain :: new ( Vec1 :: new ( facet. clone ( ) ) ) ,
264+ range,
265+ ) ;
266+ match right {
267+ Type :: None | Type :: Literal ( Lit :: Bool ( _) ) | Type :: Literal ( Lit :: Enum ( _) ) => {
268+ if self . is_subset_eq ( & right, & facet_ty) {
269+ t. clone ( )
270+ } else {
271+ Type :: never ( )
272+ }
273+ }
274+ _ => t. clone ( ) ,
275+ }
276+ } ) )
277+ }
278+ AtomicNarrowOp :: IsNot ( v) => {
279+ let right = self . expr_infer ( v, errors) ;
280+ Some ( self . distribute_over_union ( base, |t| {
281+ let base_info = TypeInfo :: of_ty ( t. clone ( ) ) ;
282+ let facet_ty = self . get_facet_chain_type (
283+ & base_info,
284+ & FacetChain :: new ( Vec1 :: new ( facet. clone ( ) ) ) ,
285+ range,
286+ ) ;
287+ match ( & facet_ty, & right) {
288+ (
289+ Type :: None | Type :: Literal ( Lit :: Bool ( _) ) | Type :: Literal ( Lit :: Enum ( _) ) ,
290+ Type :: None | Type :: Literal ( Lit :: Bool ( _) ) | Type :: Literal ( Lit :: Enum ( _) ) ,
291+ ) if right == facet_ty => Type :: never ( ) ,
292+ _ => t. clone ( ) ,
293+ }
294+ } ) )
295+ }
296+ AtomicNarrowOp :: Eq ( v) => {
297+ let right = self . expr_infer ( v, errors) ;
298+ Some ( self . distribute_over_union ( base, |t| {
299+ let base_info = TypeInfo :: of_ty ( t. clone ( ) ) ;
300+ let facet_ty = self . get_facet_chain_type (
301+ & base_info,
302+ & FacetChain :: new ( Vec1 :: new ( facet. clone ( ) ) ) ,
303+ range,
304+ ) ;
305+ match right {
306+ Type :: None | Type :: Literal ( _) => {
307+ if self . is_subset_eq ( & right, & facet_ty) {
308+ t. clone ( )
309+ } else {
310+ Type :: never ( )
311+ }
312+ }
313+ _ => t. clone ( ) ,
314+ }
315+ } ) )
316+ }
317+ AtomicNarrowOp :: NotEq ( v) => {
318+ let right = self . expr_infer ( v, errors) ;
319+ Some ( self . distribute_over_union ( base, |t| {
320+ let base_info = TypeInfo :: of_ty ( t. clone ( ) ) ;
321+ let facet_ty = self . get_facet_chain_type (
322+ & base_info,
323+ & FacetChain :: new ( Vec1 :: new ( facet. clone ( ) ) ) ,
324+ range,
325+ ) ;
326+ match ( & facet_ty, & right) {
327+ ( Type :: None | Type :: Literal ( _) , Type :: None | Type :: Literal ( _) )
328+ if right == facet_ty =>
329+ {
330+ Type :: never ( )
331+ }
332+ _ => t. clone ( ) ,
333+ }
334+ } ) )
335+ }
336+ _ => None ,
337+ }
338+ }
339+
243340 pub fn atomic_narrow (
244341 & self ,
245342 ty : & Type ,
@@ -709,7 +806,33 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
709806 range,
710807 errors,
711808 ) ;
712- type_info. with_narrow ( facet_chain. facets ( ) , ty)
809+ let mut narrowed = type_info. with_narrow ( facet_chain. facets ( ) , ty) ;
810+ // For certain types of narrows, we can also narrow the parent of the current subject
811+ if let Some ( ( last, prefix) ) = facet_chain. facets ( ) . split_last ( ) {
812+ match Vec1 :: try_from ( prefix) {
813+ Ok ( prefix_facets) => {
814+ let prefix_chain = FacetChain :: new ( prefix_facets) ;
815+ let base_ty =
816+ self . get_facet_chain_type ( type_info, & prefix_chain, range) ;
817+ if let Some ( narrowed_ty) =
818+ self . atomic_narrow_for_facet ( & base_ty, last, op, range, errors)
819+ && narrowed_ty != base_ty
820+ {
821+ narrowed = narrowed. with_narrow ( prefix_chain. facets ( ) , narrowed_ty) ;
822+ }
823+ }
824+ _ => {
825+ let base_ty = type_info. ty ( ) ;
826+ if let Some ( narrowed_ty) =
827+ self . atomic_narrow_for_facet ( base_ty, last, op, range, errors)
828+ && narrowed_ty != * base_ty
829+ {
830+ narrowed = narrowed. clone ( ) . with_ty ( narrowed_ty) ;
831+ }
832+ }
833+ } ;
834+ }
835+ narrowed
713836 }
714837 NarrowOp :: And ( ops) => {
715838 let mut ops_iter = ops. iter ( ) ;
0 commit comments