diff --git a/pyrefly/lib/alt/expr.rs b/pyrefly/lib/alt/expr.rs index 796bdcaed..f8ed14345 100644 --- a/pyrefly/lib/alt/expr.rs +++ b/pyrefly/lib/alt/expr.rs @@ -30,6 +30,7 @@ use ruff_python_ast::DictItem; use ruff_python_ast::Expr; use ruff_python_ast::ExprCall; use ruff_python_ast::ExprGenerator; +use ruff_python_ast::ExprList; use ruff_python_ast::ExprNumberLiteral; use ruff_python_ast::ExprSlice; use ruff_python_ast::ExprStarred; @@ -418,29 +419,20 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } Expr::Tuple(x) => self.tuple_infer(x, hint, errors), Expr::List(x) => { - let elt_hint = hint.and_then(|ty| self.decompose_list(ty)); - if x.is_empty() { - let elem_ty = elt_hint.map_or_else( - || { - if !self.solver().infer_with_first_use { - self.error( - errors, - x.range(), - ErrorInfo::Kind(ErrorKind::ImplicitAny), - "This expression is implicitly inferred to be `list[Any]`. Please provide an explicit type annotation.".to_owned(), - ); - Type::any_implicit() - } else { - self.solver().fresh_contained(self.uniques).to_type() - } - }, - |hint| hint.to_type(), - ); - self.stdlib.list(elem_ty).to_type() - } else { - let elem_tys = self.elts_infer(&x.elts, elt_hint, errors); - self.stdlib.list(self.unions(elem_tys)).to_type() + if let Some(hint_ref) = hint.as_ref() + && let Type::Union(options) = hint_ref.ty() + { + for option in options { + let branch_hint = + self.decompose_list(HintRef::new(option, hint_ref.errors())); + let ty = self.list_with_hint(x, branch_hint, errors); + if self.is_subset_eq(&ty, option) { + return ty; + } + } } + let elt_hint = hint.and_then(|ty| self.decompose_list(ty)); + self.list_with_hint(x, elt_hint, errors) } Expr::Dict(x) => self.dict_infer(&x.items, hint, x.range, errors), Expr::Set(x) => { @@ -1751,6 +1743,36 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { }) } + fn list_with_hint( + &self, + x: &ExprList, + elt_hint: Option, + errors: &ErrorCollector, + ) -> Type { + if x.is_empty() { + let elem_ty = elt_hint.map_or_else( + || { + if !self.solver().infer_with_first_use { + self.error( + errors, + x.range(), + ErrorInfo::Kind(ErrorKind::ImplicitAny), + "This expression is implicitly inferred to be `list[Any]`. Please provide an explicit type annotation.".to_owned(), + ); + Type::any_implicit() + } else { + self.solver().fresh_contained(self.uniques).to_type() + } + }, + |hint| hint.to_type(), + ); + self.stdlib.list(elem_ty).to_type() + } else { + let elem_tys = self.elts_infer(&x.elts, elt_hint, errors); + self.stdlib.list(self.unions(elem_tys)).to_type() + } + } + fn intercept_typing_self_use(&self, x: &Expr) -> Option { match x { Expr::Name(..) | Expr::Attribute(..) => { diff --git a/pyrefly/lib/alt/unwrap.rs b/pyrefly/lib/alt/unwrap.rs index effc103d8..6a7db495a 100644 --- a/pyrefly/lib/alt/unwrap.rs +++ b/pyrefly/lib/alt/unwrap.rs @@ -114,6 +114,54 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } + fn collect_var_from_hint(&self, ty: &Type, make: &F) -> Option> + where + F: Fn(Var) -> Type, + { + match ty { + Type::Union(tys) => { + let mut collected = Vec::new(); + let mut matched = false; + for branch in tys { + if let Some(mut branch_res) = self.collect_var_from_hint(branch, make) { + matched = true; + collected.append(&mut branch_res); + } + } + if matched { Some(collected) } else { None } + } + _ => { + let var = self.fresh_var(); + let target = make(var); + if self.is_subset_eq(&target, ty) { + match self.resolve_var_opt(ty, var) { + Some(resolved) => Some(vec![resolved]), + None => Some(Vec::new()), + } + } else { + None + } + } + } + } + + fn hint_from_types<'b>( + &self, + mut types: Vec, + hint: &HintRef<'b, '_>, + ) -> Option> { + if types.is_empty() { + None + } else { + let ty = if types.len() == 1 { + types.pop().unwrap() + } else { + self.unions(types) + }; + Some(hint.map_ty(|_| ty)) + } + } + pub fn unwrap_mapping(&self, ty: &Type) -> Option<(Type, Type)> { let key = self.fresh_var(); let value = self.fresh_var(); @@ -219,45 +267,65 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &self, hint: HintRef<'b, '_>, ) -> (Option>, Option>) { - let key = self.fresh_var(); - let value = self.fresh_var(); - let dict_type = self.stdlib.dict(key.to_type(), value.to_type()).to_type(); - if self.is_subset_eq(&dict_type, hint.ty()) { - let key = hint.map_ty_opt(|ty| self.resolve_var_opt(ty, key)); - let value = hint.map_ty_opt(|ty| self.resolve_var_opt(ty, value)); - (key, value) - } else { - (None, None) + let mut key_types = Vec::new(); + let mut value_types = Vec::new(); + let mut matched = false; + + // Helper to process a single target type and accumulate results. + let mut consider = |ty: &Type| { + let key = self.fresh_var(); + let value = self.fresh_var(); + let dict_type = self.stdlib.dict(key.to_type(), value.to_type()).to_type(); + if self.is_subset_eq(&dict_type, ty) { + matched = true; + if let Some(key_ty) = self.resolve_var_opt(ty, key) { + key_types.push(key_ty); + } + if let Some(value_ty) = self.resolve_var_opt(ty, value) { + value_types.push(value_ty); + } + } + }; + + match hint.ty() { + Type::Union(branches) => { + for branch in branches { + consider(branch); + } + } + ty => consider(ty), } + + if !matched { + return (None, None); + } + + let key = self.hint_from_types(key_types, &hint); + let value = self.hint_from_types(value_types, &hint); + (key, value) } pub fn decompose_set<'b>(&self, hint: HintRef<'b, '_>) -> Option> { - let elem = self.fresh_var(); - let set_type = self.stdlib.set(elem.to_type()).to_type(); - if self.is_subset_eq(&set_type, hint.ty()) { - hint.map_ty_opt(|ty| self.resolve_var_opt(ty, elem)) - } else { - None + let make = |var: Var| self.stdlib.set(var.to_type()).to_type(); + match self.collect_var_from_hint(hint.ty(), &make) { + Some(tys) => self.hint_from_types(tys, &hint), + None => None, } } pub fn decompose_list<'b>(&self, hint: HintRef<'b, '_>) -> Option> { - let elem = self.fresh_var(); - let list_type = self.stdlib.list(elem.to_type()).to_type(); - if self.is_subset_eq(&list_type, hint.ty()) { - hint.map_ty_opt(|ty| self.resolve_var_opt(ty, elem)) - } else { - None + let make = |var: Var| self.stdlib.list(var.to_type()).to_type(); + match self.collect_var_from_hint(hint.ty(), &make) { + Some(tys) => self.hint_from_types(tys, &hint), + None => None, } } pub fn decompose_tuple<'b>(&self, hint: HintRef<'b, '_>) -> Option> { - let elem = self.fresh_var(); - let tuple_type = self.stdlib.tuple(elem.to_type()).to_type(); - if self.is_subset_eq(&tuple_type, hint.ty()) { - hint.map_ty_opt(|ty| self.resolve_var_opt(ty, elem)) - } else { - None + let make = |var: Var| self.stdlib.tuple(var.to_type()).to_type(); + match self.collect_var_from_hint(hint.ty(), &make) { + Some(tys) => self.hint_from_types(tys, &hint), + None => None, } } diff --git a/pyrefly/lib/test/contextual.rs b/pyrefly/lib/test/contextual.rs index 63f8683d2..3e1adc46c 100644 --- a/pyrefly/lib/test/contextual.rs +++ b/pyrefly/lib/test/contextual.rs @@ -102,7 +102,6 @@ kwarg(xs=[B()], ys=[B()]) ); testcase!( - bug = "Both assignments should be allowed. When decomposing the contextual hint, we eagerly resolve vars to the 'first' branch of the union. Note: due to the union's sorted representation, the first branch is not necessarily the first in source order.", test_contextual_typing_against_unions, r#" class A: ... @@ -110,7 +109,7 @@ class B: ... class B2(B): ... class C: ... -x: list[A] | list[B] = [B2()] # E: `list[B2]` is not assignable to `list[A] | list[B]` +x: list[A] | list[B] = [B2()] y: list[B] | list[C] = [B2()] "#, ); @@ -266,7 +265,6 @@ x2: list[A] = True and [B()] ); testcase!( - bug = "x or y or ... fails due to union hints, see test_contextual_typing_against_unions", test_context_boolop_soft, r#" from typing import TypedDict, assert_type @@ -280,7 +278,7 @@ def test(x: list[A] | None, y: list[C] | None, z: TD | None) -> None: assert_type(x or [B()], list[A]) assert_type(x or [0], list[A] | list[int]) assert_type(x or y or [B()], list[A] | list[C]) - assert_type(x or y or [D()], list[A] | list[C]) # TODO # E: assert_type(list[A] | list[C] | list[D], list[A] | list[C]) failed + assert_type(x or y or [D()], list[A] | list[C]) assert_type(z or {"x": 0}, TD) assert_type(z or {"x": ""}, TD | dict[str, str]) "#,