Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions libs/core/langchain_core/runnables/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,31 +238,40 @@ def _batch(
) -> list[Union[Output, Exception]]:
results_map: dict[int, Output] = {}

def pending(iterable: list[U]) -> list[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]

not_set: list[Output] = []
result = not_set
try:
for attempt in self._sync_retrying():
with attempt:
# Get the results of the inputs that have not succeeded yet.
# Retry for inputs that have not yet succeeded
# Determine which original indices remain.
remaining_indices = [
i for i in range(len(inputs)) if i not in results_map
]
if not remaining_indices:
break
pending_inputs = [inputs[i] for i in remaining_indices]
pending_configs = [config[i] for i in remaining_indices]
pending_run_managers = [run_manager[i] for i in remaining_indices]
# Invoke underlying batch only on remaining elements.
result = super().batch(
pending(inputs),
pending_inputs,
self._patch_config_list(
pending(config), pending(run_manager), attempt.retry_state
pending_configs, pending_run_managers, attempt.retry_state
),
return_exceptions=True,
**kwargs,
)
# Register the results of the inputs that have succeeded.
# Register the results of the inputs that have succeeded, mapping
# back to their original indices.
first_exception = None
for i, r in enumerate(result):
for offset, r in enumerate(result):
if isinstance(r, Exception):
if not first_exception:
first_exception = r
continue
results_map[i] = r
orig_idx = remaining_indices[offset]
results_map[orig_idx] = r
# If any exception occurred, raise it, to retry the failed ones
if first_exception:
raise first_exception
Expand Down Expand Up @@ -305,31 +314,39 @@ async def _abatch(
) -> list[Union[Output, Exception]]:
results_map: dict[int, Output] = {}

def pending(iterable: list[U]) -> list[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]

not_set: list[Output] = []
result = not_set
try:
async for attempt in self._async_retrying():
with attempt:
# Get the results of the inputs that have not succeeded yet.
# Retry for inputs that have not yet succeeded
# Determine which original indices remain.
remaining_indices = [
i for i in range(len(inputs)) if i not in results_map
]
if not remaining_indices:
break
pending_inputs = [inputs[i] for i in remaining_indices]
pending_configs = [config[i] for i in remaining_indices]
pending_run_managers = [run_manager[i] for i in remaining_indices]
result = await super().abatch(
pending(inputs),
pending_inputs,
self._patch_config_list(
pending(config), pending(run_manager), attempt.retry_state
pending_configs, pending_run_managers, attempt.retry_state
),
return_exceptions=True,
**kwargs,
)
# Register the results of the inputs that have succeeded.
# Register the results of the inputs that have succeeded, mapping
# back to their original indices.
first_exception = None
for i, r in enumerate(result):
for offset, r in enumerate(result):
if isinstance(r, Exception):
if not first_exception:
first_exception = r
continue
results_map[i] = r
orig_idx = remaining_indices[offset]
results_map[orig_idx] = r
# If any exception occurred, raise it, to retry the failed ones
if first_exception:
raise first_exception
Expand Down
52 changes: 52 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3919,6 +3919,58 @@ def _lambda(x: int) -> Union[int, Runnable]:
lambda_mock.reset_mock()


def test_retry_batch_preserves_order() -> None:
"""Regression test: batch with retry should preserve input order.

The previous implementation stored successful results in a map keyed by the
index within the *pending* (filtered) list rather than the original input
index, causing collisions after retries. This produced duplicated outputs
and dropped earlier successes (e.g. [0,1,2] -> [1,1,2]).
"""
# Fail only the middle element on the first attempt to trigger the bug.
first_fail: set[int] = {1}

def sometimes_fail(x: int) -> int: # pragma: no cover - trivial
if x in first_fail:
first_fail.remove(x)
msg = "fail once"
raise ValueError(msg)
return x

runnable = RunnableLambda(sometimes_fail)

results = runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).batch([0, 1, 2])

# Expect exact ordering preserved.
assert results == [0, 1, 2]


async def test_async_retry_batch_preserves_order() -> None:
"""Async variant of order preservation regression test."""
first_fail: set[int] = {1}

def sometimes_fail(x: int) -> int: # pragma: no cover - trivial
if x in first_fail:
first_fail.remove(x)
msg = "fail once"
raise ValueError(msg)
return x

runnable = RunnableLambda(sometimes_fail)

results = await runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).abatch([0, 1, 2])

assert results == [0, 1, 2]


async def test_async_retrying(mocker: MockerFixture) -> None:
def _lambda(x: int) -> Union[int, Runnable]:
if x == 1:
Expand Down