Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion mockito/invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from collections import deque
from typing import TYPE_CHECKING, Union

from . import matchers, signature
from . import matchers, sameish, signature
from . import verification as verificationModule
from .mock_registry import mock_registry
from .utils import contains_strict
Expand Down Expand Up @@ -628,6 +628,21 @@ def transition_to_chain(self) -> ChainContinuation:
continuation = self.get_continuation()

if isinstance(continuation, ChainContinuation):
if (
continuation.invocation is not self
and sameish.invocations_have_distinct_captors(
self,
continuation.invocation,
)
):
self.forget_self()
raise InvocationError(
"'%s' is already configured with a different captor "
"instance for the same selector. Reuse the same "
"captor() / call_captor() object across chain branches."
% self.method_name
)

self.rollback_if_not_configured_by(continuation)
return continuation

Expand Down
14 changes: 11 additions & 3 deletions mockito/matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@
from abc import ABC, abstractmethod
import functools
import re
from typing import TYPE_CHECKING

if TYPE_CHECKING:
try:
from typing import TypeGuard
except ImportError:
from typing_extensions import TypeGuard

builtin_any = any

__all__ = [
Expand Down Expand Up @@ -473,15 +481,15 @@ def __repr__(self):
return "<CaptorKwargsSentinel: %r>" % self.captor


def is_call_captor(value):
def is_call_captor(value: object) -> 'TypeGuard[CallCaptor]':
return isinstance(value, CallCaptor)


def is_captor_args_sentinel(value):
def is_captor_args_sentinel(value: object) -> 'TypeGuard[CaptorArgsSentinel]':
return isinstance(value, CaptorArgsSentinel)


def is_captor_kwargs_sentinel(value):
def is_captor_kwargs_sentinel(value: object) -> 'TypeGuard[CaptorKwargsSentinel]':
return isinstance(value, CaptorKwargsSentinel)


Expand Down
18 changes: 6 additions & 12 deletions mockito/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from dataclasses import dataclass
from typing import Any, AsyncIterator, Callable, Iterable, Iterator, cast

from . import invocation, signature, utils
from . import invocation, sameish, signature, utils
from . import verification as verificationModule
from .mock_registry import mock_registry
from .patching import Patch, patcher
Expand Down Expand Up @@ -407,12 +407,12 @@ def set_continuation(self, continuation: invocation.ConfiguredContinuation) -> N
def _sameish_invocations(
self, same: invocation.StubbedInvocation
) -> list[invocation.StubbedInvocation]:
"""Find prior stubs that are *mutually* signature-compatible.
"""Find prior stubs that are signature-compatible.

This is used only for continuation bookkeeping (value-vs-chain mode),
not for runtime call dispatch. We intentionally do a symmetric check
(`a.matches(b)` and `b.matches(a)`) to approximate "same signature"
despite one-way matchers like `any()`.
not for runtime call dispatch. The comparison is structural and avoids
executing matcher predicates, so `arg_that(...)` and other custom
matchers cannot crash internal equivalence probing.

Why this exists: repeated selectors such as

Expand All @@ -439,13 +439,7 @@ def _invocations_are_sameish(
left: invocation.StubbedInvocation,
right: invocation.StubbedInvocation,
) -> bool:
# Be conservative in internal equivalence probing: user predicates from
# `arg_that` can throw when evaluated against matcher/sentinel objects.
# In this phase, exceptions should mean "not equivalent", not failure.
try:
return left.matches(right) and right.matches(left)
except Exception:
return False
return sameish.invocations_are_sameish(left, right)

def get_original_method(self, method_name: str) -> object | None:
return self._original_methods.get(method_name, None)
Expand Down
195 changes: 195 additions & 0 deletions mockito/sameish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from . import matchers

if TYPE_CHECKING:
from .invocation import StubbedInvocation


def invocations_are_sameish(
left: StubbedInvocation,
right: StubbedInvocation,
) -> bool:
"""Structural signature-compatibility checks for continuation bookkeeping.

Intentionally avoids executing user-provided matcher predicates
(e.g. `arg_that(...)) while comparing stub signatures.
"""

return (
_params_are_sameish(left.params, right.params)
and _named_params_are_sameish(
left.named_params,
right.named_params,
)
)


def invocations_have_distinct_captors(
left: StubbedInvocation,
right: StubbedInvocation,
) -> bool:
"""Return True when equivalent selectors bind different captor instances."""

for left_value, right_value in zip(left.params, right.params):
if _values_bind_distinct_captors(left_value, right_value):
return True

for key in set(left.named_params) & set(right.named_params):
if _values_bind_distinct_captors(
left.named_params[key],
right.named_params[key],
):
return True

return False


def _params_are_sameish(left: tuple, right: tuple) -> bool:
if len(left) != len(right):
return False

return all(
_values_are_sameish(left_value, right_value)
for left_value, right_value in zip(left, right)
)


def _named_params_are_sameish(left: dict, right: dict) -> bool:
if set(left) != set(right):
return False

return all(
_values_are_sameish(left[key], right[key])
for key in left
)


def _values_are_sameish(left: object, right: object) -> bool:
if left is right:
return True

if left is Ellipsis or right is Ellipsis:
return left is right

if matchers.is_call_captor(left) and matchers.is_call_captor(right):
return True

if matchers.is_call_captor(left) or matchers.is_call_captor(right):
return False

if (
matchers.is_captor_args_sentinel(left)
and matchers.is_captor_args_sentinel(right)
):
return _values_are_sameish(left.captor.matcher, right.captor.matcher)

if (
matchers.is_captor_kwargs_sentinel(left)
and matchers.is_captor_kwargs_sentinel(right)
):
return _values_are_sameish(left.captor.matcher, right.captor.matcher)

if (
matchers.is_captor_args_sentinel(left)
or matchers.is_captor_args_sentinel(right)
or matchers.is_captor_kwargs_sentinel(left)
or matchers.is_captor_kwargs_sentinel(right)
):
return False

if isinstance(left, matchers.Matcher) and isinstance(right, matchers.Matcher):
return _matchers_are_sameish(left, right)

if isinstance(left, matchers.Matcher) or isinstance(right, matchers.Matcher):
return False

return _equals_or_identity(left, right)


def _matchers_are_sameish( # noqa: C901
left: matchers.Matcher,
right: matchers.Matcher,
) -> bool:
if left is right:
return True

if type(left) is not type(right):
return False

if isinstance(left, matchers.Any) and isinstance(right, matchers.Any):
return _equals_or_identity(left.wanted_type, right.wanted_type)

if (
isinstance(left, matchers.ValueMatcher)
and isinstance(right, matchers.ValueMatcher)
):
return _values_are_sameish(left.value, right.value)

if (
isinstance(left, (matchers.And, matchers.Or))
and isinstance(right, (matchers.And, matchers.Or))
):
return _params_are_sameish(
tuple(left.matchers),
tuple(right.matchers),
)

if isinstance(left, matchers.Not) and isinstance(right, matchers.Not):
return _values_are_sameish(left.matcher, right.matcher)

if isinstance(left, matchers.ArgThat) and isinstance(right, matchers.ArgThat):
return left.predicate is right.predicate

if isinstance(left, matchers.Contains) and isinstance(right, matchers.Contains):
return _values_are_sameish(left.sub, right.sub)

if isinstance(left, matchers.Matches) and isinstance(right, matchers.Matches):
return (
left.regex.pattern == right.regex.pattern
and left.flags == right.flags
)

if (
isinstance(left, matchers.ArgumentCaptor)
and isinstance(right, matchers.ArgumentCaptor)
):
return _values_are_sameish(left.matcher, right.matcher)

return _equals_or_identity(left, right)


def _values_bind_distinct_captors(left: object, right: object) -> bool:
left_binding = _captor_binding(left)
right_binding = _captor_binding(right)

return (
left_binding is not None
and right_binding is not None
and left_binding is not right_binding
)


def _captor_binding(value: object) -> object | None:
if matchers.is_call_captor(value):
return value

if isinstance(value, matchers.ArgumentCaptor):
return value

if matchers.is_captor_args_sentinel(value):
return value.captor

if matchers.is_captor_kwargs_sentinel(value):
return value.captor

return None


def _equals_or_identity(left: object, right: object) -> bool:
try:
return left == right
except Exception:
return left is right
Loading
Loading