Skip to content

Commit 45eb37d

Browse files
committed
Enhance _check_reference to unwrap ObjectProxy for assertions
1 parent d94bbd1 commit 45eb37d

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

src/pynguin/assertion/assertiontraceobserver.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import pynguin.testcase.testcase as tc
2424
import pynguin.testcase.variablereference as vr
2525
import pynguin.utils.generic.genericaccessibleobject as gao
26+
import pynguin.utils.typetracing as tt
2627
from pynguin.analyses.typesystem import ANY, TypeInfo
2728
from pynguin.utils.type_utils import (
2829
is_assertable,
@@ -120,7 +121,8 @@ def _handle( # noqa: C901
120121
trace = self._assertion_local_state.trace
121122

122123
if not statement.ret_val.is_none_type():
123-
if is_primitive_type(type(exec_ctx.get_reference_value(statement.ret_val))):
124+
ret_value = tt.unwrap(exec_ctx.get_reference_value(statement.ret_val))
125+
if is_primitive_type(type(ret_value)):
124126
# Primitives won't change, so we only check them once.
125127
self._check_reference(module_provider, exec_ctx, statement.ret_val, position, trace)
126128
elif type(exec_ctx.get_reference_value(statement.ret_val)).__module__ != "builtins":
@@ -155,7 +157,7 @@ def _handle( # noqa: C901
155157

156158
# Check fields of classes whose constructors were used.
157159
for seen_type in [
158-
type(exec_ctx.get_reference_value(ref))
160+
type(tt.unwrap(exec_ctx.get_reference_value(ref)))
159161
for ref in self._assertion_local_state.watch_list
160162
]:
161163
if (
@@ -208,6 +210,7 @@ def _check_reference(
208210
max_depth: The maximum recursion depth.
209211
"""
210212
value = exec_ctx.get_reference_value(ref)
213+
value = tt.unwrap(value)
211214
if isinstance(value, float):
212215
trace.add_entry(position, ass.FloatAssertion(ref, value))
213216
return

tests/assertion/test_assertiontraceobserver.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from unittest import mock
99
from unittest.mock import MagicMock
1010

11+
import pynguin.assertion.assertion as ass
1112
import pynguin.assertion.assertiontraceobserver as ato
13+
import pynguin.utils.typetracing as tt
1214
from pynguin.testcase.execution import ExecutionContext, TestCaseExecutor
1315
from pynguin.testcase.statement import Statement
1416

@@ -47,3 +49,46 @@ def test_after_test_case_execution():
4749
trace_mock.clone.return_value = clone
4850
observer.after_test_case_execution(MagicMock(), MagicMock(), result)
4951
assert result.assertion_trace == clone
52+
53+
54+
def test_check_reference_unwraps_object_proxy():
55+
"""Regression test: _check_reference must unwrap ObjectProxy to generate assertions."""
56+
observer = ato.RemoteAssertionTraceObserver()
57+
58+
wrapped_value = "test_string"
59+
proxy = tt.ObjectProxy(wrapped_value)
60+
61+
exec_ctx = MagicMock()
62+
exec_ctx.get_reference_value.return_value = proxy
63+
module_provider = MagicMock()
64+
ref = MagicMock()
65+
66+
trace = observer._assertion_local_state.trace
67+
observer._check_reference(module_provider, exec_ctx, ref, position=0, trace=trace)
68+
69+
assertions = list(trace.trace.get(0, []))
70+
assert len(assertions) == 1
71+
assert isinstance(assertions[0], ass.ObjectAssertion)
72+
assert assertions[0].object == wrapped_value
73+
74+
75+
def test_check_reference_unwraps_nested_object_proxy():
76+
"""Regression test: _check_reference handles nested ObjectProxy correctly."""
77+
observer = ato.RemoteAssertionTraceObserver()
78+
79+
inner_value = 42
80+
inner_proxy = tt.ObjectProxy(inner_value)
81+
outer_proxy = tt.ObjectProxy(inner_proxy)
82+
83+
exec_ctx = MagicMock()
84+
exec_ctx.get_reference_value.return_value = outer_proxy
85+
module_provider = MagicMock()
86+
ref = MagicMock()
87+
88+
trace = observer._assertion_local_state.trace
89+
observer._check_reference(module_provider, exec_ctx, ref, position=0, trace=trace)
90+
91+
assertions = list(trace.trace.get(0, []))
92+
assert len(assertions) == 1
93+
assert isinstance(assertions[0], ass.ObjectAssertion)
94+
assert assertions[0].object == inner_value

0 commit comments

Comments
 (0)