Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyrefly/lib/alt/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
Type::Literal(lit) => acc.push(AttributeBase1::ClassInstance(
lit.general_class_type(self.stdlib).clone(),
)),
Type::Type(box Type::Literal(lit)) => acc.push(AttributeBase1::ClassObject(
ClassBase::ClassType(lit.general_class_type(self.stdlib).clone()),
)),
Type::Type(box Type::LiteralString) => acc.push(AttributeBase1::ClassObject(
ClassBase::ClassType(self.stdlib.str().clone()),
)),
Type::TypeGuard(_) | Type::TypeIs(_) => {
acc.push(AttributeBase1::ClassInstance(self.stdlib.bool().clone()))
}
Expand Down
35 changes: 30 additions & 5 deletions pyrefly/lib/alt/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use crate::types::class::ClassType;
use crate::types::keywords::KwCall;
use crate::types::keywords::TypeMap;
use crate::types::literal::Lit;
use crate::types::type_var::PreInferenceVariance;
use crate::types::type_var::Restriction;
use crate::types::typed_dict::TypedDict;
use crate::types::types::AnyStyle;
Expand Down Expand Up @@ -444,6 +445,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
);
self.solver()
.finish_class_targs(&mut ctor_targs, self.uniques);
self.promote_invariant_targs(&mut ctor_targs);
ret.subst_mut(&ctor_targs.substitution_map());
Some(ret)
}
Expand Down Expand Up @@ -512,9 +514,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
// According to the spec, the actual type (as opposed to the class under construction)
// should take priority. However, if the actual type comes from a type error or an implicit
// Any, using the class under construction is still more useful.
self.solver()
.finish_class_targs(cls.targs_mut(), self.uniques);
return ret.subst(&cls.targs().substitution_map());
{
let targs = cls.targs_mut();
self.solver().finish_class_targs(targs, self.uniques);
self.promote_invariant_targs(targs);
}
let substitution = cls.targs().substitution_map();
return ret.subst(&substitution);
}
(true, has_errors)
} else {
Expand Down Expand Up @@ -549,8 +555,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
// Not quite an overload, but close enough
self.record_overload_trace_from_type(range, init_method);
}
self.solver()
.finish_class_targs(cls.targs_mut(), self.uniques);
{
let targs = cls.targs_mut();
self.solver().finish_class_targs(targs, self.uniques);
self.promote_invariant_targs(targs);
}
if let Some(mut ret) = dunder_new_ret {
ret.subst_mut(&cls.targs().substitution_map());
ret
Expand All @@ -559,6 +568,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
}

fn promote_invariant_targs(&self, targs: &mut TArgs) {
targs.iter_paired_mut().for_each(|(param, targ)| {
if !matches!(param.variance, PreInferenceVariance::PCovariant) {
*targ = targ.clone().promote_literals(self.stdlib);
}
});
}

fn construct_typed_dict(
&self,
mut typed_dict: TypedDict,
Expand Down Expand Up @@ -597,6 +614,10 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
);
self.solver()
.finish_class_targs(typed_dict.targs_mut(), self.uniques);
typed_dict.targs_mut().as_mut().iter_mut().for_each(|targ| {
let promoted = targ.clone().promote_literals(self.stdlib);
*targ = promoted;
});
Type::TypedDict(typed_dict)
}

