Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions pyrefly/lib/alt/class/class_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use ruff_python_ast::name::Name;
use ruff_text_size::TextRange;
use starlark_map::small_map::SmallMap;
use starlark_map::small_set::SmallSet;
use vec1::Vec1;
use vec1::vec1;

use crate::alt::answers::LookupAnswer;
Expand Down Expand Up @@ -1264,6 +1265,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
} else {
value_ty
};
let descriptor_value_ty = value_ty.clone();

// Types provided in annotations shadow inferred types
let ty = if let Some(ann) = annotation {
Expand Down Expand Up @@ -1329,6 +1331,17 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
ty
};

if descriptor.is_some() {
self.validate_descriptor_annotation(
class,
name,
annotation,
&descriptor_value_ty,
range,
errors,
);
}

// Pin any vars in the type: leaking a var in a class field is particularly
// likely to lead to data races where downstream uses can pin inconsistently.
//
Expand Down Expand Up @@ -1470,6 +1483,165 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
None
}

fn collect_descriptor_classes_from_type(&self, out: &mut SmallSet<ClassType>, ty: &Type) {
match ty {
Type::ClassType(cls) => {
out.insert(cls.clone());
}
Type::Union(types) => {
for ty in types {
self.collect_descriptor_classes_from_type(out, ty);
}
}
_ => {}
}
}

fn descriptor_info_from_class(&self, cls: ClassType, range: TextRange) -> Option<Descriptor> {
let getter = self
.get_class_member(cls.class_object(), &dunder::GET)
.is_some();
let setter = self
.get_class_member(cls.class_object(), &dunder::SET)
.is_some();
if getter || setter {
Some(Descriptor {
range,
cls,
getter,
setter,
})
} else {
None
}
}

fn validate_descriptor_annotation(
&self,
class: &Class,
name: &Name,
annotation: Option<&Annotation>,
value_ty: &Type,
range: TextRange,
errors: &ErrorCollector,
) {
let Some(annotation) = annotation else {
return;
};
let expected_ty = annotation.get_type().clone();
if expected_ty.is_any() || expected_ty.is_error() {
return;
}

let mut descriptor_classes = SmallSet::new();
self.collect_descriptor_classes_from_type(&mut descriptor_classes, value_ty);

if descriptor_classes.is_empty() {
return;
}

let class_type = self.as_class_type_unchecked(class);
let mut messages = Vec::new();
let ignore_errors = self.error_swallower();

for cls in descriptor_classes {
if let Type::ClassType(expected_cls) = &expected_ty
&& expected_cls.class_object() == cls.class_object()
{
continue;
}
let Some(desc_info) = self.descriptor_info_from_class(cls.clone(), range) else {
continue;
};

if desc_info.getter {
if let Some(getter_method) =
self.resolve_descriptor_getter(&desc_info, &ignore_errors)
{
let instance_ret = self.call_descriptor_getter(
getter_method.clone(),
DescriptorBase::Instance(class_type.clone()),
range,
&ignore_errors,
None,
);
if !self.is_subset_eq(&instance_ret, &expected_ty) {
let actual_display = self
.for_display(instance_ret.deterministic_printing())
.deterministic_printing();
let expected_display = self
.for_display(expected_ty.clone().deterministic_printing())
.deterministic_printing();
messages.push(format!(
"Descriptor `{}` returns `{}` when accessed on instances of `{}`, which is not assignable to `{}`",
name,
actual_display,
class.name(),
expected_display,
));
}

let class_ret = self.call_descriptor_getter(
getter_method,
DescriptorBase::ClassDef(class.dupe()),
range,
&ignore_errors,
None,
);
if !self.is_subset_eq(&class_ret, &expected_ty) {
let actual_display = self
.for_display(class_ret.deterministic_printing())
.deterministic_printing();
let expected_display = self
.for_display(expected_ty.clone().deterministic_printing())
.deterministic_printing();
messages.push(format!(
"Descriptor `{}` returns `{}` when accessed on class `{}`, which is not assignable to `{}`",
name,
actual_display,
class.name(),
expected_display,
));
}
}
}

if desc_info.setter {
if let Some(setter_method) =
self.resolve_descriptor_setter(&desc_info, &ignore_errors)
{
let setter_check_errors = self.error_collector();
self.call_descriptor_setter(
setter_method.clone(),
class_type.clone(),
CallArg::ty(&expected_ty, range),
range,
&setter_check_errors,
None,
);
if !setter_check_errors.is_empty() {
let setter_display = self
.for_display(setter_method.deterministic_printing())
.deterministic_printing();
let expected_display = self
.for_display(expected_ty.clone().deterministic_printing())
.deterministic_printing();
messages.push(format!(
"Descriptor `{}` setter `{}` does not accept annotated type `{}`",
name, setter_display, expected_display,
));
}
}
}
}

if !messages.is_empty() {
if let Ok(msgs) = Vec1::try_from_vec(messages) {
errors.add(range, ErrorInfo::Kind(ErrorKind::BadAssignment), msgs);
}
}
}

/// Return (type of first inherited field, first inherited annotation). May not be from the same class!
/// For example, in:
/// class A:
Expand Down
26 changes: 26 additions & 0 deletions pyrefly/lib/test/descriptors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,29 @@ class A:
return self.d
"#,
);

testcase!(
test_descriptor_incompatible_get_return_annotation,
r#"
from typing import Literal
class A:
def __get__(self, obj, objtype) -> A | int: ...
class B(A):
def __get__(self, obj, objtype) -> Literal[1]: ...
class C:
a: A = B() # E: Descriptor `a` returns `Literal[1]` when accessed on instances of `C`, which is not assignable to `A`
"#,
);

testcase!(
test_descriptor_incompatible_set_annotation,
r#"
from typing import Any
class A:
def __set__(self, obj, value: Any) -> None: ...
class B(A):
def __set__(self, obj, value: str) -> None: ...
class C:
a: A = B() # E: Descriptor `a` setter `BoundMethod[B, (self: B, obj: Unknown, value: str) -> None]` does not accept annotated type `A`
"#,
);