Skip to content

Commit cc4df52

Browse files
feat: server-side client state persistence (#8314)
## Summary Move client state persistence from browser to server. - Add new client state persistence service to handle reading and writing client state to db & associated router. The API mirrors that of LocalStorage/IndexedDB where the set/get methods both operate on _keys_. For example, when we persist the canvas state, we send only the new canvas state to the backend - not the whole app state. - The data is very flexibly-typed as a pydantic `JsonValue`. The client is expected to handle all data parsing/validation (it must do this anyways, and does this today). - Change persistence from debounced to throttled at 2 seconds. Maybe less is OK? Trying to not hammer the server. - Add new persistence storage driver in client and use it in redux-remember. It does its best to avoid extraneous persist requests, caching the last data it persisted and noop-ing if there are no changes. - Storage driver tracks pending persist actions using ref counts (bc each slice is persisted independently). If there user navigates away from the page during a persist request, it will give them the "you may lose something if you navigate away" alert. - This "lose something" alert message is not customizable (browser security reasons). - The alert is triggered only when the user closes the tape while a persist network request is mid-flight. It's possible that the user makes a change and closes the page before we start persisting. In this case, they will lose the last 2 seconds of data. - I tried making triggering the alert when a persist was waiting to start, and it felt off. - Maybe the alert isn't even necessary. Again you'd lose 2s of data at most, probably a non issue. IMO after trying it, a subtle indicator somewhere on the page is probably less confusing/intrusive. - Fix an issue where the `redux-remember` enhancer was added _last_ in the enhancer chain, which prevented us detecting when a persist has succeeded. This required a small change to the `unserialze` utility (used during rehydration) to ensure slices enhanced with `redux-undo` are set up correctly as they are rehydrated. - Restructure the redux store code to avoid circular dependencies. I couldn't figure out how to do this without just smooshing it all into the main `store.ts` file. Oh well. Implications: - Because client state is now on the server, different browsers will have the same studio state. For example, if I start working on something in Firefox, if I switch to Chrome, I have the same client state. - Incognito windows won't do anything bc client state is server-side. - It takes a bit longer for persistence to happen thanks to the debounce, but there's now an indicator that tells you your stuff isn't saved yet. - Resetting the browser won't fix an issue with your studio state. You must use `Reset Web UI` to fix it (or otherwise hit the appropriate endpoint). It may be possible to end up in a Catch-22 where you can't click the button and get stuck w/ a borked studio - I think to think through this a bit more, might not be an issue. - It probably takes a bit longer to start up, since we need to retrieve client state over network instead of directly with browser APIs. Other notes: - We could explore adding an "incognito" mode, enabled via `invokeai.yaml` setting or maybe in the UI. This would temporarily disable persistence. Actually, I don't think this really makes sense, bc all the images would be saved to disk. - The studio state is stored in a single row in the DB. Currently, a static row ID is used to force the studio state to be a singleton. It is _possible_ to support multiple saved states. Might be a solve for app workspaces. ## Related Issues / Discussions n/a ## QA Instructions Try it out. It's pretty straightforward. Error states are the main things to test - for example, network blips. The new server-side persistence driver is the only real functional change - everything else is just kinda shuffling things around to support it. ## Merge Plan n/a ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 2571e19 + 1cb4ef0 commit cc4df52

File tree

89 files changed

+2277
-1306
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+2277
-1306
lines changed

invokeai/app/api/dependencies.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
1111
from invokeai.app.services.boards.boards_default import BoardService
1212
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
13+
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite
1314
from invokeai.app.services.config.config_default import InvokeAIAppConfig
1415
from invokeai.app.services.download.download_default import DownloadQueueService
1516
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
@@ -151,6 +152,7 @@ def initialize(
151152
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
152153
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
153154
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
155+
client_state_persistence = ClientStatePersistenceSqlite(db=db)
154156

155157
services = InvocationServices(
156158
board_image_records=board_image_records,
@@ -181,6 +183,7 @@ def initialize(
181183
style_preset_records=style_preset_records,
182184
style_preset_image_files=style_preset_image_files,
183185
workflow_thumbnails=workflow_thumbnails,
186+
client_state_persistence=client_state_persistence,
184187
)
185188

186189
ApiDependencies.invoker = Invoker(services)

invokeai/app/api/routers/app_info.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from typing import Optional
66

77
import torch
8-
from fastapi import Body
8+
from fastapi import Body, HTTPException, Query
99
from fastapi.routing import APIRouter
10-
from pydantic import BaseModel, Field
10+
from pydantic import BaseModel, Field, JsonValue
1111

1212
from invokeai.app.api.dependencies import ApiDependencies
1313
from invokeai.app.invocations.upscale import ESRGAN_MODELS
@@ -173,3 +173,50 @@ async def disable_invocation_cache() -> None:
173173
async def get_invocation_cache_status() -> InvocationCacheStatus:
174174
"""Clears the invocation cache"""
175175
return ApiDependencies.invoker.services.invocation_cache.get_status()
176+
177+
178+
@app_router.get(
179+
"/client_state",
180+
operation_id="get_client_state_by_key",
181+
response_model=JsonValue | None,
182+
)
183+
async def get_client_state_by_key(
184+
key: str = Query(..., description="Key to get"),
185+
) -> JsonValue | None:
186+
"""Gets the client state"""
187+
try:
188+
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(key)
189+
except Exception as e:
190+
logging.error(f"Error getting client state: {e}")
191+
raise HTTPException(status_code=500, detail="Error setting client state")
192+
193+
194+
@app_router.post(
195+
"/client_state",
196+
operation_id="set_client_state",
197+
response_model=None,
198+
)
199+
async def set_client_state(
200+
key: str = Query(..., description="Key to set"),
201+
value: JsonValue = Body(..., description="Value of the key"),
202+
) -> None:
203+
"""Sets the client state"""
204+
try:
205+
ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value)
206+
except Exception as e:
207+
logging.error(f"Error setting client state: {e}")
208+
raise HTTPException(status_code=500, detail="Error setting client state")
209+
210+
211+
@app_router.delete(
212+
"/client_state",
213+
operation_id="delete_client_state",
214+
responses={204: {"description": "Client state deleted"}},
215+
)
216+
async def delete_client_state() -> None:
217+
"""Deletes the client state"""
218+
try:
219+
ApiDependencies.invoker.services.client_state_persistence.delete()
220+
except Exception as e:
221+
logging.error(f"Error deleting client state: {e}")
222+
raise HTTPException(status_code=500, detail="Error deleting client state")
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from abc import ABC, abstractmethod
2+
3+
from pydantic import JsonValue
4+
5+
6+
class ClientStatePersistenceABC(ABC):
7+
"""
8+
Base class for client persistence implementations.
9+
This class defines the interface for persisting client data.
10+
"""
11+
12+
@abstractmethod
13+
def set_by_key(self, key: str, value: JsonValue) -> None:
14+
"""
15+
Store the data for the client.
16+
17+
:param data: The client data to be stored.
18+
"""
19+
pass
20+
21+
@abstractmethod
22+
def get_by_key(self, key: str) -> JsonValue | None:
23+
"""
24+
Get the data for the client.
25+
26+
:return: The client data.
27+
"""
28+
pass
29+
30+
@abstractmethod
31+
def delete(self) -> None:
32+
"""
33+
Delete the data for the client.
34+
"""
35+
pass
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import json
2+
3+
from pydantic import JsonValue
4+
5+
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
6+
from invokeai.app.services.invoker import Invoker
7+
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
8+
9+
10+
class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
11+
"""
12+
Base class for client persistence implementations.
13+
This class defines the interface for persisting client data.
14+
"""
15+
16+
def __init__(self, db: SqliteDatabase) -> None:
17+
super().__init__()
18+
self._db = db
19+
self._default_row_id = 1
20+
21+
def start(self, invoker: Invoker) -> None:
22+
self._invoker = invoker
23+
24+
def set_by_key(self, key: str, value: JsonValue) -> None:
25+
state = self.get() or {}
26+
state.update({key: value})
27+
28+
with self._db.transaction() as cursor:
29+
cursor.execute(
30+
f"""
31+
INSERT INTO client_state (id, data)
32+
VALUES ({self._default_row_id}, ?)
33+
ON CONFLICT(id) DO UPDATE
34+
SET data = excluded.data;
35+
""",
36+
(json.dumps(state),),
37+
)
38+
39+
def get(self) -> dict[str, JsonValue] | None:
40+
with self._db.transaction() as cursor:
41+
cursor.execute(
42+
f"""
43+
SELECT data FROM client_state
44+
WHERE id = {self._default_row_id}
45+
"""
46+
)
47+
row = cursor.fetchone()
48+
if row is None:
49+
return None
50+
return json.loads(row[0])
51+
52+
def get_by_key(self, key: str) -> JsonValue | None:
53+
state = self.get()
54+
if state is None:
55+
return None
56+
return state.get(key, None)
57+
58+
def delete(self) -> None:
59+
with self._db.transaction() as cursor:
60+
cursor.execute(
61+
f"""
62+
DELETE FROM client_state
63+
WHERE id = {self._default_row_id}
64+
"""
65+
)

invokeai/app/services/invocation_services.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
1818
from invokeai.app.services.boards.boards_base import BoardServiceABC
1919
from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase
20+
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
2021
from invokeai.app.services.config import InvokeAIAppConfig
2122
from invokeai.app.services.download import DownloadQueueServiceBase
2223
from invokeai.app.services.events.events_base import EventServiceBase
@@ -73,6 +74,7 @@ def __init__(
7374
style_preset_records: "StylePresetRecordsStorageBase",
7475
style_preset_image_files: "StylePresetImageFileStorageBase",
7576
workflow_thumbnails: "WorkflowThumbnailServiceBase",
77+
client_state_persistence: "ClientStatePersistenceABC",
7678
):
7779
self.board_images = board_images
7880
self.board_image_records = board_image_records
@@ -102,3 +104,4 @@ def __init__(
102104
self.style_preset_records = style_preset_records
103105
self.style_preset_image_files = style_preset_image_files
104106
self.workflow_thumbnails = workflow_thumbnails
107+
self.client_state_persistence = client_state_persistence

invokeai/app/services/shared/sqlite/sqlite_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
2424
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
2525
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
26+
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import build_migration_21
2627
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
2728

2829

@@ -63,6 +64,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
6364
migrator.register_migration(build_migration_18())
6465
migrator.register_migration(build_migration_19(app_config=config))
6566
migrator.register_migration(build_migration_20())
67+
migrator.register_migration(build_migration_21())
6668
migrator.run_migrations()
6769

6870
return db
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import sqlite3
2+
3+
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
4+
5+
6+
class Migration21Callback:
7+
def __call__(self, cursor: sqlite3.Cursor) -> None:
8+
cursor.execute(
9+
"""
10+
CREATE TABLE client_state (
11+
id INTEGER PRIMARY KEY CHECK(id = 1),
12+
data TEXT NOT NULL, -- Frontend will handle the shape of this data
13+
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP)
14+
);
15+
"""
16+
)
17+
cursor.execute(
18+
"""
19+
CREATE TRIGGER tg_client_state_updated_at
20+
AFTER UPDATE ON client_state
21+
FOR EACH ROW
22+
BEGIN
23+
UPDATE client_state
24+
SET updated_at = CURRENT_TIMESTAMP
25+
WHERE id = OLD.id;
26+
END;
27+
"""
28+
)
29+
30+
31+
def build_migration_21() -> Migration:
32+
"""Builds the migration object for migrating from version 20 to version 21. This includes:
33+
- Creating the `client_state` table.
34+
- Adding a trigger to update the `updated_at` field on updates.
35+
"""
36+
return Migration(
37+
from_version=20,
38+
to_version=21,
39+
callback=Migration21Callback(),
40+
)

invokeai/frontend/web/.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,5 @@ yalc.lock
4444

4545
# vitest
4646
tsconfig.vitest-temp.json
47-
coverage/
47+
coverage/
48+
*.tgz

invokeai/frontend/web/.storybook/preview.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ i18n.use(initReactI18next).init({
2626
returnNull: false,
2727
});
2828

29-
const store = createStore(undefined, false);
29+
const store = createStore({ driver: { getItem: () => {}, setItem: () => {} }, persistThrottle: 2000 });
3030
$store.set(store);
3131
$baseUrl.set('http://localhost:9090');
3232

invokeai/frontend/web/eslint.config.mjs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ export default [
197197
importNames: ['isEqual'],
198198
message: 'Please use objectEquals from @observ33r/object-equals instead.',
199199
},
200+
{
201+
name: 'zod/v3',
202+
message: 'Import from zod instead.',
203+
},
200204
],
201205
},
202206
],

0 commit comments

Comments
 (0)