Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions mockito/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,11 @@ def _ensure_target_is_callable(theMock: Mock, method_name: str) -> None:
if not was_in_spec and target is None:
return

if _should_continue_with_stubbed_invocation(target, allow_classes=True):
if _should_continue_with_stubbed_invocation(
target,
allow_classes=True,
spec=theMock.spec,
):
return

raise invocation.InvocationError("'%s' is not callable." % method_name)
Expand All @@ -253,7 +257,7 @@ def _ensure_target_is_not_callable(theMock: Mock, method_name: str) -> None:
else:
return

if _should_continue_with_stubbed_invocation(value):
if _should_continue_with_stubbed_invocation(value, spec=spec):
raise invocation.InvocationError(
f"expected an invocation of '{method_name}'"
)
Expand All @@ -262,6 +266,7 @@ def _ensure_target_is_not_callable(theMock: Mock, method_name: str) -> None:
def _should_continue_with_stubbed_invocation(
value: object,
allow_classes: bool = False,
spec: object | None = None,
) -> bool:
if (
inspect.isfunction(value)
Expand All @@ -276,12 +281,21 @@ def _should_continue_with_stubbed_invocation(
):
return True

# Generic callable fallback, but keep custom descriptors/property-like
# attributes on the property stubbing path.
# For class specs, callable descriptors (objects implementing both
# `__call__` and `__get__`) are generally meant to be stubbed through
# the property path. For non-class specs (e.g. module attributes such as
# `numpy.vstack`), `__get__` should not disqualify callable targets.
treat_callable_descriptors_as_non_callable = inspect.isclass(spec)

# Generic callable fallback, with optional handling for callable
# descriptor-like objects (`__call__` + `__get__`).
return (
callable(value)
and (allow_classes or not inspect.isclass(value))
and not hasattr(value, '__get__')
and (
not treat_callable_descriptors_as_non_callable
or not hasattr(value, '__get__')
)
)


Expand Down
5 changes: 5 additions & 0 deletions tests/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ def testEnsureNumpyArrayAllowedWhenCalling(self):
when(module).one_arg(Ellipsis).thenReturn('yep')
assert module.one_arg(array) == 'yep'


def test_np_vstack_is_callable():
when(np).vstack(...).thenReturn("ok.")

assert np.vstack([np.array([1]), np.array([2])]) == "ok."
Loading