From 1c051ece3301c9afcf5a92f7c7e082e1bf118969 Mon Sep 17 00:00:00 2001 From: noelpenne Date: Tue, 6 May 2025 13:28:04 -0500 Subject: [PATCH] fix: cleanup blobs and writes for shallow classes --- langgraph/checkpoint/redis/ashallow.py | 36 ++++++++++++++++-------- langgraph/checkpoint/redis/shallow.py | 39 +++++++++++++++----------- tests/test_shallow_async.py | 16 +++++++++-- tests/test_shallow_sync.py | 12 ++++++-- 4 files changed, 70 insertions(+), 33 deletions(-) diff --git a/langgraph/checkpoint/redis/ashallow.py b/langgraph/checkpoint/redis/ashallow.py index 90a4560..932f648 100644 --- a/langgraph/checkpoint/redis/ashallow.py +++ b/langgraph/checkpoint/redis/ashallow.py @@ -6,7 +6,6 @@ import json import os from contextlib import asynccontextmanager -from functools import partial from types import TracebackType from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Type, cast @@ -25,7 +24,6 @@ from redisvl.index import AsyncSearchIndex from redisvl.query import FilterQuery from redisvl.query.filter import Num, Tag -from redisvl.redis.connection import RedisConnectionFactory from langgraph.checkpoint.redis.base import ( CHECKPOINT_BLOB_PREFIX, @@ -34,6 +32,10 @@ REDIS_KEY_SEPARATOR, BaseRedisSaver, ) +from langgraph.checkpoint.redis.util import ( + to_storage_safe_id, + to_storage_safe_str, +) SCHEMAS = [ { @@ -794,16 +796,26 @@ def put_writes( @staticmethod def _make_shallow_redis_checkpoint_key(thread_id: str, checkpoint_ns: str) -> str: """Create a key for shallow checkpoints using only thread_id and checkpoint_ns.""" - return REDIS_KEY_SEPARATOR.join([CHECKPOINT_PREFIX, thread_id, checkpoint_ns]) + return REDIS_KEY_SEPARATOR.join( + [ + CHECKPOINT_PREFIX, + str(to_storage_safe_id(thread_id)), + to_storage_safe_str(checkpoint_ns), + ] + ) @staticmethod def _make_shallow_redis_checkpoint_blob_key_pattern( thread_id: str, checkpoint_ns: str ) -> str: """Create a pattern to match all blob keys for a thread and namespace.""" - return ( - REDIS_KEY_SEPARATOR.join([CHECKPOINT_BLOB_PREFIX, thread_id, checkpoint_ns]) - + ":*" + return REDIS_KEY_SEPARATOR.join( + [ + CHECKPOINT_BLOB_PREFIX, + str(to_storage_safe_id(thread_id)), + to_storage_safe_str(checkpoint_ns), + "*", + ] ) @staticmethod @@ -811,9 +823,11 @@ def _make_shallow_redis_checkpoint_writes_key_pattern( thread_id: str, checkpoint_ns: str ) -> str: """Create a pattern to match all writes keys for a thread and namespace.""" - return ( - REDIS_KEY_SEPARATOR.join( - [CHECKPOINT_WRITE_PREFIX, thread_id, checkpoint_ns] - ) - + ":*" + return REDIS_KEY_SEPARATOR.join( + [ + CHECKPOINT_WRITE_PREFIX, + str(to_storage_safe_id(thread_id)), + to_storage_safe_str(checkpoint_ns), + "*", + ] ) diff --git a/langgraph/checkpoint/redis/shallow.py b/langgraph/checkpoint/redis/shallow.py index ad8bcd4..8f451e2 100644 --- a/langgraph/checkpoint/redis/shallow.py +++ b/langgraph/checkpoint/redis/shallow.py @@ -26,6 +26,10 @@ REDIS_KEY_SEPARATOR, BaseRedisSaver, ) +from langgraph.checkpoint.redis.util import ( + to_storage_safe_id, + to_storage_safe_str, +) SCHEMAS = [ { @@ -688,15 +692,12 @@ def _load_pending_sends( @staticmethod def _make_shallow_redis_checkpoint_key(thread_id: str, checkpoint_ns: str) -> str: """Create a key for shallow checkpoints using only thread_id and checkpoint_ns.""" - return REDIS_KEY_SEPARATOR.join([CHECKPOINT_PREFIX, thread_id, checkpoint_ns]) - - @staticmethod - def _make_shallow_redis_checkpoint_blob_key( - thread_id: str, checkpoint_ns: str, channel: str - ) -> str: - """Create a key for a blob in a shallow checkpoint.""" return REDIS_KEY_SEPARATOR.join( - [CHECKPOINT_BLOB_PREFIX, thread_id, checkpoint_ns, channel] + [ + CHECKPOINT_PREFIX, + str(to_storage_safe_id(thread_id)), + to_storage_safe_str(checkpoint_ns), + ] ) @staticmethod @@ -704,9 +705,13 @@ def _make_shallow_redis_checkpoint_blob_key_pattern( thread_id: str, checkpoint_ns: str ) -> str: """Create a pattern to match all blob keys for a thread and namespace.""" - return ( - REDIS_KEY_SEPARATOR.join([CHECKPOINT_BLOB_PREFIX, thread_id, checkpoint_ns]) - + ":*" + return REDIS_KEY_SEPARATOR.join( + [ + CHECKPOINT_BLOB_PREFIX, + str(to_storage_safe_id(thread_id)), + to_storage_safe_str(checkpoint_ns), + "*", + ] ) @staticmethod @@ -714,9 +719,11 @@ def _make_shallow_redis_checkpoint_writes_key_pattern( thread_id: str, checkpoint_ns: str ) -> str: """Create a pattern to match all writes keys for a thread and namespace.""" - return ( - REDIS_KEY_SEPARATOR.join( - [CHECKPOINT_WRITE_PREFIX, thread_id, checkpoint_ns] - ) - + ":*" + return REDIS_KEY_SEPARATOR.join( + [ + CHECKPOINT_WRITE_PREFIX, + str(to_storage_safe_id(thread_id)), + to_storage_safe_str(checkpoint_ns), + "*", + ] ) diff --git a/tests/test_shallow_async.py b/tests/test_shallow_async.py index 0f10943..d81dc4d 100644 --- a/tests/test_shallow_async.py +++ b/tests/test_shallow_async.py @@ -12,6 +12,7 @@ from redis.exceptions import ConnectionError as RedisConnectionError from langgraph.checkpoint.redis.ashallow import AsyncShallowRedisSaver +from langgraph.checkpoint.redis.base import CHECKPOINT_BLOB_PREFIX @pytest.fixture @@ -96,7 +97,10 @@ async def test_only_latest_checkpoint( } ) checkpoint_1 = test_data["checkpoints"][0] - await saver.aput(config_1, checkpoint_1, test_data["metadata"][0], {}) + channel_versions_1 = {"test_channel": "1"} + await saver.aput( + config_1, checkpoint_1, test_data["metadata"][0], channel_versions_1 + ) # Create second checkpoint config_2 = RunnableConfig( @@ -108,13 +112,19 @@ async def test_only_latest_checkpoint( } ) checkpoint_2 = test_data["checkpoints"][1] - await saver.aput(config_2, checkpoint_2, test_data["metadata"][1], {}) + channel_versions_2 = {"test_channel": "2"} + await saver.aput( + config_2, checkpoint_2, test_data["metadata"][1], channel_versions_2 + ) - # Verify only latest checkpoint exists + # Verify only latest checkpoint and blobs exists results = [c async for c in saver.alist(None)] assert len(results) == 1 assert results[0].config["configurable"]["checkpoint_id"] == checkpoint_2["id"] + blobs = list(await saver._redis.keys(CHECKPOINT_BLOB_PREFIX + ":*")) + assert len(blobs) == 1 + @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/test_shallow_sync.py b/tests/test_shallow_sync.py index 02944cf..232d609 100644 --- a/tests/test_shallow_sync.py +++ b/tests/test_shallow_sync.py @@ -12,6 +12,7 @@ from redis import Redis from redis.exceptions import ConnectionError as RedisConnectionError +from langgraph.checkpoint.redis.base import CHECKPOINT_BLOB_PREFIX from langgraph.checkpoint.redis.shallow import ShallowRedisSaver @@ -102,7 +103,8 @@ def test_only_latest_checkpoint( } } checkpoint_1 = test_data["checkpoints"][0] - saver.put(config_1, checkpoint_1, test_data["metadata"][0], {}) + channel_versions_1 = {"test_channel": "1"} + saver.put(config_1, checkpoint_1, test_data["metadata"][0], channel_versions_1) # Create second checkpoint config_2 = { @@ -112,13 +114,17 @@ def test_only_latest_checkpoint( } } checkpoint_2 = test_data["checkpoints"][1] - saver.put(config_2, checkpoint_2, test_data["metadata"][1], {}) + channel_versions_2 = {"test_channel": "2"} + saver.put(config_2, checkpoint_2, test_data["metadata"][1], channel_versions_2) - # Verify only latest checkpoint exists + # Verify only latest checkpoint and blobs exists results = list(saver.list(None)) assert len(results) == 1 assert results[0].config["configurable"]["checkpoint_id"] == checkpoint_2["id"] + blobs = list(saver._redis.keys(CHECKPOINT_BLOB_PREFIX + ":*")) + assert len(blobs) == 1 + @pytest.mark.parametrize( "query, expected_count",