Skip to content

Commit 9e335a0

Browse files
committed
Track positional and keyword arguments
1 parent 6f911c4 commit 9e335a0

File tree

2 files changed

+195
-20
lines changed

2 files changed

+195
-20
lines changed

pyrefly/lib/state/lsp.rs

Lines changed: 102 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ use crate::types::callable::Params;
8686
use crate::types::module::ModuleType;
8787
use crate::types::types::Type;
8888

89+
#[derive(Debug, Clone)]
90+
enum ActiveArgument {
91+
Positional(usize),
92+
Keyword(Name),
93+
Next(usize),
94+
}
95+
8996
fn default_true() -> bool {
9097
true
9198
}
@@ -844,25 +851,47 @@ impl<'a> Transaction<'a> {
844851
fn visit_finding_signature_range(
845852
x: &Expr,
846853
find: TextSize,
847-
res: &mut Option<(TextRange, TextRange, usize)>,
854+
res: &mut Option<(TextRange, TextRange, ActiveArgument)>,
848855
) {
849856
if let Expr::Call(call) = x
850857
&& call.arguments.range.contains_inclusive(find)
851858
{
859+
// Check positional arguments
852860
for (i, arg) in call.arguments.args.as_ref().iter().enumerate() {
853861
if arg.range().contains_inclusive(find) {
854862
Self::visit_finding_signature_range(arg, find, res);
855863
if res.is_some() {
856864
return;
857865
}
858-
*res = Some((call.func.range(), call.arguments.range, i));
866+
*res = Some((
867+
call.func.range(),
868+
call.arguments.range,
869+
ActiveArgument::Positional(i),
870+
));
871+
return;
872+
}
873+
}
874+
// Check keyword arguments
875+
let positional_count = call.arguments.args.len();
876+
for (j, kw) in call.arguments.keywords.iter().enumerate() {
877+
if kw.range.contains_inclusive(find) {
878+
Self::visit_finding_signature_range(&kw.value, find, res);
879+
if res.is_some() {
880+
return;
881+
}
882+
let active_argument = match kw.arg.as_ref() {
883+
Some(identifier) => ActiveArgument::Keyword(identifier.id.clone()),
884+
None => ActiveArgument::Positional(positional_count + j),
885+
};
886+
*res = Some((call.func.range(), call.arguments.range, active_argument));
887+
return;
859888
}
860889
}
861890
if res.is_none() {
862891
*res = Some((
863892
call.func.range(),
864893
call.arguments.range,
865-
call.arguments.len(),
894+
ActiveArgument::Next(call.arguments.len()),
866895
));
867896
}
868897
} else {
@@ -875,11 +904,17 @@ impl<'a> Transaction<'a> {
875904
&self,
876905
handle: &Handle,
877906
position: TextSize,
878-
) -> Option<(Vec<Type>, usize, usize)> {
907+
) -> Option<(Vec<Type>, usize, ActiveArgument)> {
879908
let mod_module = self.get_ast(handle)?;
880909
let mut res = None;
881910
mod_module.visit(&mut |x| Self::visit_finding_signature_range(x, position, &mut res));
882-
let (callee_range, call_args_range, arg_index) = res?;
911+
let (callee_range, call_args_range, mut active_argument) = res?;
912+
if let ActiveArgument::Next(index) = &mut active_argument
913+
&& let Some(next_index) =
914+
self.count_argument_separators_before(handle, call_args_range, position)
915+
{
916+
*index = next_index;
917+
}
883918
let answers = self.get_answers(handle)?;
884919
if let Some((overloads, chosen_overload_index)) =
885920
answers.get_all_overload_trace(call_args_range)
@@ -888,12 +923,12 @@ impl<'a> Transaction<'a> {
888923
Some((
889924
callables,
890925
chosen_overload_index.unwrap_or_default(),
891-
arg_index,
926+
active_argument,
892927
))
893928
} else {
894929
answers
895930
.get_type_trace(callee_range)
896-
.map(|t| (vec![t], 0, arg_index))
931+
.map(|t| (vec![t], 0, active_argument))
897932
}
898933
}
899934

@@ -903,27 +938,33 @@ impl<'a> Transaction<'a> {
903938
position: TextSize,
904939
) -> Option<SignatureHelp> {
905940
self.get_callables_from_call(handle, position).map(
906-
|(callables, chosen_overload_index, arg_index)| SignatureHelp {
907-
signatures: callables
941+
|(callables, chosen_overload_index, active_argument)| {
942+
let signatures = callables
908943
.into_iter()
909-
.map(|t| Self::create_signature_information(t, arg_index))
910-
.collect_vec(),
911-
active_signature: Some(chosen_overload_index as u32),
912-
active_parameter: Some(arg_index as u32),
944+
.map(|t| Self::create_signature_information(t, &active_argument))
945+
.collect_vec();
946+
let active_parameter = signatures
947+
.get(chosen_overload_index)
948+
.and_then(|info| info.active_parameter);
949+
SignatureHelp {
950+
signatures,
951+
active_signature: Some(chosen_overload_index as u32),
952+
active_parameter,
953+
}
913954
},
914955
)
915956
}
916957

917-
fn create_signature_information(type_: Type, arg_index: usize) -> SignatureInformation {
958+
fn create_signature_information(
959+
type_: Type,
960+
active_argument: &ActiveArgument,
961+
) -> SignatureInformation {
918962
let type_ = type_.deterministic_printing();
919963
let label = type_.as_hover_string();
920964
let (parameters, active_parameter) =
921965
if let Some(params) = Self::normalize_singleton_function_type_into_params(type_) {
922-
let active_parameter = if arg_index < params.len() {
923-
Some(arg_index as u32)
924-
} else {
925-
None
926-
};
966+
let active_parameter =
967+
Self::active_parameter_index(&params, active_argument).map(|idx| idx as u32);
927968
(
928969
Some(params.map(|param| ParameterInformation {
929970
label: ParameterLabel::Simple(format!("{param}")),
@@ -942,6 +983,46 @@ impl<'a> Transaction<'a> {
942983
}
943984
}
944985

986+
fn active_parameter_index(params: &[Param], active_argument: &ActiveArgument) -> Option<usize> {
987+
match active_argument {
988+
ActiveArgument::Positional(index) | ActiveArgument::Next(index) => {
989+
(*index < params.len()).then_some(*index)
990+
}
991+
ActiveArgument::Keyword(name) => params.iter().position(|param| {
992+
Self::parameter_name(param).is_some_and(|param_name| param_name == name)
993+
}),
994+
}
995+
}
996+
997+
fn parameter_name(param: &Param) -> Option<&Name> {
998+
match param {
999+
Param::PosOnly(Some(name), ..)
1000+
| Param::Pos(name, ..)
1001+
| Param::VarArg(Some(name), ..)
1002+
| Param::KwOnly(name, ..)
1003+
| Param::Kwargs(Some(name), ..) => Some(name),
1004+
_ => None,
1005+
}
1006+
}
1007+
1008+
fn count_argument_separators_before(
1009+
&self,
1010+
handle: &Handle,
1011+
arguments_range: TextRange,
1012+
position: TextSize,
1013+
) -> Option<usize> {
1014+
let module = self.get_module_info(handle)?;
1015+
let contents = module.contents();
1016+
let start = arguments_range.start().to_usize();
1017+
let end = arguments_range.end().to_usize().min(contents.len());
1018+
if start >= end {
1019+
return Some(0);
1020+
}
1021+
let pos = position.to_usize().clamp(start, end);
1022+
let slice = &contents[start..pos];
1023+
Some(slice.bytes().filter(|&b| b == b',').count())
1024+
}
1025+
9451026
fn normalize_singleton_function_type_into_params(type_: Type) -> Option<Vec<Param>> {
9461027
let callable = type_.to_callable()?;
9471028
// We will drop the self parameter for signature help
@@ -2169,11 +2250,12 @@ impl<'a> Transaction<'a> {
21692250
position: TextSize,
21702251
completions: &mut Vec<CompletionItem>,
21712252
) {
2172-
if let Some((callables, chosen_overload_index, arg_index)) =
2253+
if let Some((callables, chosen_overload_index, active_argument)) =
21732254
self.get_callables_from_call(handle, position)
21742255
&& let Some(callable) = callables.get(chosen_overload_index)
21752256
&& let Some(params) =
21762257
Self::normalize_singleton_function_type_into_params(callable.clone())
2258+
&& let Some(arg_index) = Self::active_parameter_index(&params, &active_argument)
21772259
&& let Some(param) = params.get(arg_index)
21782260
{
21792261
Self::add_literal_completions_from_type(param.as_type(), completions);

pyrefly/lib/test/lsp/signature_help.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,99 @@ Signature Help Result: active=0
128128
);
129129
}
130130

131+
#[test]
132+
fn positional_arguments_test() {
133+
let code = r#"
134+
def f(x: int, y: int, z: int) -> None: ...
135+
136+
f(1,,)
137+
# ^
138+
f(1,,)
139+
# ^
140+
f(1,,3)
141+
# ^
142+
"#;
143+
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
144+
assert_eq!(
145+
r#"
146+
# main.py
147+
4 | f(1,,)
148+
^
149+
Signature Help Result: active=0
150+
- def f(
151+
x: int,
152+
y: int,
153+
z: int
154+
) -> None, parameters=[x: int, y: int, z: int], active parameter = 1
155+
156+
6 | f(1,,)
157+
^
158+
Signature Help Result: active=0
159+
- def f(
160+
x: int,
161+
y: int,
162+
z: int
163+
) -> None, parameters=[x: int, y: int, z: int], active parameter = 2
164+
165+
8 | f(1,,3)
166+
^
167+
Signature Help Result: active=0
168+
- def f(
169+
x: int,
170+
y: int,
171+
z: int
172+
) -> None, parameters=[x: int, y: int, z: int], active parameter = 1
173+
"#
174+
.trim(),
175+
report.trim(),
176+
);
177+
}
178+
179+
#[test]
180+
fn keyword_arguments_test() {
181+
let code = r#"
182+
def f(a: str, b: int) -> None: ...
183+
184+
f(a)
185+
# ^
186+
f(a=)
187+
# ^
188+
f(b=)
189+
# ^
190+
"#;
191+
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
192+
assert_eq!(
193+
r#"
194+
# main.py
195+
4 | f(a)
196+
^
197+
Signature Help Result: active=0
198+
- def f(
199+
a: str,
200+
b: int
201+
) -> None, parameters=[a: str, b: int], active parameter = 0
202+
203+
6 | f(a=)
204+
^
205+
Signature Help Result: active=0
206+
- def f(
207+
a: str,
208+
b: int
209+
) -> None, parameters=[a: str, b: int], active parameter = 0
210+
211+
8 | f(b=)
212+
^
213+
Signature Help Result: active=0
214+
- def f(
215+
a: str,
216+
b: int
217+
) -> None, parameters=[a: str, b: int], active parameter = 1
218+
"#
219+
.trim(),
220+
report.trim(),
221+
);
222+
}
223+
131224
#[test]
132225
fn simple_incomplete_function_call_test() {
133226
let code = r#"

0 commit comments

Comments
 (0)