Expand Down Expand Up @@ -796,6 +817,10 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
style.propagate()
}
};
let res = match res {
Type::Union(members) => self.unions(members),
other => other,
};
if let Some(func_metadata) = kw_metadata {
let mut kws = TypeMap::new();
for kw in keywords {
Expand Down
47 changes: 38 additions & 9 deletions pyrefly/lib/solver/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,22 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> {
let t1 = t1.clone();
drop(v1_ref);
drop(variables);
self.is_subset_eq(&t1, t2)
match self.is_subset_eq(&t1, t2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section looks unsafe to me, I don't see how it could be okay to change an Answer after it's already been decided.

Do you remember why this was needed? It may be a hint that propagating literals isn't going to work well at least at this time.

Ok(()) => Ok(()),
Err(err) => {
let t1_promoted =
t1.clone().promote_literals(self.type_order.stdlib());
if t1_promoted != t1 {
self.solver
.variables
.lock()
.update(*v1, Variable::Answer(t1_promoted.clone()));
self.is_subset_eq(&t1_promoted, t2)
} else {
Err(err)
}
}
}
}
Variable::Quantified(q) => {
let name = q.name.clone();
Expand Down Expand Up @@ -1238,24 +1253,38 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> {
let t2 = t2.clone();
drop(v2_ref);
drop(variables);
self.is_subset_eq(t1, &t2)
match self.is_subset_eq(t1, &t2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same thing here - I don't think we would ever want to change an answer after it has been decided already)

Ok(()) => Ok(()),
Err(err) => {
let t2_promoted =
t2.clone().promote_literals(self.type_order.stdlib());
if t2_promoted != t2 {
self.solver
.variables
.lock()
.update(*v2, Variable::Answer(t2_promoted.clone()));
self.is_subset_eq(t1, &t2_promoted)
} else {
Err(err)
}
}
}
}
Variable::Quantified(q) => {
let t1_p = t1.clone().promote_literals(self.type_order.stdlib());
let name = q.name.clone();
let bound = q.restriction().as_type(self.type_order.stdlib());
drop(v2_ref);
variables.update(*v2, Variable::Answer(t1_p.clone()));
variables.update(*v2, Variable::Answer(t1.clone()));
drop(variables);
if let Err(err_p) = self.is_subset_eq(&t1_p, &bound) {
// If the promoted type fails, try again with the original type, in case the bound itself is literal.
// This could be more optimized, but errors are rare, so this code path should not be hot.
if self.is_subset_eq(t1, &bound).is_err() {
// Fall back to the promoted type if the literal version violates the bound.
self.solver
.variables
.lock()
.update(*v2, Variable::Answer(t1.clone()));
if self.is_subset_eq(t1, &bound).is_err() {
// If the original type is also an error, use the promoted type.
.update(*v2, Variable::Answer(t1_p.clone()));
if let Err(err_p) = self.is_subset_eq(&t1_p, &bound) {
// If the promoted type also violates the bound, record the error.
self.solver
.variables
.lock()
Expand Down
8 changes: 4 additions & 4 deletions pyrefly/lib/test/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1365,11 +1365,11 @@ assert_type(A.f(A[int]()), int)
testcase!(
test_access_generic_method_using_class_param_on_class,
r#"
from typing import assert_type, reveal_type, Any
from typing import Literal, assert_type, reveal_type, Any
class A[T]:
def f[S](self, x: S) -> tuple[S, T]: ...
reveal_type(A.f) # E: revealed type: [T, S](self: A[T], x: S) -> tuple[S, T]
assert_type(A.f(A[int](), ""), tuple[str, int]) # E: assert_type(tuple[str, Any], tuple[str, int])
assert_type(A.f(A[int](), ""), tuple[str, int]) # E: assert_type(tuple[Literal[''], Any], tuple[str, int])
"#,
);

Expand All @@ -1391,7 +1391,7 @@ assert_type(A.f(A[int]()), int)
testcase!(
test_access_overloaded_staticmethod_using_class_param_on_class,
r#"
from typing import assert_type, reveal_type, overload, Any
from typing import Literal, assert_type, reveal_type, overload, Any
class A[T]:
@overload
@staticmethod
Expand All @@ -1403,7 +1403,7 @@ class A[T]:
def f(x = None) -> Any: ...
reveal_type(A.f) # E: revealed type: Overload[(x: None = ...) -> None, [T](x: T) -> T]
assert_type(A.f(), None)
assert_type(A.f(0), int)
assert_type(A.f(0), Literal[0])
"#,
);

Expand Down
40 changes: 20 additions & 20 deletions pyrefly/lib/test/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,19 +456,19 @@ test(1, 2, 3, *[4]) # OK
testcase!(
test_splat_unpacked_args,
r#"
from typing import assert_type
from typing import Literal, assert_type

def test1(*args: *tuple[int, int, int]): ...
test1(*(1, 2, 3)) # OK
test1(*(1, 2)) # E: Unpacked argument `tuple[Literal[1], Literal[2]]` is not assignable to parameter `*args` with type `tuple[int, int, int]` in function `test1`
test1(*(1, 2, 3, 4)) # E: Unpacked argument `tuple[Literal[1], Literal[2], Literal[3], Literal[4]]` is not assignable to parameter `*args` with type `tuple[int, int, int]` in function `test1`

def test2[*T](*args: *tuple[int, *T, int]) -> tuple[*T]: ...
assert_type(test2(*(1, 2, 3)), tuple[int])
assert_type(test2(*(1, 2, 3)), tuple[Literal[2]])
assert_type(test2(*(1, 2)), tuple[()])
assert_type(test2(*(1, 2, 3, 4)), tuple[int, int])
assert_type(test2(1, 2, *(3, 4), 5), tuple[int, int, int])
assert_type(test2(1, *(2, 3), *("4", 5)), tuple[int, int, str])
assert_type(test2(*(1, 2, 3, 4)), tuple[Literal[2], Literal[3]])
assert_type(test2(1, 2, *(3, 4), 5), tuple[Literal[2], Literal[3], Literal[4]])
assert_type(test2(1, *(2, 3), *("4", 5)), tuple[Literal[2], Literal[3], Literal['4']])
assert_type(test2(1, *[2, 3], 4), tuple[int, ...])
test2(1, *(2, 3), *(4, "5")) # E: Unpacked argument `tuple[Literal[1], Literal[2], Literal[3], Literal[4], Literal['5']]` is not assignable to parameter `*args` with type `tuple[int, *@_, int]` in function `test2`
"#,
Expand Down Expand Up @@ -933,13 +933,13 @@ c3: Callable[[C], C] = f # OK
testcase!(
test_return_generic_callable,
r#"
from typing import assert_type, Callable
from typing import Literal, assert_type, Callable
def f[T]() -> Callable[[T], T]:
return lambda x: x

g = f()
assert_type(g(0), int)
assert_type(g(""), str)
assert_type(g(0), Literal[0])
assert_type(g(""), Literal[''])

@f()
def h(x: int) -> int:
Expand All @@ -951,51 +951,51 @@ assert_type(h(0), int)
testcase!(
test_generic_callable_union,
r#"
from typing import assert_type, Callable
from typing import Literal, assert_type, Callable
def f[T]() -> Callable[[T], T] | Callable[[T], list[T]]: ...
g = f()
assert_type(g(0), int | list[int])
assert_type(g(""), str | list[str])
assert_type(g(0), Literal[0] | list[Literal[0]])
assert_type(g(""), Literal[''] | list[Literal['']])
"#,
);

testcase!(
test_callable_returns_callable_returns_callable,
r#"
from typing import assert_type, Callable
from typing import Literal, assert_type, Callable

def f[T]() -> Callable[[], Callable[[T], T]]:
def f():
return lambda x: x
return f

g = f()()
assert_type(g(0), int)
assert_type(g(""), str)
assert_type(g(0), Literal[0])
assert_type(g(""), Literal[''])

"#,
);

testcase!(
test_return_substituted_callable,
r#"
from typing import assert_type, Callable
from typing import Literal, assert_type, Callable
def f[T](x: T) -> Callable[[T], T]: ...
g = f(0)
assert_type(g(0), int)
assert_type(g(""), int) # E: `Literal['']` is not assignable to parameter with type `int`
assert_type(g(0), Literal[0])
assert_type(g(""), Literal[0]) # E: `Literal['']` is not assignable to parameter with type `Literal[0]`
"#,
);

testcase!(
test_generic_callable_or_none,
r#"
from typing import assert_type, Callable
from typing import Literal, assert_type, Callable
def f[T]() -> Callable[[T], T] | None: ...
g = f()
if g:
assert_type(g(0), int)
assert_type(g(""), str)
assert_type(g(0), Literal[0])
assert_type(g(""), Literal[''])
"#,
);

Expand Down
12 changes: 6 additions & 6 deletions pyrefly/lib/test/calls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,30 @@ use crate::testcase;
testcase!(
test_generic_call_happy_case,
r#"
from typing import Never
from typing import Never, Literal
def force_error(x: Never) -> None: ...
def f[S, T](x: S, y: T) -> tuple[S, T]: ...
force_error(f(1, "foo")) # E: Argument `tuple[int, str]` is not assignable to parameter `x`
force_error(f(1, "foo")) # E: Argument `tuple[Literal[1], Literal['foo']]` is not assignable to parameter `x`
"#,
);

testcase!(
test_generic_call_fails_to_solve_output_var_simple,
r#"
from typing import Never
from typing import Never, Literal
def force_error(x: Never) -> None: ...
def f[S, T](x: S) -> tuple[S, T]: ...
force_error(f(1)) # E: Argument `tuple[int, @_]` is not assignable to parameter `x`
force_error(f(1)) # E: Argument `tuple[Literal[1], @_]` is not assignable to parameter `x`
"#,
);

testcase!(
test_generic_call_fails_to_solve_output_var_union_case,
r#"
from typing import Never
from typing import Never, Literal
def force_error(x: Never) -> None: ...
def f[S, T](x: S, y: list[T] | None) -> tuple[S, T]: ...
force_error(f(1, None)) # E: Argument `tuple[int, @_]` is not assignable to parameter `x`
force_error(f(1, None)) # E: Argument `tuple[Literal[1], @_]` is not assignable to parameter `x`
"#,
);

Expand Down
8 changes: 4 additions & 4 deletions pyrefly/lib/test/constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class Meta(type):
class C[T](metaclass=Meta):
def __init__(self, x: T):
pass
assert_type(C(0), C[int]) # Correct, because metaclass call does not instantiate T=str
assert_type(C(0), C[int]) # Correct, preserves the argument type passed to C
"#,
);

Expand Down Expand Up @@ -477,10 +477,10 @@ assert_type(C2(), C2[int])
testcase!(
test_specialize_in_new,
r#"
from typing import assert_type
from typing import Literal, assert_type
class C[T]:
def __new__[T2](cls, x: T2) -> C[T2]: ...
assert_type(C(0), C[int])
assert_type(C(0), C[Literal[0]])
"#,
);

Expand Down Expand Up @@ -514,7 +514,7 @@ assert_type(C(0, "foo"), C[str])
testcase!(
test_new_and_init_generic,
r#"
from typing import Self,assert_type
from typing import Self, assert_type

class Class2[T]:
def __new__(cls, *args, **kwargs) -> Self: ...
Expand Down
6 changes: 3 additions & 3 deletions pyrefly/lib/test/contextual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,15 +513,15 @@ xs[0] = [B()]
testcase!(
test_generic_get_literal,
r#"
from typing import assert_type, TypeVar, Literal
from typing import TypeVar, assert_type

class Foo[T]:
def __init__(self, x: T) -> None: ...
def get(self) -> T: ...

# Should propagate the context to the argument 42
x: Foo[Literal[42]] = Foo(42)
assert_type(x.get(), Literal[42])
x: Foo[int] = Foo(42)
assert_type(x.get(), int)
"#,
);

Expand Down
Loading