Skip to content

Commit 63cab3f

Browse files
feat: server-side client state persistence
1 parent cacfb18 commit 63cab3f

35 files changed

+586
-121
lines changed

invokeai/app/api/dependencies.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
1111
from invokeai.app.services.boards.boards_default import BoardService
1212
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
13+
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite
1314
from invokeai.app.services.config.config_default import InvokeAIAppConfig
1415
from invokeai.app.services.download.download_default import DownloadQueueService
1516
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
@@ -151,6 +152,7 @@ def initialize(
151152
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
152153
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
153154
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
155+
client_state_persistence = ClientStatePersistenceSqlite(db=db)
154156

155157
services = InvocationServices(
156158
board_image_records=board_image_records,
@@ -181,6 +183,7 @@ def initialize(
181183
style_preset_records=style_preset_records,
182184
style_preset_image_files=style_preset_image_files,
183185
workflow_thumbnails=workflow_thumbnails,
186+
client_state_persistence=client_state_persistence,
184187
)
185188

186189
ApiDependencies.invoker = Invoker(services)

invokeai/app/api/routers/app_info.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from typing import Optional
66

77
import torch
8-
from fastapi import Body
8+
from fastapi import Body, HTTPException, Query
99
from fastapi.routing import APIRouter
10-
from pydantic import BaseModel, Field
10+
from pydantic import BaseModel, Field, JsonValue
1111

1212
from invokeai.app.api.dependencies import ApiDependencies
1313
from invokeai.app.invocations.upscale import ESRGAN_MODELS
@@ -173,3 +173,50 @@ async def disable_invocation_cache() -> None:
173173
async def get_invocation_cache_status() -> InvocationCacheStatus:
174174
"""Clears the invocation cache"""
175175
return ApiDependencies.invoker.services.invocation_cache.get_status()
176+
177+
178+
@app_router.get(
179+
"/client_state",
180+
operation_id="get_client_state_by_key",
181+
response_model=JsonValue | None,
182+
)
183+
async def get_client_state_by_key(
184+
key: str = Query(..., description="Key to retrieve from client state persistence"),
185+
) -> JsonValue | None:
186+
"""Gets the client state"""
187+
try:
188+
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(key)
189+
except Exception as e:
190+
logging.error(f"Error getting client state: {e}")
191+
raise HTTPException(status_code=500, detail="Error setting client state")
192+
193+
194+
@app_router.post(
195+
"/client_state",
196+
operation_id="set_client_state",
197+
response_model=None,
198+
)
199+
async def set_client_state(
200+
key: str = Body(..., description="Key to set"),
201+
value: JsonValue = Body(..., description="Value of the key"),
202+
) -> None:
203+
"""Sets the client state"""
204+
try:
205+
ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value)
206+
except Exception as e:
207+
logging.error(f"Error setting client state: {e}")
208+
raise HTTPException(status_code=500, detail="Error setting client state")
209+
210+
211+
@app_router.delete(
212+
"/client_state",
213+
operation_id="delete_client_state",
214+
responses={204: {"description": "Client state deleted"}},
215+
)
216+
async def delete_client_state() -> None:
217+
"""Deletes the client state"""
218+
try:
219+
ApiDependencies.invoker.services.client_state_persistence.delete()
220+
except Exception as e:
221+
logging.error(f"Error deleting client state: {e}")
222+
raise HTTPException(status_code=500, detail="Error deleting client state")
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from abc import ABC, abstractmethod
2+
3+
from pydantic import JsonValue
4+
5+
6+
class ClientStatePersistenceABC(ABC):
7+
"""
8+
Base class for client persistence implementations.
9+
This class defines the interface for persisting client data.
10+
"""
11+
12+
@abstractmethod
13+
def set_by_key(self, key: str, value: JsonValue) -> None:
14+
"""
15+
Store the data for the client.
16+
17+
:param data: The client data to be stored.
18+
"""
19+
pass
20+
21+
@abstractmethod
22+
def get_by_key(self, key: str) -> JsonValue | None:
23+
"""
24+
Get the data for the client.
25+
26+
:return: The client data.
27+
"""
28+
pass
29+
30+
@abstractmethod
31+
def delete(self) -> None:
32+
"""
33+
Delete the data for the client.
34+
"""
35+
pass
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import json
2+
3+
from pydantic import JsonValue
4+
5+
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
6+
from invokeai.app.services.invoker import Invoker
7+
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
8+
9+
10+
class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
11+
"""
12+
Base class for client persistence implementations.
13+
This class defines the interface for persisting client data.
14+
"""
15+
16+
def __init__(self, db: SqliteDatabase) -> None:
17+
super().__init__()
18+
self._db = db
19+
self._default_row_id = 1
20+
21+
def start(self, invoker: Invoker) -> None:
22+
self._invoker = invoker
23+
24+
def set_by_key(self, key: str, value: JsonValue) -> None:
25+
state = self.get() or {}
26+
state.update({key: value})
27+
28+
with self._db.transaction() as cursor:
29+
cursor.execute(
30+
f"""
31+
INSERT INTO client_state (id, data)
32+
VALUES ({self._default_row_id}, ?)
33+
ON CONFLICT(id) DO UPDATE
34+
SET data = excluded.data;
35+
""",
36+
(json.dumps(state),),
37+
)
38+
39+
def get(self) -> dict[str, JsonValue] | None:
40+
with self._db.transaction() as cursor:
41+
cursor.execute(
42+
f"""
43+
SELECT data FROM client_state
44+
WHERE id = {self._default_row_id}
45+
"""
46+
)
47+
row = cursor.fetchone()
48+
if row is None:
49+
return None
50+
return json.loads(row[0])
51+
52+
def get_by_key(self, key: str) -> JsonValue | None:
53+
state = self.get()
54+
if state is None:
55+
return None
56+
return state.get(key, None)
57+
58+
def delete(self) -> None:
59+
with self._db.transaction() as cursor:
60+
cursor.execute(
61+
f"""
62+
DELETE FROM client_state
63+
WHERE id = {self._default_row_id}
64+
"""
65+
)

invokeai/app/services/invocation_services.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
1818
from invokeai.app.services.boards.boards_base import BoardServiceABC
1919
from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase
20+
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
2021
from invokeai.app.services.config import InvokeAIAppConfig
2122
from invokeai.app.services.download import DownloadQueueServiceBase
2223
from invokeai.app.services.events.events_base import EventServiceBase
@@ -73,6 +74,7 @@ def __init__(
7374
style_preset_records: "StylePresetRecordsStorageBase",
7475
style_preset_image_files: "StylePresetImageFileStorageBase",
7576
workflow_thumbnails: "WorkflowThumbnailServiceBase",
77+
client_state_persistence: "ClientStatePersistenceABC",
7678
):
7779
self.board_images = board_images
7880
self.board_image_records = board_image_records
@@ -102,3 +104,4 @@ def __init__(
102104
self.style_preset_records = style_preset_records
103105
self.style_preset_image_files = style_preset_image_files
104106
self.workflow_thumbnails = workflow_thumbnails
107+
self.client_state_persistence = client_state_persistence

invokeai/app/services/shared/sqlite/sqlite_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
2424
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
2525
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
26+
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import build_migration_21
2627
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
2728

2829

@@ -63,6 +64,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
6364
migrator.register_migration(build_migration_18())
6465
migrator.register_migration(build_migration_19(app_config=config))
6566
migrator.register_migration(build_migration_20())
67+
migrator.register_migration(build_migration_21())
6668
migrator.run_migrations()
6769

6870
return db
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import sqlite3
2+
3+
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
4+
5+
6+
class Migration21Callback:
7+
def __call__(self, cursor: sqlite3.Cursor) -> None:
8+
cursor.execute(
9+
"""
10+
CREATE TABLE client_state (
11+
id INTEGER PRIMARY KEY CHECK(id = 1),
12+
data TEXT NOT NULL, -- Frontend will handle the shape of this data
13+
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP)
14+
);
15+
"""
16+
)
17+
cursor.execute(
18+
"""
19+
CREATE TRIGGER tg_client_state_updated_at
20+
AFTER UPDATE ON client_state
21+
FOR EACH ROW
22+
BEGIN
23+
UPDATE client_state
24+
SET updated_at = CURRENT_TIMESTAMP
25+
WHERE id = OLD.id;
26+
END;
27+
"""
28+
)
29+
30+
31+
def build_migration_21() -> Migration:
32+
"""Builds the migration object for migrating from version 20 to version 21. This includes:
33+
- Creating the `client_state` table.
34+
- Adding a trigger to update the `updated_at` field on updates.
35+
"""
36+
return Migration(
37+
from_version=20,
38+
to_version=21,
39+
callback=Migration21Callback(),
40+
)

invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
2+
import { $authToken } from 'app/store/nanostores/authToken';
23
import { $projectId } from 'app/store/nanostores/projectId';
34
import type { UseStore } from 'idb-keyval';
45
import { clear, createStore as createIDBKeyValStore, get, set } from 'idb-keyval';
56
import { atom } from 'nanostores';
67
import type { Driver } from 'redux-remember';
8+
import { getBaseUrl } from 'services/api';
9+
import { buildAppInfoUrl } from 'services/api/endpoints/appInfo';
710

811
// Create a custom idb-keyval store (just needed to customize the name)
912
const $idbKeyValStore = atom<UseStore>(createIDBKeyValStore('invoke', 'invoke-store'));
@@ -38,3 +41,73 @@ export const idbKeyValDriver: Driver = {
3841
}
3942
},
4043
};
44+
45+
const getHeaders = (extra?: Record<string, string>) => {
46+
const headers = new Headers();
47+
const authToken = $authToken.get();
48+
if (authToken) {
49+
headers.set('Authorization', `Bearer ${authToken}`);
50+
}
51+
const projectId = $projectId.get();
52+
if (projectId) {
53+
headers.set('project-id', projectId);
54+
}
55+
for (const [key, value] of Object.entries(extra ?? {})) {
56+
headers.set(key, value);
57+
}
58+
return headers;
59+
};
60+
61+
export const serverBackedDriver: Driver = {
62+
getItem: async (key) => {
63+
try {
64+
const baseUrl = getBaseUrl();
65+
const path = buildAppInfoUrl('client_state', { key });
66+
const url = `${baseUrl}/${path}`;
67+
const headers = getHeaders();
68+
const res = await fetch(url, { headers, method: 'GET' });
69+
if (!res.ok) {
70+
throw new Error(`Response status: ${res.status}`);
71+
}
72+
const json = await res.json();
73+
return json;
74+
} catch (originalError) {
75+
throw new StorageError({
76+
key,
77+
projectId: $projectId.get(),
78+
originalError,
79+
});
80+
}
81+
},
82+
setItem: async (key, value) => {
83+
try {
84+
const baseUrl = getBaseUrl();
85+
const path = buildAppInfoUrl('client_state');
86+
const url = `${baseUrl}/${path}`;
87+
const headers = getHeaders({ 'content-type': 'application/json' });
88+
const res = await fetch(url, { headers, method: 'POST', body: JSON.stringify({ key, value }) });
89+
if (!res.ok) {
90+
throw new Error(`Response status: ${res.status}`);
91+
}
92+
return value;
93+
} catch (originalError) {
94+
throw new StorageError({
95+
key,
96+
value,
97+
projectId: $projectId.get(),
98+
originalError,
99+
});
100+
}
101+
},
102+
};
103+
104+
export const resetClientState = async () => {
105+
const baseUrl = getBaseUrl();
106+
const path = buildAppInfoUrl('client_state');
107+
const url = `${baseUrl}/${path}`;
108+
const headers = getHeaders();
109+
const res = await fetch(url, { headers, method: 'DELETE' });
110+
if (!res.ok) {
111+
throw new Error(`Response status: ${res.status}`);
112+
}
113+
};

0 commit comments

Comments
 (0)