From 11cd4c61cf7f900573f7bcf8576feb4f027a8240 Mon Sep 17 00:00:00 2001 From: "vincent.min" Date: Wed, 13 Aug 2025 12:59:37 +0200 Subject: [PATCH 1/2] add failing test for retry batch ordering --- .../unit_tests/runnables/test_runnable.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index d9a9db349e8ab..105e39392626d 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -3923,6 +3923,58 @@ def _lambda(x: int) -> Union[int, Runnable]: lambda_mock.reset_mock() +def test_retry_batch_preserves_order() -> None: + """Regression test: batch with retry should preserve input order. + + The previous implementation stored successful results in a map keyed by the + index within the *pending* (filtered) list rather than the original input + index, causing collisions after retries. This produced duplicated outputs + and dropped earlier successes (e.g. [0,1,2] -> [1,1,2]). + """ + # Fail only the middle element on the first attempt to trigger the bug. + first_fail: set[int] = {1} + + def sometimes_fail(x: int) -> int: # pragma: no cover - trivial + if x in first_fail: + first_fail.remove(x) + msg = "fail once" + raise ValueError(msg) + return x + + runnable = RunnableLambda(sometimes_fail) + + results = runnable.with_retry( + stop_after_attempt=2, + wait_exponential_jitter=False, + retry_if_exception_type=(ValueError,), + ).batch([0, 1, 2]) + + # Expect exact ordering preserved. + assert results == [0, 1, 2] + + +async def test_async_retry_batch_preserves_order() -> None: + """Async variant of order preservation regression test.""" + first_fail: set[int] = {1} + + def sometimes_fail(x: int) -> int: # pragma: no cover - trivial + if x in first_fail: + first_fail.remove(x) + msg = "fail once" + raise ValueError(msg) + return x + + runnable = RunnableLambda(sometimes_fail) + + results = await runnable.with_retry( + stop_after_attempt=2, + wait_exponential_jitter=False, + retry_if_exception_type=(ValueError,), + ).abatch([0, 1, 2]) + + assert results == [0, 1, 2] + + async def test_async_retrying(mocker: MockerFixture) -> None: def _lambda(x: int) -> Union[int, Runnable]: if x == 1: From cec7e8913e0e7fb516c233663ffc1ce12d0a7866 Mon Sep 17 00:00:00 2001 From: "vincent.min" Date: Wed, 13 Aug 2025 13:08:50 +0200 Subject: [PATCH 2/2] fix index issue in RunnableRetry batch and abatch methods --- libs/core/langchain_core/runnables/retry.py | 53 ++++++++++++++------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index d495c59e6e9e9..b6b81df8e07c5 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -234,31 +234,40 @@ def _batch( ) -> list[Union[Output, Exception]]: results_map: dict[int, Output] = {} - def pending(iterable: list[U]) -> list[U]: - return [item for idx, item in enumerate(iterable) if idx not in results_map] - not_set: list[Output] = [] result = not_set try: for attempt in self._sync_retrying(): with attempt: - # Get the results of the inputs that have not succeeded yet. + # Retry for inputs that have not yet succeeded + # Determine which original indices remain. + remaining_indices = [ + i for i in range(len(inputs)) if i not in results_map + ] + if not remaining_indices: + break + pending_inputs = [inputs[i] for i in remaining_indices] + pending_configs = [config[i] for i in remaining_indices] + pending_run_managers = [run_manager[i] for i in remaining_indices] + # Invoke underlying batch only on remaining elements. result = super().batch( - pending(inputs), + pending_inputs, self._patch_config_list( - pending(config), pending(run_manager), attempt.retry_state + pending_configs, pending_run_managers, attempt.retry_state ), return_exceptions=True, **kwargs, ) - # Register the results of the inputs that have succeeded. + # Register the results of the inputs that have succeeded, mapping + # back to their original indices. first_exception = None - for i, r in enumerate(result): + for offset, r in enumerate(result): if isinstance(r, Exception): if not first_exception: first_exception = r continue - results_map[i] = r + orig_idx = remaining_indices[offset] + results_map[orig_idx] = r # If any exception occurred, raise it, to retry the failed ones if first_exception: raise first_exception @@ -301,31 +310,39 @@ async def _abatch( ) -> list[Union[Output, Exception]]: results_map: dict[int, Output] = {} - def pending(iterable: list[U]) -> list[U]: - return [item for idx, item in enumerate(iterable) if idx not in results_map] - not_set: list[Output] = [] result = not_set try: async for attempt in self._async_retrying(): with attempt: - # Get the results of the inputs that have not succeeded yet. + # Retry for inputs that have not yet succeeded + # Determine which original indices remain. + remaining_indices = [ + i for i in range(len(inputs)) if i not in results_map + ] + if not remaining_indices: + break + pending_inputs = [inputs[i] for i in remaining_indices] + pending_configs = [config[i] for i in remaining_indices] + pending_run_managers = [run_manager[i] for i in remaining_indices] result = await super().abatch( - pending(inputs), + pending_inputs, self._patch_config_list( - pending(config), pending(run_manager), attempt.retry_state + pending_configs, pending_run_managers, attempt.retry_state ), return_exceptions=True, **kwargs, ) - # Register the results of the inputs that have succeeded. + # Register the results of the inputs that have succeeded, mapping + # back to their original indices. first_exception = None - for i, r in enumerate(result): + for offset, r in enumerate(result): if isinstance(r, Exception): if not first_exception: first_exception = r continue - results_map[i] = r + orig_idx = remaining_indices[offset] + results_map[orig_idx] = r # If any exception occurred, raise it, to retry the failed ones if first_exception: raise first_exception