Skip to content

Commit 5d74ad5

Browse files
authored
Merge pull request #21419 from github/tausbn/python-improve-overloaded-method-resolution
Python: Improve modelling of overloaded methods
2 parents be9c1d0 + f2bad1e commit 5d74ad5

File tree

5 files changed

+92
-1
lines changed

5 files changed

+92
-1
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
category: minorAnalysis
3+
---
4+
5+
- The call graph resolution no longer considers methods marked using [`@typing.overload`](https://typing.python.org/en/latest/spec/overload.html#overloads) as valid targets. This ensures that only the method that contains the actual implementation gets resolved as a target.

python/ql/lib/semmle/python/dataflow/new/internal/DataFlowDispatch.qll

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,25 @@ predicate hasContextmanagerDecorator(Function func) {
304304
)
305305
}
306306

307+
/**
308+
* Holds if the function `func` has a `typing.overload` decorator.
309+
* Such functions are type stubs that declare an overload signature but are
310+
* not the actual implementation.
311+
*
312+
* Normally we would want to model this using API graphs for more precision, but since this
313+
* predicate is used in the call graph computation, we have to use a more syntactic approach.
314+
*/
315+
overlay[local]
316+
private predicate hasOverloadDecorator(Function func) {
317+
exists(ControlFlowNode overload |
318+
overload.(NameNode).getId() = "overload" and overload.(NameNode).isGlobal()
319+
or
320+
overload.(AttrNode).getObject("overload").(NameNode).isGlobal()
321+
|
322+
func.getADecorator() = overload.getNode()
323+
)
324+
}
325+
307326
// =============================================================================
308327
// Callables
309328
// =============================================================================
@@ -849,7 +868,8 @@ private Class getNextClassInMro(Class cls) {
849868
*/
850869
Function findFunctionAccordingToMro(Class cls, string name) {
851870
result = cls.getAMethod() and
852-
result.getName() = name
871+
result.getName() = name and
872+
not hasOverloadDecorator(result)
853873
or
854874
not class_has_method(cls, name) and
855875
result = findFunctionAccordingToMro(getNextClassInMro(cls), name)
@@ -891,6 +911,7 @@ Class getNextClassInMroKnownStartingClass(Class cls, Class startingClass) {
891911
Function findFunctionAccordingToMroKnownStartingClass(Class cls, Class startingClass, string name) {
892912
result = cls.getAMethod() and
893913
result.getName() = name and
914+
not hasOverloadDecorator(result) and
894915
cls = getADirectSuperclass*(startingClass)
895916
or
896917
not class_has_method(cls, name) and

python/ql/test/library-tests/dataflow/calls-overload/OverloadCallTest.expected

Whitespace-only changes.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/**
2+
* Test that `@typing.overload` stubs are not resolved as call targets.
3+
*/
4+
5+
import python
6+
import semmle.python.dataflow.new.internal.DataFlowDispatch as DataFlowDispatch
7+
import utils.test.InlineExpectationsTest
8+
9+
module OverloadCallTest implements TestSig {
10+
string getARelevantTag() { result = "init" }
11+
12+
predicate hasActualResult(Location location, string element, string tag, string value) {
13+
exists(location.getFile().getRelativePath()) and
14+
exists(DataFlowDispatch::DataFlowCall call, Function target |
15+
location = call.getLocation() and
16+
element = call.toString() and
17+
DataFlowDispatch::resolveCall(call.getNode(), target, _) and
18+
target.getName() = "__init__"
19+
|
20+
value = target.getQualifiedName() + ":" + target.getLocation().getStartLine().toString() and
21+
tag = "init"
22+
)
23+
}
24+
}
25+
26+
import MakeTest<OverloadCallTest>
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import typing
2+
3+
4+
class OverloadedInit:
5+
@typing.overload
6+
def __init__(self, x: int) -> None: ...
7+
8+
@typing.overload
9+
def __init__(self, x: str, y: str) -> None: ...
10+
11+
def __init__(self, x, y=None):
12+
pass
13+
14+
OverloadedInit(1) # $ init=OverloadedInit.__init__:11
15+
OverloadedInit("a", "b") # $ init=OverloadedInit.__init__:11
16+
17+
18+
from typing import overload
19+
20+
21+
class OverloadedInitFromImport:
22+
@overload
23+
def __init__(self, x: int) -> None: ...
24+
25+
@overload
26+
def __init__(self, x: str, y: str) -> None: ...
27+
28+
def __init__(self, x, y=None):
29+
pass
30+
31+
OverloadedInitFromImport(1) # $ init=OverloadedInitFromImport.__init__:28
32+
OverloadedInitFromImport("a", "b") # $ init=OverloadedInitFromImport.__init__:28
33+
34+
35+
class NoOverloads:
36+
def __init__(self, x):
37+
pass
38+
39+
NoOverloads(1) # $ init=NoOverloads.__init__:36

0 commit comments

Comments
 (0)