Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 44 additions & 22 deletions pyrefly/lib/alt/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use ruff_python_ast::DictItem;
use ruff_python_ast::Expr;
use ruff_python_ast::ExprCall;
use ruff_python_ast::ExprGenerator;
use ruff_python_ast::ExprList;
use ruff_python_ast::ExprNumberLiteral;
use ruff_python_ast::ExprSlice;
use ruff_python_ast::ExprStarred;
Expand Down Expand Up @@ -418,29 +419,20 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
Expr::Tuple(x) => self.tuple_infer(x, hint, errors),
Expr::List(x) => {
let elt_hint = hint.and_then(|ty| self.decompose_list(ty));
if x.is_empty() {
let elem_ty = elt_hint.map_or_else(
|| {
if !self.solver().infer_with_first_use {
self.error(
errors,
x.range(),
ErrorInfo::Kind(ErrorKind::ImplicitAny),
"This expression is implicitly inferred to be `list[Any]`. Please provide an explicit type annotation.".to_owned(),
);
Type::any_implicit()
} else {
self.solver().fresh_contained(self.uniques).to_type()
}
},
|hint| hint.to_type(),
);
self.stdlib.list(elem_ty).to_type()
} else {
let elem_tys = self.elts_infer(&x.elts, elt_hint, errors);
self.stdlib.list(self.unions(elem_tys)).to_type()
if let Some(hint_ref) = hint.as_ref()
&& let Type::Union(options) = hint_ref.ty()
{
for option in options {
let branch_hint =
self.decompose_list(HintRef::new(option, hint_ref.errors()));
let ty = self.list_with_hint(x, branch_hint, errors);
if self.is_subset_eq(&ty, option) {
return ty;
}
}
}
let elt_hint = hint.and_then(|ty| self.decompose_list(ty));
self.list_with_hint(x, elt_hint, errors)
}
Expr::Dict(x) => self.dict_infer(&x.items, hint, x.range, errors),
Expr::Set(x) => {
Expand Down Expand Up @@ -1751,6 +1743,36 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
})
}

fn list_with_hint(
&self,
x: &ExprList,
elt_hint: Option<Hint>,
errors: &ErrorCollector,
) -> Type {
if x.is_empty() {
let elem_ty = elt_hint.map_or_else(
|| {
if !self.solver().infer_with_first_use {
self.error(
errors,
x.range(),
ErrorInfo::Kind(ErrorKind::ImplicitAny),
"This expression is implicitly inferred to be `list[Any]`. Please provide an explicit type annotation.".to_owned(),
);
Type::any_implicit()
} else {
self.solver().fresh_contained(self.uniques).to_type()
}
},
|hint| hint.to_type(),
);
self.stdlib.list(elem_ty).to_type()
} else {
let elem_tys = self.elts_infer(&x.elts, elt_hint, errors);
self.stdlib.list(self.unions(elem_tys)).to_type()
}
}

fn intercept_typing_self_use(&self, x: &Expr) -> Option<TypeInfo> {
match x {
Expr::Name(..) | Expr::Attribute(..) => {
Expand Down
122 changes: 95 additions & 27 deletions pyrefly/lib/alt/unwrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,54 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
}

fn collect_var_from_hint<F>(&self, ty: &Type, make: &F) -> Option<Vec<Type>>
where
F: Fn(Var) -> Type,
{
match ty {
Type::Union(tys) => {
let mut collected = Vec::new();
let mut matched = false;
for branch in tys {
if let Some(mut branch_res) = self.collect_var_from_hint(branch, make) {
matched = true;
collected.append(&mut branch_res);
}
}
if matched { Some(collected) } else { None }
}
_ => {
let var = self.fresh_var();
let target = make(var);
if self.is_subset_eq(&target, ty) {
match self.resolve_var_opt(ty, var) {
Some(resolved) => Some(vec![resolved]),
None => Some(Vec::new()),
}
} else {
None
}
}
}
}

fn hint_from_types<'b>(
&self,
mut types: Vec<Type>,
hint: &HintRef<'b, '_>,
) -> Option<Hint<'b>> {
if types.is_empty() {
None
} else {
let ty = if types.len() == 1 {
types.pop().unwrap()
} else {
self.unions(types)
};
Some(hint.map_ty(|_| ty))
}
}

