Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions src/core/services/streaming/tool_call_repair_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ToolCallRepairProcessor(IStreamProcessor):
def __init__(self, tool_call_repair_service: IToolCallRepairService) -> None:
self.tool_call_repair_service = tool_call_repair_service
self._buffers: dict[str, str] = {}
self._max_buffer_bytes = self._resolve_buffer_cap(tool_call_repair_service)

async def process(self, content: StreamingContent) -> StreamingContent:
"""
Expand All @@ -36,6 +37,7 @@ async def process(self, content: StreamingContent) -> StreamingContent:
buffer = self._buffers.get(stream_id, "")

buffer += content.content or ""
buffer = self._enforce_buffer_cap(stream_id, buffer)

repaired_content_parts: list[str] = []
remaining_buffer = buffer
Expand Down Expand Up @@ -101,3 +103,58 @@ async def process(self, content: StreamingContent) -> StreamingContent:
content="",
is_cancellation=content.is_cancellation,
) # Return empty if nothing to yield

def _resolve_buffer_cap(self, service: IToolCallRepairService) -> int:
"""Determine the maximum buffer size supported by the repair service."""

default_cap = 64 * 1024
candidate = getattr(service, "max_buffer_bytes", default_cap)
try:
cap_value = int(candidate)
except (TypeError, ValueError):
logger.warning(
"Invalid tool call repair buffer cap %r; using default %d bytes",
candidate,
default_cap,
)
return default_cap
if cap_value < 0:
logger.warning(
"Negative tool call repair buffer cap %d received; treating as zero",
cap_value,
)
return 0
return cap_value

def _enforce_buffer_cap(self, stream_id: str, buffer: str) -> str:
"""Ensure per-stream buffer usage stays within configured limits."""

cap = self._max_buffer_bytes
if cap == 0:
if buffer:
logger.warning(
"Dropping streaming tool call buffer for stream %s because cap is 0",
stream_id,
)
return ""

if cap < 0:
return buffer

if not buffer:
return buffer

buffer_bytes = buffer.encode("utf-8")
current_size = len(buffer_bytes)
if current_size <= cap:
return buffer

truncated_bytes = buffer_bytes[-cap:]
dropped = current_size - cap
logger.warning(
"Tool call repair buffer for stream %s exceeded %d bytes; dropping %d bytes",
stream_id,
cap,
dropped,
)
return truncated_bytes.decode("utf-8", errors="ignore")
12 changes: 11 additions & 1 deletion src/core/services/tool_call_repair_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,17 @@ def __init__(self, max_buffer_bytes: int | None = None) -> None:
)

# Cap per-session buffer to guard against pathological streams
self._max_buffer_bytes: int = max_buffer_bytes or (64 * 1024) # default 64 KB
self._max_buffer_bytes: int = (
int(max_buffer_bytes) if max_buffer_bytes is not None else 64 * 1024
)
if self._max_buffer_bytes < 0:
self._max_buffer_bytes = 0

@property
def max_buffer_bytes(self) -> int:
"""Return the configured maximum buffer size for streaming repair."""

return self._max_buffer_bytes

def repair_tool_calls(self, response_content: str) -> dict[str, Any] | None:
"""
Expand Down
47 changes: 45 additions & 2 deletions tests/unit/core/services/test_tool_call_repair.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from pytest_mock import MockerFixture
from src.core.domain.streaming_response_processor import StreamingContent
from src.core.interfaces.response_processor_interface import ProcessedResponse
from src.core.services.streaming.tool_call_repair_processor import (
ToolCallRepairProcessor,
Expand Down Expand Up @@ -72,8 +73,6 @@ async def test_process_chunks_with_tool_call(
streaming_processor: StreamingToolCallRepairProcessor,
mocker: MockerFixture,
) -> None:
from src.core.domain.streaming_response_processor import StreamingContent

# Mock the underlying ToolCallRepairProcessor's process method
# This is where the actual repair logic is now encapsulated
mock_tool_call_repair_processor_process = mocker.AsyncMock(
Expand Down Expand Up @@ -150,3 +149,47 @@ async def mock_async_chunks_generator() -> (
)
assert actual_calls[2].content == "World."
assert actual_calls[3].is_done is True and actual_calls[3].content == ""


class TestToolCallRepairProcessorBehavior:
@pytest.mark.asyncio
async def test_buffer_truncated_when_cap_exceeded(
self, mocker: MockerFixture
) -> None:
service = ToolCallRepairService(max_buffer_bytes=32)
processor = ToolCallRepairProcessor(service)

stream_id = "stream-cap"
large_payload = "x" * 100

repair_mock = mocker.patch.object(
service, "repair_tool_calls", return_value=None
)

await processor.process(
StreamingContent(content=large_payload, metadata={"stream_id": stream_id})
)

stored_buffer = processor._buffers.get(stream_id, "")
assert len(stored_buffer.encode("utf-8")) <= service.max_buffer_bytes

repair_mock.assert_called_once()
processed_buffer = repair_mock.call_args[0][0]
assert len(processed_buffer.encode("utf-8")) <= service.max_buffer_bytes

@pytest.mark.asyncio
async def test_buffer_dropped_when_cap_zero(self, mocker: MockerFixture) -> None:
service = ToolCallRepairService(max_buffer_bytes=0)
processor = ToolCallRepairProcessor(service)

stream_id = "stream-zero"
repair_mock = mocker.patch.object(
service, "repair_tool_calls", return_value=None
)

await processor.process(
StreamingContent(content="payload", metadata={"stream_id": stream_id})
)

assert stream_id not in processor._buffers
repair_mock.assert_called_once_with("")
Loading