diff --git a/mockito/invocation.py b/mockito/invocation.py index b76026e..2b7cd8c 100644 --- a/mockito/invocation.py +++ b/mockito/invocation.py @@ -27,7 +27,7 @@ from collections import deque from typing import TYPE_CHECKING, Union -from . import matchers, signature +from . import matchers, sameish, signature from . import verification as verificationModule from .mock_registry import mock_registry from .utils import contains_strict @@ -628,6 +628,21 @@ def transition_to_chain(self) -> ChainContinuation: continuation = self.get_continuation() if isinstance(continuation, ChainContinuation): + if ( + continuation.invocation is not self + and sameish.invocations_have_distinct_captors( + self, + continuation.invocation, + ) + ): + self.forget_self() + raise InvocationError( + "'%s' is already configured with a different captor " + "instance for the same selector. Reuse the same " + "captor() / call_captor() object across chain branches." + % self.method_name + ) + self.rollback_if_not_configured_by(continuation) return continuation diff --git a/mockito/matchers.py b/mockito/matchers.py index 6cfe184..2cd8f04 100644 --- a/mockito/matchers.py +++ b/mockito/matchers.py @@ -62,6 +62,14 @@ from abc import ABC, abstractmethod import functools import re +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + try: + from typing import TypeGuard + except ImportError: + from typing_extensions import TypeGuard + builtin_any = any __all__ = [ @@ -473,15 +481,15 @@ def __repr__(self): return "" % self.captor -def is_call_captor(value): +def is_call_captor(value: object) -> 'TypeGuard[CallCaptor]': return isinstance(value, CallCaptor) -def is_captor_args_sentinel(value): +def is_captor_args_sentinel(value: object) -> 'TypeGuard[CaptorArgsSentinel]': return isinstance(value, CaptorArgsSentinel) -def is_captor_kwargs_sentinel(value): +def is_captor_kwargs_sentinel(value: object) -> 'TypeGuard[CaptorKwargsSentinel]': return isinstance(value, CaptorKwargsSentinel) diff --git a/mockito/mocking.py b/mockito/mocking.py index 1c25231..8df45b9 100644 --- a/mockito/mocking.py +++ b/mockito/mocking.py @@ -28,7 +28,7 @@ from dataclasses import dataclass from typing import Any, AsyncIterator, Callable, Iterable, Iterator, cast -from . import invocation, signature, utils +from . import invocation, sameish, signature, utils from . import verification as verificationModule from .mock_registry import mock_registry from .patching import Patch, patcher @@ -407,12 +407,12 @@ def set_continuation(self, continuation: invocation.ConfiguredContinuation) -> N def _sameish_invocations( self, same: invocation.StubbedInvocation ) -> list[invocation.StubbedInvocation]: - """Find prior stubs that are *mutually* signature-compatible. + """Find prior stubs that are signature-compatible. This is used only for continuation bookkeeping (value-vs-chain mode), - not for runtime call dispatch. We intentionally do a symmetric check - (`a.matches(b)` and `b.matches(a)`) to approximate "same signature" - despite one-way matchers like `any()`. + not for runtime call dispatch. The comparison is structural and avoids + executing matcher predicates, so `arg_that(...)` and other custom + matchers cannot crash internal equivalence probing. Why this exists: repeated selectors such as @@ -439,13 +439,7 @@ def _invocations_are_sameish( left: invocation.StubbedInvocation, right: invocation.StubbedInvocation, ) -> bool: - # Be conservative in internal equivalence probing: user predicates from - # `arg_that` can throw when evaluated against matcher/sentinel objects. - # In this phase, exceptions should mean "not equivalent", not failure. - try: - return left.matches(right) and right.matches(left) - except Exception: - return False + return sameish.invocations_are_sameish(left, right) def get_original_method(self, method_name: str) -> object | None: return self._original_methods.get(method_name, None) diff --git a/mockito/sameish.py b/mockito/sameish.py new file mode 100644 index 0000000..1a9b7b3 --- /dev/null +++ b/mockito/sameish.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from . import matchers + +if TYPE_CHECKING: + from .invocation import StubbedInvocation + + +def invocations_are_sameish( + left: StubbedInvocation, + right: StubbedInvocation, +) -> bool: + """Structural signature-compatibility checks for continuation bookkeeping. + + Intentionally avoids executing user-provided matcher predicates + (e.g. `arg_that(...)) while comparing stub signatures. + """ + + return ( + _params_are_sameish(left.params, right.params) + and _named_params_are_sameish( + left.named_params, + right.named_params, + ) + ) + + +def invocations_have_distinct_captors( + left: StubbedInvocation, + right: StubbedInvocation, +) -> bool: + """Return True when equivalent selectors bind different captor instances.""" + + for left_value, right_value in zip(left.params, right.params): + if _values_bind_distinct_captors(left_value, right_value): + return True + + for key in set(left.named_params) & set(right.named_params): + if _values_bind_distinct_captors( + left.named_params[key], + right.named_params[key], + ): + return True + + return False + + +def _params_are_sameish(left: tuple, right: tuple) -> bool: + if len(left) != len(right): + return False + + return all( + _values_are_sameish(left_value, right_value) + for left_value, right_value in zip(left, right) + ) + + +def _named_params_are_sameish(left: dict, right: dict) -> bool: + if set(left) != set(right): + return False + + return all( + _values_are_sameish(left[key], right[key]) + for key in left + ) + + +def _values_are_sameish(left: object, right: object) -> bool: + if left is right: + return True + + if left is Ellipsis or right is Ellipsis: + return left is right + + if matchers.is_call_captor(left) and matchers.is_call_captor(right): + return True + + if matchers.is_call_captor(left) or matchers.is_call_captor(right): + return False + + if ( + matchers.is_captor_args_sentinel(left) + and matchers.is_captor_args_sentinel(right) + ): + return _values_are_sameish(left.captor.matcher, right.captor.matcher) + + if ( + matchers.is_captor_kwargs_sentinel(left) + and matchers.is_captor_kwargs_sentinel(right) + ): + return _values_are_sameish(left.captor.matcher, right.captor.matcher) + + if ( + matchers.is_captor_args_sentinel(left) + or matchers.is_captor_args_sentinel(right) + or matchers.is_captor_kwargs_sentinel(left) + or matchers.is_captor_kwargs_sentinel(right) + ): + return False + + if isinstance(left, matchers.Matcher) and isinstance(right, matchers.Matcher): + return _matchers_are_sameish(left, right) + + if isinstance(left, matchers.Matcher) or isinstance(right, matchers.Matcher): + return False + + return _equals_or_identity(left, right) + + +def _matchers_are_sameish( # noqa: C901 + left: matchers.Matcher, + right: matchers.Matcher, +) -> bool: + if left is right: + return True + + if type(left) is not type(right): + return False + + if isinstance(left, matchers.Any) and isinstance(right, matchers.Any): + return _equals_or_identity(left.wanted_type, right.wanted_type) + + if ( + isinstance(left, matchers.ValueMatcher) + and isinstance(right, matchers.ValueMatcher) + ): + return _values_are_sameish(left.value, right.value) + + if ( + isinstance(left, (matchers.And, matchers.Or)) + and isinstance(right, (matchers.And, matchers.Or)) + ): + return _params_are_sameish( + tuple(left.matchers), + tuple(right.matchers), + ) + + if isinstance(left, matchers.Not) and isinstance(right, matchers.Not): + return _values_are_sameish(left.matcher, right.matcher) + + if isinstance(left, matchers.ArgThat) and isinstance(right, matchers.ArgThat): + return left.predicate is right.predicate + + if isinstance(left, matchers.Contains) and isinstance(right, matchers.Contains): + return _values_are_sameish(left.sub, right.sub) + + if isinstance(left, matchers.Matches) and isinstance(right, matchers.Matches): + return ( + left.regex.pattern == right.regex.pattern + and left.flags == right.flags + ) + + if ( + isinstance(left, matchers.ArgumentCaptor) + and isinstance(right, matchers.ArgumentCaptor) + ): + return _values_are_sameish(left.matcher, right.matcher) + + return _equals_or_identity(left, right) + + +def _values_bind_distinct_captors(left: object, right: object) -> bool: + left_binding = _captor_binding(left) + right_binding = _captor_binding(right) + + return ( + left_binding is not None + and right_binding is not None + and left_binding is not right_binding + ) + + +def _captor_binding(value: object) -> object | None: + if matchers.is_call_captor(value): + return value + + if isinstance(value, matchers.ArgumentCaptor): + return value + + if matchers.is_captor_args_sentinel(value): + return value.captor + + if matchers.is_captor_kwargs_sentinel(value): + return value.captor + + return None + + +def _equals_or_identity(left: object, right: object) -> bool: + try: + return left == right + except Exception: + return left is right diff --git a/tests/chaining_test.py b/tests/chaining_test.py index b3f86a9..c90d245 100644 --- a/tests/chaining_test.py +++ b/tests/chaining_test.py @@ -1,7 +1,7 @@ import pytest from mockito import any as any_ -from mockito import arg_that, expect, mock, verify, unstub, when +from mockito import arg_that, call_captor, captor, expect, mock, verify, unstub, when from mockito.invocation import AnswerError, InvocationError @@ -37,6 +37,101 @@ def test_multiple_chain_branches_on_same_root_are_supported(): assert cat_that_meowed.roll() == "playful" +def test_multiple_chain_branches_with_equivalent_typed_any_matchers_share_root(): + cat = mock() + + when(cat).meow(any_(int)).purr().thenReturn("friendly") + when(cat).meow(any_(int)).roll().thenReturn("playful") + + cat_that_meowed = cat.meow(1) + assert cat_that_meowed.purr() == "friendly" + assert cat_that_meowed.roll() == "playful" + + +def test_multiple_chain_branches_with_same_arg_that_matcher_share_root(): + cat = mock() + pred = arg_that(lambda value: value > 0) + + when(cat).meow(pred).purr().thenReturn("friendly") + when(cat).meow(pred).roll().thenReturn("playful") + + cat_that_meowed = cat.meow(1) + assert cat_that_meowed.purr() == "friendly" + assert cat_that_meowed.roll() == "playful" + + +def test_multiple_chain_branches_with_same_call_captor_instance_share_root(): + cat = mock() + call = call_captor() + + when(cat).meow(call).purr().thenReturn("friendly") + when(cat).meow(call).roll().thenReturn("playful") + + cat_that_meowed = cat.meow(1) + assert cat_that_meowed.purr() == "friendly" + assert cat_that_meowed.roll() == "playful" + + +def test_multiple_chain_branches_with_distinct_call_captor_roots_are_rejected(): + cat = mock() + + when(cat).meow(call_captor()).purr().thenReturn("friendly") + + with pytest.raises(InvocationError) as exc: + when(cat).meow(call_captor()).roll().thenReturn("playful") + + assert str(exc.value) == ( + "'meow' is already configured with a different captor instance for " + "the same selector. Reuse the same captor() / call_captor() object " + "across chain branches." + ) + + +def test_multiple_chain_branches_with_distinct_args_captor_roots_are_rejected(): + cat = mock() + + when(cat).meow(*captor()).purr().thenReturn("friendly") + + with pytest.raises(InvocationError) as exc: + when(cat).meow(*captor()).roll().thenReturn("playful") + + assert str(exc.value) == ( + "'meow' is already configured with a different captor instance for " + "the same selector. Reuse the same captor() / call_captor() object " + "across chain branches." + ) + + +def test_multiple_chain_branches_with_distinct_kwargs_captor_roots_are_rejected(): + cat = mock() + + when(cat).meow(**captor()).purr().thenReturn("friendly") + + with pytest.raises(InvocationError) as exc: + when(cat).meow(**captor()).roll().thenReturn("playful") + + assert str(exc.value) == ( + "'meow' is already configured with a different captor instance for " + "the same selector. Reuse the same captor() / call_captor() object " + "across chain branches." + ) + + +def test_multiple_chain_branches_with_distinct_typed_args_captor_roots_are_rejected(): + cat = mock() + + when(cat).meow(*captor(any_(int))).purr().thenReturn("friendly") + + with pytest.raises(InvocationError) as exc: + when(cat).meow(*captor(any_(int))).roll().thenReturn("playful") + + assert str(exc.value) == ( + "'meow' is already configured with a different captor instance for " + "the same selector. Reuse the same captor() / call_captor() object " + "across chain branches." + ) + + def test_unstub_child_chain_then_reconfigure_does_not_leave_stale_root_stub(): cat = mock() diff --git a/tests/sameish_test.py b/tests/sameish_test.py new file mode 100644 index 0000000..4012dd1 --- /dev/null +++ b/tests/sameish_test.py @@ -0,0 +1,217 @@ +from dataclasses import dataclass, field + +from mockito import and_, any as any_, arg_that, call_captor, captor, eq, gt, neq, or_ +from mockito import sameish + + +@dataclass +class FakeInvocation: + params: tuple = () + named_params: dict = field(default_factory=dict) + + +def bar(*params, **named_params): + return FakeInvocation(params=params, named_params=named_params) + + +def test_concrete_values_must_match_exactly(): + assert sameish.invocations_are_sameish( + bar(1, "x"), + bar(1, "x"), + ) + assert not sameish.invocations_are_sameish( + bar(1, "x"), + bar(2, "x"), + ) + + +def test_keyword_names_must_match_independent_of_order(): + assert sameish.invocations_are_sameish( + bar(a=1, b=2), + bar(b=2, a=1), + ) + assert not sameish.invocations_are_sameish( + bar(a=1), + bar(a=1, b=2), + ) + + +def test_any_matchers_are_compared_structurally(): + assert sameish.invocations_are_sameish( + bar(any_(int)), + bar(any_(int)), + ) + assert not sameish.invocations_are_sameish( + bar(any_(int)), + bar(any_()), + ) + assert not sameish.invocations_are_sameish( + bar(any_()), + bar(1), + ) + + +def test_composite_matchers_are_compared_recursively(): + assert sameish.invocations_are_sameish( + bar(and_(any_(int), gt(1))), + bar(and_(any_(int), gt(1))), + ) + assert not sameish.invocations_are_sameish( + bar(and_(any_(int), gt(1))), + bar(and_(any_(int), gt(2))), + ) + + +def test_distinct_matcher_types_are_not_sameish_even_with_equal_payload(): + assert not sameish.invocations_are_sameish( + bar(eq(1)), + bar(neq(1)), + ) + assert not sameish.invocations_are_sameish( + bar(and_(any_(int), gt(1))), + bar(or_(any_(int), gt(1))), + ) + + +def test_arg_that_uses_predicate_identity_and_does_not_execute_predicate(): + calls = [] + + def predicate(value): + calls.append(value) + raise RuntimeError("must not be executed") + + assert sameish.invocations_are_sameish( + bar(arg_that(predicate)), + bar(arg_that(predicate)), + ) + assert calls == [] + + +def test_arg_that_with_different_predicates_is_not_sameish(): + assert not sameish.invocations_are_sameish( + bar(arg_that(lambda value: value > 0)), + bar(arg_that(lambda value: value > 0)), + ) + + +def test_arg_that_predicate_side_effects_are_not_triggered(): + seen = [] + + def predicate(value): + seen.append(value) + return True + + assert sameish.invocations_are_sameish( + bar(arg_that(predicate)), + bar(arg_that(predicate)), + ) + assert seen == [] + + +def test_call_captor_instances_are_sameish_for_root_deduping(): + left = call_captor() + right = call_captor() + + assert sameish.invocations_are_sameish( + bar(left), + bar(left), + ) + assert sameish.invocations_are_sameish( + bar(left), + bar(right), + ) + + +def test_argument_captor_instances_are_sameish_for_root_deduping(): + left = captor() + right = captor() + + assert sameish.invocations_are_sameish( + bar(left), + bar(left), + ) + assert sameish.invocations_are_sameish( + bar(left), + bar(right), + ) + + +def test_star_argument_captor_instances_are_sameish_for_root_deduping(): + left = captor() + right = captor() + + assert sameish.invocations_are_sameish( + bar(1, *left), + bar(1, *left), + ) + assert sameish.invocations_are_sameish( + bar(1, *left), + bar(1, *right), + ) + + +def test_kwargs_argument_captor_instances_are_sameish_for_root_deduping(): + left = captor() + right = captor() + + assert sameish.invocations_are_sameish( + bar(1, **left), + bar(1, **left), + ) + assert sameish.invocations_are_sameish( + bar(1, **left), + bar(1, **right), + ) + + +def test_star_argument_captors_with_different_matchers_are_not_sameish(): + assert not sameish.invocations_are_sameish( + bar(1, *captor(any_(int))), + bar(1, *captor(any_(str))), + ) + + +def test_kwargs_argument_captors_with_different_matchers_are_not_sameish(): + assert not sameish.invocations_are_sameish( + bar(1, **captor(any_(int))), + bar(1, **captor(any_(str))), + ) + + +def test_star_argument_captor_any_and_typed_any_are_not_sameish(): + assert not sameish.invocations_are_sameish( + bar(1, *captor()), + bar(1, *captor(any_(int))), + ) + + +def test_kwargs_argument_captor_any_and_typed_any_are_not_sameish(): + assert not sameish.invocations_are_sameish( + bar(1, **captor()), + bar(1, **captor(any_(int))), + ) + + +def test_argument_captor_instances_with_different_matchers_are_not_sameish(): + assert not sameish.invocations_are_sameish( + bar(captor(any_(int))), + bar(captor(any_(str))), + ) + + +def test_eq_failures_fallback_to_identity(): + class EqBoom: + def __eq__(self, other): + raise RuntimeError("boom") + + first = EqBoom() + second = EqBoom() + + assert sameish.invocations_are_sameish( + bar(first), + bar(first), + ) + assert not sameish.invocations_are_sameish( + bar(first), + bar(second), + )