pub fn unwrap_mapping(&self, ty: &Type) -> Option<(Type, Type)> {
let key = self.fresh_var();
let value = self.fresh_var();
Expand Down Expand Up @@ -219,45 +267,65 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
&self,
hint: HintRef<'b, '_>,
) -> (Option<Hint<'b>>, Option<Hint<'b>>) {
let key = self.fresh_var();
let value = self.fresh_var();
let dict_type = self.stdlib.dict(key.to_type(), value.to_type()).to_type();
if self.is_subset_eq(&dict_type, hint.ty()) {
let key = hint.map_ty_opt(|ty| self.resolve_var_opt(ty, key));
let value = hint.map_ty_opt(|ty| self.resolve_var_opt(ty, value));
(key, value)
} else {
(None, None)
let mut key_types = Vec::new();
let mut value_types = Vec::new();
let mut matched = false;

// Helper to process a single target type and accumulate results.
let mut consider = |ty: &Type| {
let key = self.fresh_var();
let value = self.fresh_var();
let dict_type = self.stdlib.dict(key.to_type(), value.to_type()).to_type();
if self.is_subset_eq(&dict_type, ty) {
matched = true;
if let Some(key_ty) = self.resolve_var_opt(ty, key) {
key_types.push(key_ty);
}
if let Some(value_ty) = self.resolve_var_opt(ty, value) {
value_types.push(value_ty);
}
}
};

match hint.ty() {
Type::Union(branches) => {
for branch in branches {
consider(branch);
}
}
ty => consider(ty),
}

if !matched {
return (None, None);
}

let key = self.hint_from_types(key_types, &hint);
let value = self.hint_from_types(value_types, &hint);
(key, value)
}

pub fn decompose_set<'b>(&self, hint: HintRef<'b, '_>) -> Option<Hint<'b>> {
let elem = self.fresh_var();
let set_type = self.stdlib.set(elem.to_type()).to_type();
if self.is_subset_eq(&set_type, hint.ty()) {
hint.map_ty_opt(|ty| self.resolve_var_opt(ty, elem))
} else {
None
let make = |var: Var| self.stdlib.set(var.to_type()).to_type();
match self.collect_var_from_hint(hint.ty(), &make) {
Some(tys) => self.hint_from_types(tys, &hint),
None => None,
}
}

pub fn decompose_list<'b>(&self, hint: HintRef<'b, '_>) -> Option<Hint<'b>> {
let elem = self.fresh_var();
let list_type = self.stdlib.list(elem.to_type()).to_type();
if self.is_subset_eq(&list_type, hint.ty()) {
hint.map_ty_opt(|ty| self.resolve_var_opt(ty, elem))
} else {
None
let make = |var: Var| self.stdlib.list(var.to_type()).to_type();
match self.collect_var_from_hint(hint.ty(), &make) {
Some(tys) => self.hint_from_types(tys, &hint),
None => None,
}
}

pub fn decompose_tuple<'b>(&self, hint: HintRef<'b, '_>) -> Option<Hint<'b>> {
let elem = self.fresh_var();
let tuple_type = self.stdlib.tuple(elem.to_type()).to_type();
if self.is_subset_eq(&tuple_type, hint.ty()) {
hint.map_ty_opt(|ty| self.resolve_var_opt(ty, elem))
} else {
None
let make = |var: Var| self.stdlib.tuple(var.to_type()).to_type();
match self.collect_var_from_hint(hint.ty(), &make) {
Some(tys) => self.hint_from_types(tys, &hint),
None => None,
}
}

Expand Down
6 changes: 2 additions & 4 deletions pyrefly/lib/test/contextual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,14 @@ kwarg(xs=[B()], ys=[B()])
);

testcase!(
bug = "Both assignments should be allowed. When decomposing the contextual hint, we eagerly resolve vars to the 'first' branch of the union. Note: due to the union's sorted representation, the first branch is not necessarily the first in source order.",
test_contextual_typing_against_unions,
r#"
class A: ...
class B: ...
class B2(B): ...
class C: ...

x: list[A] | list[B] = [B2()] # E: `list[B2]` is not assignable to `list[A] | list[B]`
x: list[A] | list[B] = [B2()]
y: list[B] | list[C] = [B2()]
"#,
);
Expand Down Expand Up @@ -266,7 +265,6 @@ x2: list[A] = True and [B()]
);

testcase!(
bug = "x or y or ... fails due to union hints, see test_contextual_typing_against_unions",
test_context_boolop_soft,
r#"
from typing import TypedDict, assert_type
Expand All @@ -280,7 +278,7 @@ def test(x: list[A] | None, y: list[C] | None, z: TD | None) -> None:
assert_type(x or [B()], list[A])
assert_type(x or [0], list[A] | list[int])
assert_type(x or y or [B()], list[A] | list[C])
assert_type(x or y or [D()], list[A] | list[C]) # TODO # E: assert_type(list[A] | list[C] | list[D], list[A] | list[C]) failed
assert_type(x or y or [D()], list[A] | list[C])
assert_type(z or {"x": 0}, TD)
assert_type(z or {"x": ""}, TD | dict[str, str])
"#,
Expand Down